TabPFN is a tabular foundation model that makes accurate predictions in a single forward pass - no dataset-specific training, no hyperparameter search, no separate model artifact per target variable.The promise: one registered model that handles both classification and regression across any tabular dataset.
Less Operational Complexity
Manage one model instead of many. No per-dataset training runs, no hyperparameter searches, no separate artifacts per outcome.
Higher Productivity
Use a single model for classification and regression across any tabular dataset - from a notebook or SQL ai_query().
By the end of this tutorial you will have:
Wrapped TabPFN as an MLflow PythonModel
Registered it to Unity Catalog with a champion alias
Tested it locally on classification and regression tasks
Run an end-to-end example on the Lending Club loan dataset
Deployed it to a GPU-accelerated Mosaic AI Model Serving endpoint
For the full list of supported TabPFN parameters (estimators, output types, configuration options) see the TabPFN GitHub repository. For broader context on TabPFN and Databricks, read the Databricks blog post.
TabPFN runs inference via PyTorch, so a GPU significantly speeds up predictions - especially on larger datasets. In the notebook toolbar, select Serverless with a GPU-enabled instance (A10G or better). Serverless handles provisioning and scaling automatically.
TabPFN model weights are gated and require authentication. Store your token as a Databricks secret so it never lives in plain text in your notebook.Run these commands in the Databricks CLI:
The TabPFNWrapper is a custom mlflow.pyfunc.PythonModel. It is the heart of this integration and handles three concerns:
Dual-format input - accepts both raw Python objects (from notebooks) and JSON strings (from SQL ai_query())
Task routing - classification or regression, controlled by task_config
Flexible output - class labels, probabilities, or regression predictions based on output_type
All input columns use DataType.string so the same endpoint works from Python, REST, and SQL without any changes to the registered model. The _maybe_parse_json() helper transparently handles both formats at predict time.
import osimport jsonimport inspectimport mlflowimport numpy as npimport pandas as pdfrom typing import Literalfrom mlflow.models.signature import ModelSignaturefrom mlflow.types.schema import ( Array, ColSpec, DataType, Object, ParamSchema, ParamSpec, Property, Schema)from tabpfn import TabPFNClassifier, TabPFNRegressorclass TabPFNWrapper(mlflow.pyfunc.PythonModel): """MLflow PythonModel wrapper for TabPFN""" _CLASSIFICATION_OUTPUT_TYPES = {"preds", "probas"} """The model output types for classification.""" _REGRESSION_OUTPUT_TYPES = {"mean", "mode", "median", "quantiles", "main", "full"} """The model output types for regression.""" @staticmethod def _maybe_parse_json(value: str | dict | list) -> dict | list: """Helper function to parse a JSON string if needed. Args: value: The value to parse. Can be a string, dict, or list. Returns: The parsed value. """ if isinstance(value, str): return json.loads(value) return value def _get_output_type( self, task: Literal["classification", "regression"], output_type: str | None ) -> str: """Get the prediction output type. Args: output_type: The output type to get. Returns: The prediction output type. """ if output_type is not None: supported_output_types = self._CLASSIFICATION_OUTPUT_TYPES | self._REGRESSION_OUTPUT_TYPES if output_type not in supported_output_types: raise ValueError(f"Unknown output_type: {output_type!r}. Must be one of {supported_output_types}") return output_type # Fallback to the defaults return "preds" if task == "classification" else "mean" def _init_estimator( self, task: Literal["classification", "regression"], config: dict ) -> TabPFNClassifier | TabPFNRegressor: """Initialize a TabPFN estimator. Args: task: The task to initialize the estimator for. config: The configuration for the estimator. Returns: The initialized estimator. """ Estimator = TabPFNClassifier if task == "classification" else TabPFNRegressor # Validate provided config keys against the estimator constructor sig = inspect.signature(Estimator.__init__) constructor_params = set(sig.parameters.keys()) - {"self"} supplied_keys = set(config.keys()) invalid_keys = supplied_keys - constructor_params # Raise an error if any invalid keys are provided if invalid_keys: msg = ( f"Config contains invalid parameters for {Estimator.__name__}: {sorted(invalid_keys)}.\n" f"Allowed parameters: {sorted(constructor_params)}" ) raise ValueError(msg) return Estimator(**config) def predict(self, model_input, params=None): """Run predictions. TabPFN runs predictions in a single forward pass. The model is fitted on the training data and then used to predict on the test data. Args: model_input: The input data to predict on. params: The parameters for the model. Returns: The predictions. """ # Accept both DataFrame (local pyfunc) and list-of-dicts (serving endpoint) if isinstance(model_input, pd.DataFrame): model_input = model_input.to_dict(orient="records") assert isinstance(model_input, list), "model_input must be a list with a single row" assert len(model_input) == 1, "model_input must have a single row" model_input: dict = model_input[0] # Parse the task configuration task_config = self._maybe_parse_json(model_input["task_config"]) or {} task = task_config.get("task") if task is None: raise KeyError("Task is required, must be 'classification' or 'regression'") tabpfn_config = task_config.get("tabpfn_config") or {} predict_params = task_config.get("predict_params") or {} output_type = self._get_output_type(task, predict_params.get("output_type")) # Parse the input data X_train = np.array(self._maybe_parse_json(model_input["X_train"]), dtype=np.float64) y_train = np.array(self._maybe_parse_json(model_input["y_train"]), dtype=np.float64) X_test = np.array(self._maybe_parse_json(model_input["X_test"]), dtype=np.float64) # Initialize the estimator estimator = self._init_estimator(task, tabpfn_config) # Fit the estimator estimator.fit(X_train, y_train) # Run predictions if task == "classification": if output_type == "probas": return estimator.predict_proba(X_test).tolist() return estimator.predict(X_test).tolist() predictions = estimator.predict(X_test, output_type=output_type) return predictions.tolist()
The model signature tells MLflow (and Unity Catalog) the expected input and output shapes. All columns are DataType.string to support both raw Python and JSON-serialized inputs from SQL ai_query().
# All inputs are STRING for dual-format support:# - JSON strings from SQL ai_query() via to_json()# - Raw Python objects, such as `pd.DataFrame`, from notebook callsinput_schema = Schema([ ColSpec(DataType.string, "task_config"), ColSpec(DataType.string, "X_train"), ColSpec(DataType.string, "y_train"), ColSpec(DataType.string, "X_test"),])# Output schema is required by Unity Catalog.# STRING is the best fit - the wrapper returns variable shapes# (flat list for preds, nested list for probas) and the serving# layer serializes whatever .tolist() produces into a JSON string.output_schema = Schema([ColSpec(DataType.string, name="predictions")])signature = ModelSignature(inputs=input_schema, outputs=output_schema)
The output_schema is required for Unity Catalog registration. Using DataType.string is the right choice here because the wrapper can return either a flat list (for preds) or a nested list (for probas/regression quantiles), and the serving layer serializes the result to JSON regardless.
Log the model with MLflow and register it under a fully qualified Unity Catalog path (catalog.schema.tabpfn). The input_example uses json.dumps() to match the all-string signature - the wrapper deserializes at predict time via _maybe_parse_json().After registration, we tag the latest version with a "champion" alias. The serving endpoint references this alias, so you can promote future versions without touching the endpoint config.
# Tag the latest version with a "champion" alias for servingclient = mlflow.MlflowClient()versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")latest = max(versions, key=lambda v: int(v.version))client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)print(f"Set alias 'champion' → version {latest.version} of {REGISTERED_MODEL_NAME}")
Using the champion alias decouples your serving endpoint from version numbers. When you retrain or update TabPFN, simply reassign the alias - the endpoint continues to route requests without any configuration change.
Before deploying, verify the registered model works end-to-end from the notebook. The same model handles both raw Python objects and JSON strings - just load it and call .predict().
Classification (raw Python)
Classification (JSON strings)
Regression
# Load the registered model from MLflowloaded_model = mlflow.pyfunc.load_model(model_info.model_uri)# DataFrame with raw Python objects, the notebook-friendly examplepredictions = loaded_model.predict(pd.DataFrame([{ "task_config": {"task": "classification"}, "X_train": [ [1.0, 2.0, 0.0], [3.0, 4.0, 1.0], [5.0, 6.0, 0.0], [7.0, 8.0, 1.0] ], "y_train": [0.0, 1.0, 0.0, 1.0], "X_test": [ [2.0, 3.0, 0.0] ],}]))predictions
The wrapper’s _maybe_parse_json() transparently handles both formats - no code changes required between calling from a notebook and calling from a REST endpoint.
Step 6: End-to-End Example - Lending Club Loan Data
Let’s run a real-world classification task: predicting loan default (Good vs Bad) on the Lending Club Q2 2018 dataset. This dataset ships with every Databricks workspace at /databricks-datasets/, so you can run this without any additional data download.
Deploy the registered model to a GPU-accelerated serving endpoint. The TabPFN token is securely passed via Databricks Secrets - it is never stored in the endpoint configuration.
import mlflow.deployments# Resolve the "champion" alias to a version numberclient = mlflow.MlflowClient()champion = client.get_model_version_by_alias(REGISTERED_MODEL_NAME, "champion")print(f"Deploying {REGISTERED_MODEL_NAME} version {champion.version}")# Get the deployment MLflow clientclient = mlflow.deployments.get_deploy_client("databricks")# Create the endpoint, will return immediately and continue initializing the endpoint async# Check for the status in your Databricks consoleendpoint = client.create_endpoint( name="tabpfn-endpoint", config={ "served_entities": [{ "entity_name": REGISTERED_MODEL_NAME, "entity_version": str(champion.version), "workload_size": "Medium", "workload_type": "GPU_MEDIUM", "scale_to_zero_enabled": True, "environment_vars": { "TABPFN_TOKEN": "{{secrets/tabpfn/tabpfn_token}}", }, }], },)print(f"Endpoint created: {endpoint['name']}")
create_endpoint returns immediately and initializes the endpoint asynchronously. Monitor the status in your Databricks console under Serving → tabpfn-endpoint.
When you want to update the model (e.g. after a new TabPFN release), simply re-run the registration cell and reassign the alias. The endpoint keeps routing to champion with no configuration change needed.
# After re-running the registration cell with a new TabPFNWrapper...client = mlflow.MlflowClient()versions = client.search_model_versions(f"name='{REGISTERED_MODEL_NAME}'")latest = max(versions, key=lambda v: int(v.version))client.set_registered_model_alias(REGISTERED_MODEL_NAME, "champion", latest.version)print(f"Promoted version {latest.version} to 'champion'")