> ## Documentation Index
> Fetch the complete documentation index at: https://docs.priorlabs.ai/llms.txt
> Use this file to discover all available pages before exploring further.

# TabPFN with MLflow

> Learn how to wrap TabPFN as an MLflow PythonModel, register it to Unity Catalog, and deploy it to a Mosaic AI Model serving endpoint.

## Overview

[TabPFN](https://github.com/priorlabs/tabpfn) is a tabular foundation model that makes accurate predictions in a single forward pass - no dataset-specific training, no hyperparameter search, no separate model artifact per target variable.

The promise: **one registered model** that handles both classification and regression across any tabular dataset.

<CardGroup cols={2}>
  <Card title="Less Operational Complexity" icon="cubes">
    Manage one model instead of many. No per-dataset training runs, no hyperparameter searches, no separate artifacts per outcome.
  </Card>

  <Card title="Higher Productivity" icon="bolt">
    Use a single model for classification and regression across any tabular dataset - from a notebook or SQL `ai_query()`.
  </Card>
</CardGroup>

By the end of this tutorial you will have:

* Wrapped TabPFN as an MLflow `PythonModel`
* Registered it to Unity Catalog with a `champion` alias
* Tested it locally on classification and regression tasks
* Run an end-to-end example on the Lending Club loan dataset
* Deployed it to a GPU-accelerated Mosaic AI Model Serving endpoint

<Note>
  For the full list of supported TabPFN parameters (estimators, output types, configuration options) see the [TabPFN GitHub repository](https://github.com/priorlabs/tabpfn). For broader context on TabPFN and Databricks, read the [Databricks blog post](https://www.databricks.com/blog/tappfn-ai-accelerates-business-transformation-databricks).
</Note>

***

## Prerequisites

* A Databricks workspace with Unity Catalog enabled
* Access to a **serverless GPU environment** with at least an **NVIDIA A10G** GPU
* A TabPFN token from [Prior Labs](https://github.com/priorlabs/tabpfn?tab=readme-ov-file#installation--setup)

<Tip>
  TabPFN runs inference via PyTorch, so a GPU significantly speeds up predictions - especially on larger datasets. In the notebook toolbar, select **Serverless** with a GPU-enabled instance (A10G or better). Serverless handles provisioning and scaling automatically.
</Tip>

***

## Step 1: Install Dependencies

Run the following in your Databricks notebook cell:

```python theme={null}
%pip install mlflow tabpfn
dbutils.library.restartPython()
```

***

## Step 2: Configure Your TabPFN Token

TabPFN model weights are gated and require authentication. Store your token as a Databricks secret so it never lives in plain text in your notebook.

Run these commands in the **Databricks CLI**:

```bash theme={null}
databricks secrets create-scope tabpfn
databricks secrets put-secret tabpfn tabpfn_token --string-value "<your-token>"
```

Then reference the secret in your notebook:

```python theme={null}
import os

os.environ["TABPFN_TOKEN"] = dbutils.secrets.get(scope="tabpfn", key="tabpfn_token")
```

***

## Step 3: Define the Wrapper and Signature

The `TabPFNWrapper` is a custom `mlflow.pyfunc.PythonModel`. It is the heart of this integration and handles three concerns:

* **Dual-format input** - accepts both raw Python objects (from notebooks) and JSON strings (from SQL `ai_query()`)
* **Task routing** - classification or regression, controlled by `task_config`
* **Flexible output** - class labels, probabilities, or regression predictions based on `output_type`

All input columns use `DataType.string` so the same endpoint works from Python, REST, and SQL without any changes to the registered model. The `_maybe_parse_json()` helper transparently handles both formats at predict time.

```python theme={null}
import os
import json
import inspect

import mlflow
import numpy as np
import pandas as pd

from typing import Literal

from mlflow.models.signature import ModelSignature
from mlflow.types.schema import (
    Array,
    ColSpec,
    DataType,
    Object,
    ParamSchema,
    ParamSpec,
    Property,
    Schema
)
from tabpfn import TabPFNClassifier, TabPFNRegressor


class TabPFNWrapper(mlflow.pyfunc.PythonModel):
    """MLflow PythonModel wrapper for TabPFN"""

    _CLASSIFICATION_OUTPUT_TYPES = {"preds", "probas"}
    """The model output types for classification."""

    _REGRESSION_OUTPUT_TYPES = {"mean", "mode", "median", "quantiles", "main", "full"}
    """The model output types for regression."""

    @staticmethod
    def _maybe_parse_json(value: str | dict | list) -> dict | list:
        """Helper function to parse a JSON string if needed.

        Args:
            value: The value to parse. Can be a string, dict, or list.

        Returns:
            The parsed value.
        """
        if isinstance(value, str):
            return json.loads(value)
        return value

    def _get_output_type(
        self,
        task: Literal["classification", "regression"],
        output_type: str | None
    ) -> str:
        """Get the prediction output type.

        Args:
            output_type: The output type to get.

        Returns:
            The prediction output type.
        """
        if output_type is not None:
            supported_output_types = self._CLASSIFICATION_OUTPUT_TYPES | self._REGRESSION_OUTPUT_TYPES
            if output_type not in supported_output_types:
                raise ValueError(f"Unknown output_type: {output_type!r}. Must be one of {supported_output_types}")
            return output_type

        # Fallback to the defaults
        return "preds" if task == "classification" else "mean"

    def _init_estimator(
        self,
        task: Literal["classification", "regression"],
        config: dict
    ) -> TabPFNClassifier | TabPFNRegressor:
        """Initialize a TabPFN estimator.

        Args:
            task: The task to initialize the estimator for.
            config: The configuration for the estimator.

        Returns:
            The initialized estimator.
        """
        Estimator = TabPFNClassifier if task == "classification" else TabPFNRegressor

        # Validate provided config keys against the estimator constructor
        sig = inspect.signature(Estimator.__init__)

        constructor_params = set(sig.parameters.keys()) - {"self"}
        supplied_keys = set(config.keys())
        invalid_keys = supplied_keys - constructor_params

        # Raise an error if any invalid keys are provided
        if invalid_keys:
            msg = (
                f"Config contains invalid parameters for {Estimator.__name__}: {sorted(invalid_keys)}.\n"
                f"Allowed parameters: {sorted(constructor_params)}"
            )
            raise ValueError(msg)

        return Estimator(**config)

    def predict(self, model_input, params=None):
        """Run predictions.

        TabPFN runs predictions in a single forward pass. The model is
        fitted on the training data and then used to predict on the test data.

        Args:
            model_input: The input data to predict on.
            params: The parameters for the model.

        Returns:
            The predictions.
        """
        # Accept both DataFrame (local pyfunc) and list-of-dicts (serving endpoint)
        if isinstance(model_input, pd.DataFrame):
            model_input = model_input.to_dict(orient="records")

        assert isinstance(model_input, list), "model_input must be a list with a single row"
        assert len(model_input) == 1, "model_input must have a single row"

        model_input: dict = model_input[0]

        # Parse the task configuration
        task_config = self._maybe_parse_json(model_input["task_config"]) or {}

        task = task_config.get("task")
        if task is None:
            raise KeyError("Task is required, must be 'classification' or 'regression'")

        tabpfn_config = task_config.get("tabpfn_config") or {}
        predict_params = task_config.get("predict_params") or {}

        output_type = self._get_output_type(task, predict_params.get("output_type"))

        # Parse the input data
        X_train = np.array(self._maybe_parse_json(model_input["X_train"]), dtype=np.float64)
        y_train = np.array(self._maybe_parse_json(model_input["y_train"]), dtype=np.float64)
        X_test  = np.array(self._maybe_parse_json(model_input["X_test"]),  dtype=np.float64)

        # Initialize the estimator
        estimator = self._init_estimator(task, tabpfn_config)

        # Fit the estimator
        estimator.fit(X_train, y_train)

        # Run predictions
        if task == "classification":
            if output_type == "probas":
                return estimator.predict_proba(X_test).tolist()
            return estimator.predict(X_test).tolist()

        predictions = estimator.predict(X_test, output_type=output_type)
        return predictions.tolist()
```

### Define the Model Signature

The model signature tells MLflow (and Unity Catalog) the expected input and output shapes. All columns are `DataType.string` to support both raw Python and JSON-serialized inputs from SQL `ai_query()`.

```python theme={null}
# All inputs are STRING for dual-format support:
#   - JSON strings from SQL ai_query() via to_json()
#   - Raw Python objects, such as `pd.DataFrame`, from notebook calls

input_schema = Schema([
    ColSpec(DataType.string, "task_config"),
    ColSpec(DataType.string, "X_train"),
    ColSpec(DataType.string, "y_train"),
    ColSpec(DataType.string, "X_test"),
])

# Output schema is required by Unity Catalog.
# STRING is the best fit - the wrapper returns variable shapes
# (flat list for preds, nested list for probas) and the serving
# layer serializes whatever .tolist() produces into a JSON string.
output_schema = Schema([ColSpec(DataType.string, name="predictions")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
```

<Note>
  The `output_schema` is **required** for Unity Catalog registration. Using `DataType.string` is the right choice here because the wrapper can return either a flat list (for `preds`) or a nested list (for `probas`/regression quantiles), and the serving layer serializes the result to JSON regardless.
</Note>

***

## Step 4: Register to Unity Catalog

Log the model with MLflow and register it under a fully qualified Unity Catalog path (`catalog.schema.tabpfn`). The `input_example` uses `json.dumps()` to match the all-string signature - the wrapper deserializes at predict time via `_maybe_parse_json()`.

After registration, we tag the latest version with a `"champion"` alias. The serving endpoint references this alias, so you can promote future versions without touching the endpoint config.

```python theme={null}
# Input example uses json.dumps() to match the all-string signature
# The wrapper's _maybe_parse_json() deserializes these at predict time
input_example = pd.DataFrame([{
    "task_config": json.dumps({
        "task": "classification",
        "tabpfn_config": {
            "n_estimators": 8,
            "softmax_temperature": 0.9,
        },
        "predict_params": {
            "output_type": "preds",
        },
    }),
    "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]),
    "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]),
    "X_test": json.dumps([[2.0, 3.0, 0.0]]),
}])

# Fully qualified Unity Catalog path (portable across workspaces)
CATALOG = spark.catalog.currentCatalog()
SCHEMA = spark.catalog.currentDatabase()
REGISTERED_MODEL_NAME = f"{CATALOG}.{SCHEMA}.tabpfn"

with mlflow.start_run(run_name="tabpfn-registration") as run:
    model_info = mlflow.pyfunc.log_model(
        name="tabpfn",
        python_model=TabPFNWrapper(),
        signature=signature,
        input_example=input_example,
        pip_requirements=["tabpfn", "numpy", "pandas"],
        registered_model_name=REGISTERED_MODEL_NAME,
    )
    print(f"Model URI: {model_info.model_uri}")
    print(f"Run ID:    {run.info.run_id}")
```

Once logged, tag the latest version as `champion`:

```python theme={null}
# Tag the latest version with a "champion" alias for serving
client = mlflow.MlflowClient()
versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
latest = max(versions, key=lambda v: int(v.version))
client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
print(f"Set alias 'champion' → version {latest.version} of {REGISTERED_MODEL_NAME}")
```

<Tip>
  Using the `champion` alias decouples your serving endpoint from version numbers. When you retrain or update TabPFN, simply reassign the alias - the endpoint continues to route requests without any configuration change.
</Tip>

***

## Step 5: Test Locally

Before deploying, verify the registered model works end-to-end from the notebook. The same model handles both raw Python objects and JSON strings - just load it and call `.predict()`.

<Tabs>
  <Tab title="Classification (raw Python)">
    ```python theme={null}
    # Load the registered model from MLflow
    loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

    # DataFrame with raw Python objects, the notebook-friendly example
    predictions = loaded_model.predict(pd.DataFrame([{
        "task_config": {"task": "classification"},
        "X_train": [
            [1.0, 2.0, 0.0],
            [3.0, 4.0, 1.0],
            [5.0, 6.0, 0.0],
            [7.0, 8.0, 1.0]
        ],
        "y_train": [0.0, 1.0, 0.0, 1.0],
        "X_test": [
            [2.0, 3.0, 0.0]
        ],
    }]))
    predictions
    ```
  </Tab>

  <Tab title="Classification (JSON strings)">
    ```python theme={null}
    loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

    predictions = loaded_model.predict(pd.DataFrame([{
        "task_config": json.dumps({"task": "classification", "predict_params": {"output_type": "probas"}}),
        "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]),
        "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]),
        "X_test":  json.dumps([[2.0, 3.0, 0.0]]),
    }]))
    predictions
    ```
  </Tab>

  <Tab title="Regression">
    ```python theme={null}
    loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

    predictions = loaded_model.predict(pd.DataFrame([{
        "task_config": {"task": "regression", "predict_params": {"output_type": "mean"}},
        "X_train": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
        "y_train": [1.5, 3.5, 5.5, 7.5],
        "X_test":  [[2.0, 3.0]],
    }]))
    predictions
    ```
  </Tab>
</Tabs>

The wrapper's `_maybe_parse_json()` transparently handles both formats - no code changes required between calling from a notebook and calling from a REST endpoint.

***

## Step 6: End-to-End Example - Lending Club Loan Data

Let's run a real-world classification task: predicting loan default (Good vs Bad) on the [Lending Club Q2 2018](https://www.kaggle.com/datasets/wordsforthewise/lending-club) dataset. This dataset ships with every Databricks workspace at `/databricks-datasets/`, so you can run this without any additional data download.

### Load and Prepare the Data

```python theme={null}
from sklearn.model_selection import train_test_split

# Load Lending Club Q2 2018 (ships with every Databricks workspace)
df = (
    spark.read.csv(
        "/databricks-datasets/lending-club-loan-stats/LoanStats_2018Q2.csv",
        header=True, inferSchema=True,
    )
    .select(
        "loan_status",
        "loan_amnt", "funded_amnt", "installment", "annual_inc", "dti",
        "open_acc", "revol_bal", "total_acc", "delinq_2yrs", "inq_last_6mths",
        "pub_rec", "mort_acc", "tot_cur_bal", "total_pymnt", "last_pymnt_amnt",
    )
    .dropna(subset=["loan_status"])
    .toPandas()
)

# Binary target: Good (0) vs Bad (1)
df["target"] = (df["loan_status"].apply(
    lambda s: 0 if s in ("Current", "Fully Paid") else 1
).astype(int))
df = df.drop(columns=["loan_status"]).dropna()

# Sample and split
df_sample = df.sample(n=6_000, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    df_sample.drop(columns=["target"]), df_sample["target"],
    test_size=0.2, random_state=42, stratify=df_sample["target"],
)

print(f"Train: {len(X_train):,} × {X_train.shape[1]}  (bad-loan rate: {y_train.mean():.2%})")
print(f"Test:  {len(X_test):,} × {X_test.shape[1]}")
```

### Run Predictions

Pass the full dataset through the registered MLflow model in a single call. We request `probas` so we can compute ROC-AUC alongside accuracy and F1.

```python theme={null}
# Predict via the registered MLflow model
predictions = loaded_model.predict(pd.DataFrame([{
    "task_config": {
        "task": "classification",
        "predict_params": {"output_type": "probas"},
    },
    "X_train": X_train.values.tolist(),
    "y_train": y_train.values.tolist(),
    "X_test":  X_test.values.tolist(),
}]))

probas = np.array(predictions)
y_pred = probas.argmax(axis=1)

print(f"Predicted {len(y_pred):,} samples")
```

### Evaluate Results

```python theme={null}
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

print(f"Accuracy:      {accuracy_score(y_test, y_pred):.4f}")
print(f"F1 (weighted): {f1_score(y_test, y_pred, average='weighted'):.4f}")
print(f"ROC-AUC:       {roc_auc_score(y_test, probas[:, 1]):.4f}")

# Materialize as a Spark DataFrame
results_sdf = spark.createDataFrame(pd.DataFrame({
    "actual":              y_test.values,
    "predicted":           y_pred,
    "probability_bad_loan": np.round(probas[:, 1], 4),
}))

display(results_sdf)
```

No fine-tuning, no feature engineering pipeline, no hyperparameter search - TabPFN fits and predicts in a single forward pass.

***

## Step 7: Deploy to Mosaic AI Model Serving

Deploy the registered model to a GPU-accelerated serving endpoint. The TabPFN token is securely passed via Databricks Secrets - it is never stored in the endpoint configuration.

```python theme={null}
import mlflow.deployments

# Resolve the "champion" alias to a version number
client = mlflow.MlflowClient()
champion = client.get_model_version_by_alias(REGISTERED_MODEL_NAME, "champion")
print(f"Deploying {REGISTERED_MODEL_NAME} version {champion.version}")

# Get the deployment MLflow client
client = mlflow.deployments.get_deploy_client("databricks")

# Create the endpoint, will return immediately and continue initializing the endpoint async
# Check for the status in your Databricks console
endpoint = client.create_endpoint(
    name="tabpfn-endpoint",
    config={
        "served_entities": [{
            "entity_name": REGISTERED_MODEL_NAME,
            "entity_version": str(champion.version),
            "workload_size": "Medium",
            "workload_type": "GPU_MEDIUM",
            "scale_to_zero_enabled": True,
            "environment_vars": {
                "TABPFN_TOKEN": "{{secrets/tabpfn/tabpfn_token}}",
            },
        }],
    },
)

print(f"Endpoint created: {endpoint['name']}")
```

<Note>
  `create_endpoint` returns immediately and initializes the endpoint asynchronously. Monitor the status in your Databricks console under **Serving** → **tabpfn-endpoint**.
</Note>

### Calling the Endpoint

Once the endpoint is live, you can reach it from Python, REST, or SQL:

<Tabs>
  <Tab title="Python (MLflow Deployments)">
    ```python theme={null}
    import mlflow.deployments

    client = mlflow.deployments.get_deploy_client("databricks")
    response = client.predict(
        endpoint="tabpfn-endpoint",
        inputs={"dataframe_records": [{
            "task_config": json.dumps({"task": "classification"}),
            "X_train": json.dumps([[1.0, 2.0], [3.0, 4.0]]),
            "y_train": json.dumps([0.0, 1.0]),
            "X_test":  json.dumps([[2.0, 3.0]]),
        }]},
    )
    print(response)
    ```
  </Tab>

  <Tab title="REST API">
    ```bash theme={null}
    curl -X POST \
      https://<workspace-url>/serving-endpoints/tabpfn-endpoint/invocations \
      -H "Authorization: Bearer $DATABRICKS_TOKEN" \
      -H "Content-Type: application/json" \
      -d '{
        "dataframe_records": [{
          "task_config": "{\"task\": \"classification\"}",
          "X_train": "[[1.0, 2.0], [3.0, 4.0]]",
          "y_train": "[0.0, 1.0]",
          "X_test":  "[[2.0, 3.0]]"
        }]
      }'
    ```
  </Tab>

  <Tab title="SQL (ai_query)">
    ```sql theme={null}
    SELECT ai_query(
      'tabpfn-endpoint',
      named_struct(
        'task_config', to_json(named_struct('task', 'classification')),
        'X_train',     to_json(array(array(1.0, 2.0), array(3.0, 4.0))),
        'y_train',     to_json(array(0.0, 1.0)),
        'X_test',      to_json(array(array(2.0, 3.0)))
      )
    ) AS prediction
    ```
  </Tab>
</Tabs>

***

## How the Input Format Works

Understanding `task_config` is key to using the endpoint effectively. It controls both what TabPFN does and how it does it.

| Field            | Type     | Description                                                                                                                  |
| ---------------- | -------- | ---------------------------------------------------------------------------------------------------------------------------- |
| `task`           | `string` | Required. `"classification"` or `"regression"`                                                                               |
| `tabpfn_config`  | `object` | Optional. Passed directly to `TabPFNClassifier` / `TabPFNRegressor` constructor (e.g. `n_estimators`, `softmax_temperature`) |
| `predict_params` | `object` | Optional. Controls output format via `output_type`                                                                           |

### Output Types

<Tabs>
  <Tab title="Classification">
    | `output_type`     | Description                       |
    | ----------------- | --------------------------------- |
    | `preds` (default) | Predicted class labels            |
    | `probas`          | Class probabilities (nested list) |
  </Tab>

  <Tab title="Regression">
    | `output_type`    | Description                     |
    | ---------------- | ------------------------------- |
    | `mean` (default) | Mean prediction                 |
    | `mode`           | Mode of predictive distribution |
    | `median`         | Median prediction               |
    | `quantiles`      | Quantile predictions            |
    | `main`           | Main summary statistics         |
    | `full`           | Full predictive distribution    |
  </Tab>
</Tabs>

***

## Promoting a New Model Version

When you want to update the model (e.g. after a new TabPFN release), simply re-run the registration cell and reassign the alias. The endpoint keeps routing to `champion` with no configuration change needed.

```python theme={null}
# After re-running the registration cell with a new TabPFNWrapper...
client = mlflow.MlflowClient()
versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
latest = max(versions, key=lambda v: int(v.version))
client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
print(f"Promoted version {latest.version} to 'champion'")
```

***

## Next Steps

<CardGroup cols={2}>
  <Card title="TabPFN GitHub" icon="github" href="https://github.com/priorlabs/tabpfn">
    Explore all supported parameters, estimator options, and output types.
  </Card>

  <Card title="Databricks Blog" icon="newspaper" href="https://www.databricks.com/blog/tappfn-ai-accelerates-business-transformation-databricks">
    Learn how TabPFN accelerates business transformation on Databricks.
  </Card>

  <Card title="MLflow PythonModel Docs" icon="book" href="https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html">
    Understand the `mlflow.pyfunc.PythonModel` interface used by the wrapper.
  </Card>

  <Card title="Mosaic AI Model Serving" icon="server" href="https://docs.databricks.com/en/machine-learning/model-serving/index.html">
    Configure autoscaling, traffic splitting, and monitoring for your endpoint.
  </Card>
</CardGroup>
