Skip to main content

Documentation Index

Fetch the complete documentation index at: https://docs.priorlabs.ai/llms.txt

Use this file to discover all available pages before exploring further.

When calling .predict() (or .predict_proba() for classifiers) on a TabPFNClassifier or TabPFNRegressor, you may see:
TabPFNCUDAOutOfMemoryError: CUDA out of memory with <n> test samples.
TabPFNMPSOutOfMemoryError: MPS out of memory with <n> test samples.
TabPFN performs in-context learning: it processes your training set alongside each test point in a single attention pass. The memory bottleneck is the attention computation, which grows with the number of training rows, test rows, and features. The appropriate workaround depends on which dimension is the primary constraint.

Large training set

Use SUBSAMPLE_SAMPLES to draw a balanced subset of training rows for each estimator. As a starting point, set SUBSAMPLE_SAMPLES=50_000 and increase n_estimators so that n_estimators × SUBSAMPLE_SAMPLES covers your full training set; samples beyond this product may not be seen by any estimator. Lower SUBSAMPLE_SAMPLES if you still encounter OOM errors.
import numpy as np
from tabpfn import TabPFNClassifier

model = TabPFNClassifier(
    device="auto",
    ignore_pretraining_limits=True,
    n_estimators=16,
    inference_config={
        "SUBSAMPLE_SAMPLES": 50_000, # lower if still OOM
    },
)
model.fit(X_train, y_train)
predictions = model.predict_proba(X_test)

Large test set

Batch your test set to avoid loading too many test rows into memory at once. Start with CHUNK_SIZE = 1000 and increase if memory allows.
import numpy as np

predictions = []
CHUNK_SIZE = 1000
X_test_arr = np.asarray(X_test)

for i in range(0, len(X_test_arr), CHUNK_SIZE):
    chunk_preds = model.predict_proba(X_test_arr[i : i + CHUNK_SIZE])
    predictions.append(chunk_preds)

predictions = np.vstack(predictions)
By default, the transformer re-encodes the training context on every .predict() call. Batching reduces peak memory usage but increases total prediction time. Setting fit_mode="fit_with_cache" additionally caches the transformer KV-cache after .fit(), skipping this re-encoding on subsequent .predict() calls.

Large training set and test set

Combine both approaches: subsample the training context and batch the test set.
import numpy as np
from tabpfn import TabPFNClassifier

model = TabPFNClassifier(
    device="auto",
    ignore_pretraining_limits=True,
    n_estimators=16,
    fit_mode="fit_with_cache",
    inference_config={
        "SUBSAMPLE_SAMPLES": 50_000,  # lower if still OOM
    },
)
model.fit(X_train, y_train)

predictions = []
CHUNK_SIZE = 1000
X_test_arr = np.asarray(X_test)

for i in range(0, len(X_test_arr), CHUNK_SIZE):
    chunk_preds = model.predict_proba(X_test_arr[i : i + CHUNK_SIZE])
    predictions.append(chunk_preds)

predictions = np.vstack(predictions)

Reduce memory footprint

If the above doesn’t resolve the error, try enabling memory_saving_mode. This trades compute for lower peak memory usage:
model = TabPFNClassifier(
    device="auto",
    memory_saving_mode=True,
)

Hardware and Memory requirements

The table below shows rough VRAM estimates for TabPFN v3 default settings (1 estimator, 200 features, 1000 test rows, no subsampling; these are not default settings) on a basic T4 GPU. Actual usage scales with the number of features, training rows, and test rows.
Training rowsTest rowsApproximate VRAM
1,0001,000~2 GB
5,0001,000~2 GB
10,0001,000~2 GB
50,0001,000~2.5 GB
100,0001,000~3 GB
200,0001,000~4 GB
500,0001,000~10 GB
Use the workarounds above to reduce memory usage. If none of the above resolves the error, your GPU does not have enough VRAM for your dataset size. If you cannot get access to powerful enough hardware, consider the TabPFN API Client (no local GPU required).

Best Practices

Hardware recommendations, batch inference patterns, and fit/predict trade-off tuning.

Fast Inference

Distill TabPFN into a compact MLP or tree ensemble for low-memory production deployments.

FAQ

Common questions about GPU requirements, speed, and API rate limits.