Skip to main content
Whether you’re just getting started with TabPFN or pushing it into production, this FAQ highlights practical answers to common questions about model limits, performance, reproducibility, and API best practices.
At minimum, TabPFN requires an NVIDIA T4 GPU to run efficiently. For best performance, we recommend A100 or H100 GPUs. Any dedicated GPU supported by PyTorch is compatible, but some models may not have enough memory for larger datasets or perform slowly. CPU, MPS (Apple Silicon), and integrated GPUs are also supported, but are only suitable for small datasets.For larger datasets, fit + predict time can be dramatically reduced by parallelizing inference over several GPUs. To enable this, set the device parameter of TabPFNClassifier and TabPFNRegressor.
TabPFN is optimized for tabular datasets up to 1M rows, since it performs in-context learning using transformer attention. See our Models for an overview of capabilities.
In the local package version text features are encoded as categoricals without considering their semantic meaning. Our API automatically detects text features and includes their semantic meaning into our prediction.
In the local package version date features are encoded as categoricals without considering their semantic meaning. Our API automatically detects date features and creates an optimized embedding.
With a fixed seed and in the same environment TabPFN inference is deterministic. Across different hardware (CPU, GPU, MPS) configurations small differences are expected.
Yes. TabPFN’s estimators can handle missing values internally, including pd.NA, without requiring manual imputation.
TabPFN is inherently robust to imbalance. However, you can further optimize for specific objectives:
  • Use balance_probabilities=True to optimize for balanced accuracy, balanced loss or other evaluation metrics weighting each class equally regardless of its frequency.
  • Use eval_metric="f1" (or other supported metrics) for specific precision-recall tradeoffs.
By default TabPFN performs most of its fitting inside the predict() step. That means latency scales roughly with the number of training rows. For fast inference we can use KV-caching (TabPFNClassifier(fit_mode="fit_with_cache")). We can also create a tree- or small MLP based model that yields almost the same accuracy as TabPFN. Contact sales@priorlabs.ai to access this solution.
Unless the fitted-model cache is enabled, the model is retrained each time .predict() is called. It is much faster to make a prediction for all your test points in a single .predict() call rather than calling it repeatedly. If you run out of memory, split the test points into batches of 1,000 to 10,000 and call .predict() for each batch.You can also tune the memory_saving_mode and n_preprocessing_jobs parameters of TabPFNClassifier and TabPFNRegressor for additional speed improvements. See the code documentation for details.
The fitted-model cache stores the model during .fit(), making subsequent .predict() calls fast by using a KV-Cache. Enable it by setting the fit_mode parameter of TabPFNClassifier or TabPFNRegressor to fit_with_cache.For TabPFN-3, this setting consumes about 7GB of GPU memory per estimator (so 56GB for our default 8 estimators) for 1M rows (and roughly half for 500K rows). For previous model versions, this setting consumes a lot of memory and is not recommended except for very small datasets.
TabPFN-3 requires PyTorch 2.5 or above; we recommend 2.8 or above for best performance.
Each API request consumes usage credits; the cost grows with the number of rows and columns in your dataset. You can check your current usage at ux.priorlabs.ai/account/usage, or from Python:
from tabpfn_client.config import get_api_usage
get_api_usage()
See API metering for the full breakdown of token pools and per-model limits.
After making predictions with a TabPFN model using the client you can access model-level metadata, such as the exact model version used for inference. Each estimator exposes a .last_meta attribute containing this information:
# After model.predict():
meta = model.last_meta
print(meta)
# Example output:
# {'model_version': 'v2', 'info': 'Mi4yLjB8MjAyNS0xMC0zMVQxMjowMTo0My42Mzc3ODcrMDA6MDA='}