Skip to main content
This tutorial shows you how to connect your Databricks Delta table to TabPFN’s MCP server and run a churn prediction pipeline with an AI agent.

Prerequisites

  • Databricks workspace with a SQL warehouse running
  • Create a Delta table named customer_analytics
  • Get your Prior Labs API key
  • Install the required Python packages:
pip install openai-agents databricks-sdk databricks-sql-connector httpx pandas scikit-learn

Overview

Here’s what happens end-to-end:
  1. Sample customer data is pulled from a Databricks Delta table using a SQL connector.
  2. The data is split into train/test sets and saved to disk locally.
  3. An agent - built with the OpenAI Agents SDK - takes over.
  4. You get a churn model evaluation summary printed to stdout.

Databricks Delta table

If you don’t have a customer_analytics table yet, you can create one from Databricks’ built-in TPC-DS sample data. Run this in a Databricks notebook or SQL editor:
query.sql
CREATE OR REPLACE TABLE customer_analytics AS
SELECT
    c.c_customer_sk,
    DATEDIFF(
        DATE('2003-01-02'),
        MAKE_DATE(c.c_birth_year, c.c_birth_month, c.c_birth_day)
    ) / 365 AS customer_age,
    c.c_preferred_cust_flag,
    MIN(d.d_date) AS first_purchase_date,
    MAX(d.d_date) AS last_purchase_date,
    DATEDIFF(DATE('2003-01-02'), MAX(d.d_date)) AS days_since_last_purchase,
    COUNT(DISTINCT d.d_date) AS total_purchase_days,
    COUNT(*) AS total_transactions,
    SUM(s.ss_net_paid) AS lifetime_value,
    AVG(s.ss_net_paid) AS avg_transaction_value,
    CASE
        WHEN DATEDIFF(DATE('2003-01-02'), MAX(d.d_date)) > 180 THEN 1
        ELSE 0
    END AS is_churned
FROM samples.tpcds_sf1.customer c
JOIN samples.tpcds_sf1.store_sales s ON c.c_customer_sk = s.ss_customer_sk
JOIN samples.tpcds_sf1.date_dim d    ON s.ss_sold_date_sk = d.d_date_sk
WHERE d.d_year BETWEEN 1998 AND 2003
    AND s.ss_customer_sk IS NOT NULL
GROUP BY
    c.c_customer_sk, c.c_preferred_cust_flag,
    c.c_birth_year, c.c_birth_month, c.c_birth_day

Example

The following code block contains all necessary code to run the OpenAI Agent with Databricks and the TabPFN MCP server.
example.py
import asyncio
import json
import os
import tempfile

import httpx
import pandas as pd
from databricks import sql as dbsql
from databricks.sdk import WorkspaceClient
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split

from agents import Agent, Runner, function_tool, enable_verbose_stdout_logging
from agents.mcp import MCPServerStreamableHttp


# Enable verbose stdout logging
enable_verbose_stdout_logging()

# ---------------------------------------------------------------------------
# Configuration - edit these to match your environment
# ---------------------------------------------------------------------------

PRIOR_LABS_API_KEY = os.environ["PRIORLABS_API_KEY"]
DATABRICKS_PROFILE = "DEFAULT"   # Databricks CLI profile (~/.databrickscfg)
TABLE              = os.environ["DATABRICKS_TABLE"]
TARGET_COLUMN      = "is_churned"
DROP_COLS          = ["c_customer_sk", "first_purchase_date", "last_purchase_date"]
LIMIT              = 1_000       # rows to fetch for training context
TEMP_DIR           = tempfile.mkdtemp()

# ---------------------------------------------------------------------------
# Step 1 - Fetch data from Databricks
# ---------------------------------------------------------------------------

def _get_warehouse_http_path() -> str:
    """Return the HTTP path of the first available SQL warehouse in your workspace.

    If you want to target a specific warehouse, replace this with a hardcoded path:
        return "/sql/1.0/warehouses/<your-warehouse-id>"
    """
    ws = WorkspaceClient(profile=DATABRICKS_PROFILE)
    warehouses = list(ws.warehouses.list())
    if not warehouses:
        raise RuntimeError("No SQL warehouses found in your Databricks workspace.")
    return f"/sql/1.0/warehouses/{warehouses[0].id}"


def fetch_data() -> pd.DataFrame:
    """Query the customer_analytics Delta table and return the result as a DataFrame.

    Credentials are read from your Databricks CLI profile.
    """
    ws   = WorkspaceClient(profile=DATABRICKS_PROFILE)
    host = ws.config.host.replace("https://", "").rstrip("/")

    with dbsql.connect(
        server_hostname=host,
        http_path=_get_warehouse_http_path(),
        access_token=ws.config.token,
    ) as conn:
        with conn.cursor() as cursor:
            cursor.execute(f"SELECT * FROM {TABLE} LIMIT {LIMIT}")
            return cursor.fetchall_arrow().to_pandas()

# ---------------------------------------------------------------------------
# Step 2 - Preprocess and split
# ---------------------------------------------------------------------------

def preprocess(df: pd.DataFrame) -> pd.DataFrame:
    """Drop columns that shouldn't be used as features and remove all-null rows.

    DROP_COLS contains identifiers and date columns that would either leak information
    or add noise - adjust it to match your table schema.
    """
    return (
        df.drop(columns=[c for c in DROP_COLS if c in df.columns])
          .dropna(how="all")
    )


def split(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Split the dataset 80/20 into train and test sets.

    Returns three DataFrames:
      - train_df:      full training set including the target column
      - test_features: test set with the target column removed (what TabPFN predicts on)
      - test_labels:   held-out ground truth for evaluation

    In production you'd typically do a time-based split to avoid data leakage, but an
    80/20 random split works fine for demonstration purposes.
    """
    # Split the data into train and test sets
    train, test   = train_test_split(df, test_size=0.2, random_state=42)
    train_df      = train.reset_index(drop=True)

    # Drop the target column from the test set so that the agent can predict on it
    test_features = test.drop(columns=[TARGET_COLUMN]).reset_index(drop=True)
    test_labels   = test[[TARGET_COLUMN]].reset_index(drop=True)
    return train_df, test_features, test_labels


def save_splits(
    train_df: pd.DataFrame,
    test_features: pd.DataFrame,
    test_labels: pd.DataFrame,
) -> tuple[str, str, str]:
    """Write the three splits to CSV files in a temp directory and return their paths.

    The agent needs local file paths to upload the data to TabPFN via presigned URLs.
    A temp directory keeps things clean - it's automatically removed when the process exits.
    """
    train_path  = os.path.join(TEMP_DIR, "train.csv")
    test_path   = os.path.join(TEMP_DIR, "test.csv")
    labels_path = os.path.join(TEMP_DIR, "test_labels.csv")
    train_df.to_csv(train_path, index=False)
    test_features.to_csv(test_path, index=False)
    test_labels.to_csv(labels_path, index=False)
    return train_path, test_path, labels_path

# ---------------------------------------------------------------------------
# Step 3 - Agent tools
#
# The agent has three local tools alongside the TabPFN MCP tools:
#   - put_file_to_signed_url: uploads a CSV to the presigned URL TabPFN provides
#   - save_predictions:       persists the raw TabPFN output to disk
#   - compute_metrics:        evaluates churn-specific classification performance
# ---------------------------------------------------------------------------

@function_tool
async def put_file_to_signed_url(upload_url: str, filepath: str) -> str:
    """Upload a local CSV file to the presigned URL returned by TabPFN's upload_dataset tool.

    TabPFN's upload_dataset MCP tool returns a short-lived presigned URL pointing to cloud
    storage. This tool does the actual HTTP PUT so the agent doesn't have to handle raw bytes.
    Returns "ok" on success, or raises on HTTP error.
    """
    with open(filepath, "rb") as f:
        async with httpx.AsyncClient() as client:
            resp = await client.put(upload_url, content=f.read())
            resp.raise_for_status()
    return "ok"


@function_tool
def save_predictions(predictions_json: str) -> str:
    """Persist the raw JSON output from fit_and_predict_from_dataset to a CSV file.

    TabPFN returns churn probabilities as a JSON array of [p_active, p_churned] pairs.
    This tool parses that output and writes it to predictions.csv with columns class_0
    and class_1, where class_1 is the probability of churn.

    Returns `predictions_path=<path>` on success, or an ERROR string on failure.
    """
    try:
        predictions = json.loads(predictions_json)

        # If the predictions are a list of probabilities, create a DataFrame with the columns
        # class_0, class_1, ..., otherwise create a DataFrame with a single column prediction
        n_classes = len(predictions[0]) if isinstance(predictions[0], list) else 1
        if n_classes > 1:
            df = pd.DataFrame(predictions, columns=[f"class_{i}" for i in range(n_classes)])
        else:
            df = pd.DataFrame({"prediction": predictions})

        # Save the predictions to a CSV file
        path = os.path.join(TEMP_DIR, "predictions.csv")
        df.to_csv(path, index=False)
        return f"predictions_path={path}"
    except Exception as e:
        return f"ERROR saving predictions: {type(e).__name__}: {e}"


@function_tool
def compute_metrics(predictions_path: str, labels_path: str) -> str:
    """Compute churn-specific evaluation metrics and return them as a JSON string.

    Reads the probability outputs from predictions_path (columns class_0, class_1) and the
    ground truth from labels_path, then computes:

      roc_auc         - how well the model ranks churners vs. actives across all thresholds
      accuracy        - overall fraction of correctly classified customers
      f1_churn        - F1 score for the churned class specifically; more informative than
                        accuracy when churners are a small fraction of your customer base
      precision_churn - of customers the model flags as churners, how many actually churned
      recall_churn    - of actual churners, how many the model successfully identified

    Returns a JSON metrics object, or an ERROR string if files can't be loaded.
    """
    try:
        probas = pd.read_csv(predictions_path).values
        y_true = pd.read_csv(labels_path).iloc[:, 0].values
    except Exception as e:
        return f"ERROR loading files: {type(e).__name__}: {e}"

    try:
        y_score = probas[:, 1]
        y_pred  = probas.argmax(axis=1)
        metrics = {
            "roc_auc":         round(float(roc_auc_score(y_true, y_score)), 4),
            "accuracy":        round(float(accuracy_score(y_true, y_pred)), 4),
            "f1_churn":        round(float(f1_score(y_true, y_pred, pos_label=1)), 4),
            "precision_churn": round(float(precision_score(y_true, y_pred, pos_label=1)), 4),
            "recall_churn":    round(float(recall_score(y_true, y_pred, pos_label=1)), 4),
        }
    except Exception as e:
        return f"ERROR computing metrics: {type(e).__name__}: {e}"

    return json.dumps(metrics, indent=2)

# ---------------------------------------------------------------------------
# Step 4 - Build the agent
# ---------------------------------------------------------------------------

INSTRUCTIONS = """You are a churn prediction assistant. The training and test CSV files
have already been prepared locally. Execute the following steps without pausing:

STEP 1 - Upload training data.
Call `upload_dataset` with filename="train.csv" → get upload_url and train_dataset_id.
Call `put_file_to_signed_url(upload_url=<url>, filepath=<train_path>)`.

STEP 2 - Upload test data.
Call `upload_dataset` with filename="test.csv" → get upload_url and test_dataset_id.
Call `put_file_to_signed_url(upload_url=<url>, filepath=<test_path>)`.

STEP 3 - Run prediction.
Call `fit_and_predict_from_dataset` with:
  - train_dataset_id: from STEP 1
  - test_dataset_id:  from STEP 2
  - target_column:    "{target_column}"
  - task_type:        "classification"
  - output_type:      "probas"

STEP 4 - Save predictions.
Call `save_predictions(predictions_json=<raw JSON string from STEP 3>)`.
Parse the result to extract predictions_path.

STEP 5 - Evaluate.
Call `compute_metrics(predictions_path=<path from STEP 4>, labels_path=<test_labels_path>)`.

STEP 6 - Report.
Return a concise churn model evaluation summary using the metrics from STEP 5.
"""


def build_agent(tabpfn_server: MCPServerStreamableHttp) -> Agent:
    """Construct the churn prediction agent.

    The agent has access to two TabPFN MCP tools (upload_dataset and
    fit_and_predict_from_dataset) plus three local function tools for file upload,
    saving predictions, and computing metrics.
    """
    return Agent(
        name="Churn Predictor",
        instructions=INSTRUCTIONS.format(target_column=TARGET_COLUMN),
        mcp_servers=[tabpfn_server],
        tools=[put_file_to_signed_url, save_predictions, compute_metrics],
    )

# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

async def main() -> None:
    """Orchestrate the full pipeline: fetch → preprocess → split → agent → print results."""
    # Fetch the data using the Databricks SQL connector
    df = fetch_data()

    # Drop columns that shouldn't be used as features and remove all-null rows
    df = preprocess(df)

    # Split the data into train and test sets and save them to disk
    train_df, test_features, test_labels = split(df)
    train_path, test_path, labels_path = save_splits(train_df, test_features, test_labels)

    print(f"Data ready - Train: {len(train_df)} rows | Test: {len(test_features)} rows")

    # The TabPFN MCP server. The tool_filter limits the agent to only the two tools it needs.
    async with MCPServerStreamableHttp(
        name="TabPFN",
        params={
            "url": "https://api.priorlabs.ai/mcp/server",
            "headers": {"Authorization": f"Bearer {PRIOR_LABS_API_KEY}"},
            "timeout": 60,
            "sse_read_timeout": 300,
        },
        client_session_timeout_seconds=120,
        require_approval="never",
        tool_filter={"allowed_tool_names": ["upload_dataset", "fit_and_predict_from_dataset"]},
    ) as tabpfn_server:

        result = await Runner.run(
            build_agent(tabpfn_server),
            (
                # The instructions for the agent; change this to your own instructions
                "Which customers are most at risk of churning, and how well does the model perform?\n\n"
                f"train_path={train_path}\n"
                f"test_path={test_path}\n"
                f"test_labels_path={labels_path}"
            )
        )

    print(result.final_output)


asyncio.run(main())

Running the script

Set the required environment variables and run the script:
export PRIORLABS_API_KEY="your-prior-labs-api-key"
export DATABRICKS_TABLE="catalog.schema.customer_analytics"

python example.py
On success, the agent prints a churn model evaluation summary with ROC-AUC, accuracy, F1, precision, and recall for the churned class.