Prerequisites
- Databricks workspace with a SQL warehouse running
- Create a Delta table named
customer_analytics - Get your Prior Labs API key
- Install the required Python packages:
Copy
Ask AI
pip install openai-agents databricks-sdk databricks-sql-connector httpx pandas scikit-learn
Overview
Here’s what happens end-to-end:- Sample customer data is pulled from a Databricks Delta table using a SQL connector.
- The data is split into train/test sets and saved to disk locally.
- An agent - built with the OpenAI Agents SDK - takes over.
- You get a churn model evaluation summary printed to stdout.
Databricks Delta table
If you don’t have acustomer_analytics table yet, you can create one from Databricks’ built-in
TPC-DS sample data. Run this in a Databricks notebook or SQL editor:
query.sql
Copy
Ask AI
CREATE OR REPLACE TABLE customer_analytics AS
SELECT
c.c_customer_sk,
DATEDIFF(
DATE('2003-01-02'),
MAKE_DATE(c.c_birth_year, c.c_birth_month, c.c_birth_day)
) / 365 AS customer_age,
c.c_preferred_cust_flag,
MIN(d.d_date) AS first_purchase_date,
MAX(d.d_date) AS last_purchase_date,
DATEDIFF(DATE('2003-01-02'), MAX(d.d_date)) AS days_since_last_purchase,
COUNT(DISTINCT d.d_date) AS total_purchase_days,
COUNT(*) AS total_transactions,
SUM(s.ss_net_paid) AS lifetime_value,
AVG(s.ss_net_paid) AS avg_transaction_value,
CASE
WHEN DATEDIFF(DATE('2003-01-02'), MAX(d.d_date)) > 180 THEN 1
ELSE 0
END AS is_churned
FROM samples.tpcds_sf1.customer c
JOIN samples.tpcds_sf1.store_sales s ON c.c_customer_sk = s.ss_customer_sk
JOIN samples.tpcds_sf1.date_dim d ON s.ss_sold_date_sk = d.d_date_sk
WHERE d.d_year BETWEEN 1998 AND 2003
AND s.ss_customer_sk IS NOT NULL
GROUP BY
c.c_customer_sk, c.c_preferred_cust_flag,
c.c_birth_year, c.c_birth_month, c.c_birth_day
Example
The following code block contains all necessary code to run the OpenAI Agent with Databricks and the TabPFN MCP server.example.py
Copy
Ask AI
import asyncio
import json
import os
import tempfile
import httpx
import pandas as pd
from databricks import sql as dbsql
from databricks.sdk import WorkspaceClient
from sklearn.metrics import (
accuracy_score,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)
from sklearn.model_selection import train_test_split
from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
from agents.mcp import MCPServerStreamableHttp
# Enable verbose stdout logging
enable_verbose_stdout_logging()
# ---------------------------------------------------------------------------
# Configuration - edit these to match your environment
# ---------------------------------------------------------------------------
PRIOR_LABS_API_KEY = os.environ["PRIORLABS_API_KEY"]
DATABRICKS_PROFILE = "DEFAULT" # Databricks CLI profile (~/.databrickscfg)
TABLE = os.environ["DATABRICKS_TABLE"]
TARGET_COLUMN = "is_churned"
DROP_COLS = ["c_customer_sk", "first_purchase_date", "last_purchase_date"]
LIMIT = 1_000 # rows to fetch for training context
TEMP_DIR = tempfile.mkdtemp()
# ---------------------------------------------------------------------------
# Step 1 - Fetch data from Databricks
# ---------------------------------------------------------------------------
def _get_warehouse_http_path() -> str:
"""Return the HTTP path of the first available SQL warehouse in your workspace.
If you want to target a specific warehouse, replace this with a hardcoded path:
return "/sql/1.0/warehouses/<your-warehouse-id>"
"""
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
warehouses = list(ws.warehouses.list())
if not warehouses:
raise RuntimeError("No SQL warehouses found in your Databricks workspace.")
return f"/sql/1.0/warehouses/{warehouses[0].id}"
def fetch_data() -> pd.DataFrame:
"""Query the customer_analytics Delta table and return the result as a DataFrame.
Credentials are read from your Databricks CLI profile.
"""
ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
host = ws.config.host.replace("https://", "").rstrip("/")
with dbsql.connect(
server_hostname=host,
http_path=_get_warehouse_http_path(),
access_token=ws.config.token,
) as conn:
with conn.cursor() as cursor:
cursor.execute(f"SELECT * FROM {TABLE} LIMIT {LIMIT}")
return cursor.fetchall_arrow().to_pandas()
# ---------------------------------------------------------------------------
# Step 2 - Preprocess and split
# ---------------------------------------------------------------------------
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
"""Drop columns that shouldn't be used as features and remove all-null rows.
DROP_COLS contains identifiers and date columns that would either leak information
or add noise - adjust it to match your table schema.
"""
return (
df.drop(columns=[c for c in DROP_COLS if c in df.columns])
.dropna(how="all")
)
def split(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Split the dataset 80/20 into train and test sets.
Returns three DataFrames:
- train_df: full training set including the target column
- test_features: test set with the target column removed (what TabPFN predicts on)
- test_labels: held-out ground truth for evaluation
In production you'd typically do a time-based split to avoid data leakage, but an
80/20 random split works fine for demonstration purposes.
"""
# Split the data into train and test sets
train, test = train_test_split(df, test_size=0.2, random_state=42)
train_df = train.reset_index(drop=True)
# Drop the target column from the test set so that the agent can predict on it
test_features = test.drop(columns=[TARGET_COLUMN]).reset_index(drop=True)
test_labels = test[[TARGET_COLUMN]].reset_index(drop=True)
return train_df, test_features, test_labels
def save_splits(
train_df: pd.DataFrame,
test_features: pd.DataFrame,
test_labels: pd.DataFrame,
) -> tuple[str, str, str]:
"""Write the three splits to CSV files in a temp directory and return their paths.
The agent needs local file paths to upload the data to TabPFN via presigned URLs.
A temp directory keeps things clean - it's automatically removed when the process exits.
"""
train_path = os.path.join(TEMP_DIR, "train.csv")
test_path = os.path.join(TEMP_DIR, "test.csv")
labels_path = os.path.join(TEMP_DIR, "test_labels.csv")
train_df.to_csv(train_path, index=False)
test_features.to_csv(test_path, index=False)
test_labels.to_csv(labels_path, index=False)
return train_path, test_path, labels_path
# ---------------------------------------------------------------------------
# Step 3 - Agent tools
#
# The agent has three local tools alongside the TabPFN MCP tools:
# - put_file_to_signed_url: uploads a CSV to the presigned URL TabPFN provides
# - save_predictions: persists the raw TabPFN output to disk
# - compute_metrics: evaluates churn-specific classification performance
# ---------------------------------------------------------------------------
@function_tool
async def put_file_to_signed_url(upload_url: str, filepath: str) -> str:
"""Upload a local CSV file to the presigned URL returned by TabPFN's upload_dataset tool.
TabPFN's upload_dataset MCP tool returns a short-lived presigned URL pointing to cloud
storage. This tool does the actual HTTP PUT so the agent doesn't have to handle raw bytes.
Returns "ok" on success, or raises on HTTP error.
"""
with open(filepath, "rb") as f:
async with httpx.AsyncClient() as client:
resp = await client.put(upload_url, content=f.read())
resp.raise_for_status()
return "ok"
@function_tool
def save_predictions(predictions_json: str) -> str:
"""Persist the raw JSON output from fit_and_predict_from_dataset to a CSV file.
TabPFN returns churn probabilities as a JSON array of [p_active, p_churned] pairs.
This tool parses that output and writes it to predictions.csv with columns class_0
and class_1, where class_1 is the probability of churn.
Returns `predictions_path=<path>` on success, or an ERROR string on failure.
"""
try:
predictions = json.loads(predictions_json)
# If the predictions are a list of probabilities, create a DataFrame with the columns
# class_0, class_1, ..., otherwise create a DataFrame with a single column prediction
n_classes = len(predictions[0]) if isinstance(predictions[0], list) else 1
if n_classes > 1:
df = pd.DataFrame(predictions, columns=[f"class_{i}" for i in range(n_classes)])
else:
df = pd.DataFrame({"prediction": predictions})
# Save the predictions to a CSV file
path = os.path.join(TEMP_DIR, "predictions.csv")
df.to_csv(path, index=False)
return f"predictions_path={path}"
except Exception as e:
return f"ERROR saving predictions: {type(e).__name__}: {e}"
@function_tool
def compute_metrics(predictions_path: str, labels_path: str) -> str:
"""Compute churn-specific evaluation metrics and return them as a JSON string.
Reads the probability outputs from predictions_path (columns class_0, class_1) and the
ground truth from labels_path, then computes:
roc_auc - how well the model ranks churners vs. actives across all thresholds
accuracy - overall fraction of correctly classified customers
f1_churn - F1 score for the churned class specifically; more informative than
accuracy when churners are a small fraction of your customer base
precision_churn - of customers the model flags as churners, how many actually churned
recall_churn - of actual churners, how many the model successfully identified
Returns a JSON metrics object, or an ERROR string if files can't be loaded.
"""
try:
probas = pd.read_csv(predictions_path).values
y_true = pd.read_csv(labels_path).iloc[:, 0].values
except Exception as e:
return f"ERROR loading files: {type(e).__name__}: {e}"
try:
y_score = probas[:, 1]
y_pred = probas.argmax(axis=1)
metrics = {
"roc_auc": round(float(roc_auc_score(y_true, y_score)), 4),
"accuracy": round(float(accuracy_score(y_true, y_pred)), 4),
"f1_churn": round(float(f1_score(y_true, y_pred, pos_label=1)), 4),
"precision_churn": round(float(precision_score(y_true, y_pred, pos_label=1)), 4),
"recall_churn": round(float(recall_score(y_true, y_pred, pos_label=1)), 4),
}
except Exception as e:
return f"ERROR computing metrics: {type(e).__name__}: {e}"
return json.dumps(metrics, indent=2)
# ---------------------------------------------------------------------------
# Step 4 - Build the agent
# ---------------------------------------------------------------------------
INSTRUCTIONS = """You are a churn prediction assistant. The training and test CSV files
have already been prepared locally. Execute the following steps without pausing:
STEP 1 - Upload training data.
Call `upload_dataset` with filename="train.csv" → get upload_url and train_dataset_id.
Call `put_file_to_signed_url(upload_url=<url>, filepath=<train_path>)`.
STEP 2 - Upload test data.
Call `upload_dataset` with filename="test.csv" → get upload_url and test_dataset_id.
Call `put_file_to_signed_url(upload_url=<url>, filepath=<test_path>)`.
STEP 3 - Run prediction.
Call `fit_and_predict_from_dataset` with:
- train_dataset_id: from STEP 1
- test_dataset_id: from STEP 2
- target_column: "{target_column}"
- task_type: "classification"
- output_type: "probas"
STEP 4 - Save predictions.
Call `save_predictions(predictions_json=<raw JSON string from STEP 3>)`.
Parse the result to extract predictions_path.
STEP 5 - Evaluate.
Call `compute_metrics(predictions_path=<path from STEP 4>, labels_path=<test_labels_path>)`.
STEP 6 - Report.
Return a concise churn model evaluation summary using the metrics from STEP 5.
"""
def build_agent(tabpfn_server: MCPServerStreamableHttp) -> Agent:
"""Construct the churn prediction agent.
The agent has access to two TabPFN MCP tools (upload_dataset and
fit_and_predict_from_dataset) plus three local function tools for file upload,
saving predictions, and computing metrics.
"""
return Agent(
name="Churn Predictor",
instructions=INSTRUCTIONS.format(target_column=TARGET_COLUMN),
mcp_servers=[tabpfn_server],
tools=[put_file_to_signed_url, save_predictions, compute_metrics],
)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
async def main() -> None:
"""Orchestrate the full pipeline: fetch → preprocess → split → agent → print results."""
# Fetch the data using the Databricks SQL connector
df = fetch_data()
# Drop columns that shouldn't be used as features and remove all-null rows
df = preprocess(df)
# Split the data into train and test sets and save them to disk
train_df, test_features, test_labels = split(df)
train_path, test_path, labels_path = save_splits(train_df, test_features, test_labels)
print(f"Data ready - Train: {len(train_df)} rows | Test: {len(test_features)} rows")
# The TabPFN MCP server. The tool_filter limits the agent to only the two tools it needs.
async with MCPServerStreamableHttp(
name="TabPFN",
params={
"url": "https://api.priorlabs.ai/mcp/server",
"headers": {"Authorization": f"Bearer {PRIOR_LABS_API_KEY}"},
"timeout": 60,
"sse_read_timeout": 300,
},
client_session_timeout_seconds=120,
require_approval="never",
tool_filter={"allowed_tool_names": ["upload_dataset", "fit_and_predict_from_dataset"]},
) as tabpfn_server:
result = await Runner.run(
build_agent(tabpfn_server),
(
# The instructions for the agent; change this to your own instructions
"Which customers are most at risk of churning, and how well does the model perform?\n\n"
f"train_path={train_path}\n"
f"test_path={test_path}\n"
f"test_labels_path={labels_path}"
)
)
print(result.final_output)
asyncio.run(main())
Running the script
Set the required environment variables and run the script:Copy
Ask AI
export PRIORLABS_API_KEY="your-prior-labs-api-key"
export DATABRICKS_TABLE="catalog.schema.customer_analytics"
python example.py