Skip to main content

Overview

TabPFN is a tabular foundation model that makes accurate predictions in a single forward pass - no dataset-specific training, no hyperparameter search, no separate model artifact per target variable. The promise: one registered model that handles both classification and regression across any tabular dataset.

Less Operational Complexity

Manage one model instead of many. No per-dataset training runs, no hyperparameter searches, no separate artifacts per outcome.

Higher Productivity

Use a single model for classification and regression across any tabular dataset - from a notebook or SQL ai_query().
By the end of this tutorial you will have:
  • Wrapped TabPFN as an MLflow PythonModel
  • Registered it to Unity Catalog with a champion alias
  • Tested it locally on classification and regression tasks
  • Run an end-to-end example on the Lending Club loan dataset
  • Deployed it to a GPU-accelerated Mosaic AI Model Serving endpoint
For the full list of supported TabPFN parameters (estimators, output types, configuration options) see the TabPFN GitHub repository. For broader context on TabPFN and Databricks, read the Databricks blog post.

Prerequisites

  • A Databricks workspace with Unity Catalog enabled
  • Access to a serverless GPU environment with at least an NVIDIA A10G GPU
  • A TabPFN token from Prior Labs
TabPFN runs inference via PyTorch, so a GPU significantly speeds up predictions - especially on larger datasets. In the notebook toolbar, select Serverless with a GPU-enabled instance (A10G or better). Serverless handles provisioning and scaling automatically.

Step 1: Install Dependencies

Run the following in your Databricks notebook cell:
%pip install mlflow tabpfn
dbutils.library.restartPython()

Step 2: Configure Your TabPFN Token

TabPFN model weights are gated and require authentication. Store your token as a Databricks secret so it never lives in plain text in your notebook. Run these commands in the Databricks CLI:
databricks secrets create-scope tabpfn
databricks secrets put-secret tabpfn tabpfn_token --string-value "<your-token>"
Then reference the secret in your notebook:
import os

os.environ["TABPFN_TOKEN"] = dbutils.secrets.get(scope="tabpfn", key="tabpfn_token")

Step 3: Define the Wrapper and Signature

The TabPFNWrapper is a custom mlflow.pyfunc.PythonModel. It is the heart of this integration and handles three concerns:
  • Dual-format input - accepts both raw Python objects (from notebooks) and JSON strings (from SQL ai_query())
  • Task routing - classification or regression, controlled by task_config
  • Flexible output - class labels, probabilities, or regression predictions based on output_type
All input columns use DataType.string so the same endpoint works from Python, REST, and SQL without any changes to the registered model. The _maybe_parse_json() helper transparently handles both formats at predict time.
import os
import json
import inspect

import mlflow
import numpy as np
import pandas as pd

from typing import Literal

from mlflow.models.signature import ModelSignature
from mlflow.types.schema import (
    Array,
    ColSpec,
    DataType,
    Object,
    ParamSchema,
    ParamSpec,
    Property,
    Schema
)
from tabpfn import TabPFNClassifier, TabPFNRegressor


class TabPFNWrapper(mlflow.pyfunc.PythonModel):
    """MLflow PythonModel wrapper for TabPFN"""

    _CLASSIFICATION_OUTPUT_TYPES = {"preds", "probas"}
    """The model output types for classification."""

    _REGRESSION_OUTPUT_TYPES = {"mean", "mode", "median", "quantiles", "main", "full"}
    """The model output types for regression."""

    @staticmethod
    def _maybe_parse_json(value: str | dict | list) -> dict | list:
        """Helper function to parse a JSON string if needed.

        Args:
            value: The value to parse. Can be a string, dict, or list.

        Returns:
            The parsed value.
        """
        if isinstance(value, str):
            return json.loads(value)
        return value

    def _get_output_type(
        self,
        task: Literal["classification", "regression"],
        output_type: str | None
    ) -> str:
        """Get the prediction output type.

        Args:
            output_type: The output type to get.

        Returns:
            The prediction output type.
        """
        if output_type is not None:
            supported_output_types = self._CLASSIFICATION_OUTPUT_TYPES | self._REGRESSION_OUTPUT_TYPES
            if output_type not in supported_output_types:
                raise ValueError(f"Unknown output_type: {output_type!r}. Must be one of {supported_output_types}")
            return output_type

        # Fallback to the defaults
        return "preds" if task == "classification" else "mean"

    def _init_estimator(
        self,
        task: Literal["classification", "regression"],
        config: dict
    ) -> TabPFNClassifier | TabPFNRegressor:
        """Initialize a TabPFN estimator.

        Args:
            task: The task to initialize the estimator for.
            config: The configuration for the estimator.

        Returns:
            The initialized estimator.
        """
        Estimator = TabPFNClassifier if task == "classification" else TabPFNRegressor

        # Validate provided config keys against the estimator constructor
        sig = inspect.signature(Estimator.__init__)

        constructor_params = set(sig.parameters.keys()) - {"self"}
        supplied_keys = set(config.keys())
        invalid_keys = supplied_keys - constructor_params

        # Raise an error if any invalid keys are provided
        if invalid_keys:
            msg = (
                f"Config contains invalid parameters for {Estimator.__name__}: {sorted(invalid_keys)}.\n"
                f"Allowed parameters: {sorted(constructor_params)}"
            )
            raise ValueError(msg)

        return Estimator(**config)

    def predict(self, model_input, params=None):
        """Run predictions.

        TabPFN runs predictions in a single forward pass. The model is
        fitted on the training data and then used to predict on the test data.

        Args:
            model_input: The input data to predict on.
            params: The parameters for the model.

        Returns:
            The predictions.
        """
        # Accept both DataFrame (local pyfunc) and list-of-dicts (serving endpoint)
        if isinstance(model_input, pd.DataFrame):
            model_input = model_input.to_dict(orient="records")

        assert isinstance(model_input, list), "model_input must be a list with a single row"
        assert len(model_input) == 1, "model_input must have a single row"

        model_input: dict = model_input[0]

        # Parse the task configuration
        task_config = self._maybe_parse_json(model_input["task_config"]) or {}

        task = task_config.get("task")
        if task is None:
            raise KeyError("Task is required, must be 'classification' or 'regression'")

        tabpfn_config = task_config.get("tabpfn_config") or {}
        predict_params = task_config.get("predict_params") or {}

        output_type = self._get_output_type(task, predict_params.get("output_type"))

        # Parse the input data
        X_train = np.array(self._maybe_parse_json(model_input["X_train"]), dtype=np.float64)
        y_train = np.array(self._maybe_parse_json(model_input["y_train"]), dtype=np.float64)
        X_test  = np.array(self._maybe_parse_json(model_input["X_test"]),  dtype=np.float64)

        # Initialize the estimator
        estimator = self._init_estimator(task, tabpfn_config)

        # Fit the estimator
        estimator.fit(X_train, y_train)

        # Run predictions
        if task == "classification":
            if output_type == "probas":
                return estimator.predict_proba(X_test).tolist()
            return estimator.predict(X_test).tolist()

        predictions = estimator.predict(X_test, output_type=output_type)
        return predictions.tolist()

Define the Model Signature

The model signature tells MLflow (and Unity Catalog) the expected input and output shapes. All columns are DataType.string to support both raw Python and JSON-serialized inputs from SQL ai_query().
# All inputs are STRING for dual-format support:
#   - JSON strings from SQL ai_query() via to_json()
#   - Raw Python objects, such as `pd.DataFrame`, from notebook calls

input_schema = Schema([
    ColSpec(DataType.string, "task_config"),
    ColSpec(DataType.string, "X_train"),
    ColSpec(DataType.string, "y_train"),
    ColSpec(DataType.string, "X_test"),
])

# Output schema is required by Unity Catalog.
# STRING is the best fit - the wrapper returns variable shapes
# (flat list for preds, nested list for probas) and the serving
# layer serializes whatever .tolist() produces into a JSON string.
output_schema = Schema([ColSpec(DataType.string, name="predictions")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
The output_schema is required for Unity Catalog registration. Using DataType.string is the right choice here because the wrapper can return either a flat list (for preds) or a nested list (for probas/regression quantiles), and the serving layer serializes the result to JSON regardless.

Step 4: Register to Unity Catalog

Log the model with MLflow and register it under a fully qualified Unity Catalog path (catalog.schema.tabpfn). The input_example uses json.dumps() to match the all-string signature - the wrapper deserializes at predict time via _maybe_parse_json(). After registration, we tag the latest version with a "champion" alias. The serving endpoint references this alias, so you can promote future versions without touching the endpoint config.
# Input example uses json.dumps() to match the all-string signature
# The wrapper's _maybe_parse_json() deserializes these at predict time
input_example = pd.DataFrame([{
    "task_config": json.dumps({
        "task": "classification",
        "tabpfn_config": {
            "n_estimators": 8,
            "softmax_temperature": 0.9,
        },
        "predict_params": {
            "output_type": "preds",
        },
    }),
    "X_train": json.dumps([[1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0]]),
    "y_train": json.dumps([0.0, 1.0, 0.0, 1.0]),
    "X_test": json.dumps([[2.0, 3.0, 0.0]]),
}])

# Fully qualified Unity Catalog path (portable across workspaces)
CATALOG = spark.catalog.currentCatalog()
SCHEMA = spark.catalog.currentDatabase()
REGISTERED_MODEL_NAME = f"{CATALOG}.{SCHEMA}.tabpfn"

with mlflow.start_run(run_name="tabpfn-registration") as run:
    model_info = mlflow.pyfunc.log_model(
        name="tabpfn",
        python_model=TabPFNWrapper(),
        signature=signature,
        input_example=input_example,
        pip_requirements=["tabpfn", "numpy", "pandas"],
        registered_model_name=REGISTERED_MODEL_NAME,
    )
    print(f"Model URI: {model_info.model_uri}")
    print(f"Run ID:    {run.info.run_id}")
Once logged, tag the latest version as champion:
# Tag the latest version with a "champion" alias for serving
client = mlflow.MlflowClient()
versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
latest = max(versions, key=lambda v: int(v.version))
client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
print(f"Set alias 'champion' → version {latest.version} of {REGISTERED_MODEL_NAME}")
Using the champion alias decouples your serving endpoint from version numbers. When you retrain or update TabPFN, simply reassign the alias - the endpoint continues to route requests without any configuration change.

Step 5: Test Locally

Before deploying, verify the registered model works end-to-end from the notebook. The same model handles both raw Python objects and JSON strings - just load it and call .predict().
# Load the registered model from MLflow
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

# DataFrame with raw Python objects, the notebook-friendly example
predictions = loaded_model.predict(pd.DataFrame([{
    "task_config": {"task": "classification"},
    "X_train": [
        [1.0, 2.0, 0.0],
        [3.0, 4.0, 1.0],
        [5.0, 6.0, 0.0],
        [7.0, 8.0, 1.0]
    ],
    "y_train": [0.0, 1.0, 0.0, 1.0],
    "X_test": [
        [2.0, 3.0, 0.0]
    ],
}]))
predictions
The wrapper’s _maybe_parse_json() transparently handles both formats - no code changes required between calling from a notebook and calling from a REST endpoint.

Step 6: End-to-End Example - Lending Club Loan Data

Let’s run a real-world classification task: predicting loan default (Good vs Bad) on the Lending Club Q2 2018 dataset. This dataset ships with every Databricks workspace at /databricks-datasets/, so you can run this without any additional data download.

Load and Prepare the Data

from sklearn.model_selection import train_test_split

# Load Lending Club Q2 2018 (ships with every Databricks workspace)
df = (
    spark.read.csv(
        "/databricks-datasets/lending-club-loan-stats/LoanStats_2018Q2.csv",
        header=True, inferSchema=True,
    )
    .select(
        "loan_status",
        "loan_amnt", "funded_amnt", "installment", "annual_inc", "dti",
        "open_acc", "revol_bal", "total_acc", "delinq_2yrs", "inq_last_6mths",
        "pub_rec", "mort_acc", "tot_cur_bal", "total_pymnt", "last_pymnt_amnt",
    )
    .dropna(subset=["loan_status"])
    .toPandas()
)

# Binary target: Good (0) vs Bad (1)
df["target"] = (df["loan_status"].apply(
    lambda s: 0 if s in ("Current", "Fully Paid") else 1
).astype(int))
df = df.drop(columns=["loan_status"]).dropna()

# Sample and split
df_sample = df.sample(n=6_000, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    df_sample.drop(columns=["target"]), df_sample["target"],
    test_size=0.2, random_state=42, stratify=df_sample["target"],
)

print(f"Train: {len(X_train):,} × {X_train.shape[1]}  (bad-loan rate: {y_train.mean():.2%})")
print(f"Test:  {len(X_test):,} × {X_test.shape[1]}")

Run Predictions

Pass the full dataset through the registered MLflow model in a single call. We request probas so we can compute ROC-AUC alongside accuracy and F1.
# Predict via the registered MLflow model
predictions = loaded_model.predict(pd.DataFrame([{
    "task_config": {
        "task": "classification",
        "predict_params": {"output_type": "probas"},
    },
    "X_train": X_train.values.tolist(),
    "y_train": y_train.values.tolist(),
    "X_test":  X_test.values.tolist(),
}]))

probas = np.array(predictions)
y_pred = probas.argmax(axis=1)

print(f"Predicted {len(y_pred):,} samples")

Evaluate Results

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

print(f"Accuracy:      {accuracy_score(y_test, y_pred):.4f}")
print(f"F1 (weighted): {f1_score(y_test, y_pred, average='weighted'):.4f}")
print(f"ROC-AUC:       {roc_auc_score(y_test, probas[:, 1]):.4f}")

# Materialize as a Spark DataFrame
results_sdf = spark.createDataFrame(pd.DataFrame({
    "actual":              y_test.values,
    "predicted":           y_pred,
    "probability_bad_loan": np.round(probas[:, 1], 4),
}))

display(results_sdf)
No fine-tuning, no feature engineering pipeline, no hyperparameter search - TabPFN fits and predicts in a single forward pass.

Step 7: Deploy to Mosaic AI Model Serving

Deploy the registered model to a GPU-accelerated serving endpoint. The TabPFN token is securely passed via Databricks Secrets - it is never stored in the endpoint configuration.
import mlflow.deployments

# Resolve the "champion" alias to a version number
client = mlflow.MlflowClient()
champion = client.get_model_version_by_alias(REGISTERED_MODEL_NAME, "champion")
print(f"Deploying {REGISTERED_MODEL_NAME} version {champion.version}")

# Get the deployment MLflow client
client = mlflow.deployments.get_deploy_client("databricks")

# Create the endpoint, will return immediately and continue initializing the endpoint async
# Check for the status in your Databricks console
endpoint = client.create_endpoint(
    name="tabpfn-endpoint",
    config={
        "served_entities": [{
            "entity_name": REGISTERED_MODEL_NAME,
            "entity_version": str(champion.version),
            "workload_size": "Medium",
            "workload_type": "GPU_MEDIUM",
            "scale_to_zero_enabled": True,
            "environment_vars": {
                "TABPFN_TOKEN": "{{secrets/tabpfn/tabpfn_token}}",
            },
        }],
    },
)

print(f"Endpoint created: {endpoint['name']}")
create_endpoint returns immediately and initializes the endpoint asynchronously. Monitor the status in your Databricks console under Serving → tabpfn-endpoint.

Calling the Endpoint

Once the endpoint is live, you can reach it from Python, REST, or SQL:
import mlflow.deployments

client = mlflow.deployments.get_deploy_client("databricks")
response = client.predict(
    endpoint="tabpfn-endpoint",
    inputs={"dataframe_records": [{
        "task_config": json.dumps({"task": "classification"}),
        "X_train": json.dumps([[1.0, 2.0], [3.0, 4.0]]),
        "y_train": json.dumps([0.0, 1.0]),
        "X_test":  json.dumps([[2.0, 3.0]]),
    }]},
)
print(response)

How the Input Format Works

Understanding task_config is key to using the endpoint effectively. It controls both what TabPFN does and how it does it.
FieldTypeDescription
taskstringRequired. "classification" or "regression"
tabpfn_configobjectOptional. Passed directly to TabPFNClassifier / TabPFNRegressor constructor (e.g. n_estimators, softmax_temperature)
predict_paramsobjectOptional. Controls output format via output_type

Output Types

output_typeDescription
preds (default)Predicted class labels
probasClass probabilities (nested list)

Promoting a New Model Version

When you want to update the model (e.g. after a new TabPFN release), simply re-run the registration cell and reassign the alias. The endpoint keeps routing to champion with no configuration change needed.
# After re-running the registration cell with a new TabPFNWrapper...
client = mlflow.MlflowClient()
versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")
latest = max(versions, key=lambda v: int(v.version))
client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)
print(f"Promoted version {latest.version} to 'champion'")

Next Steps

TabPFN GitHub

Explore all supported parameters, estimator options, and output types.

Databricks Blog

Learn how TabPFN accelerates business transformation on Databricks.

MLflow PythonModel Docs

Understand the mlflow.pyfunc.PythonModel interface used by the wrapper.

Mosaic AI Model Serving

Configure autoscaling, traffic splitting, and monitoring for your endpoint.