TabPFNEmbedding is a scikit-learn style transformer with the familiar fit / fit_transform / transform API. It supports two extraction modes:
- Out-of-fold embeddings (
n_fold >= 2, recommended) — robust, leakage-free training-set embeddings extracted via K-fold cross-validation. These generalize better and give stronger downstream performance. - Vanilla embeddings (
n_fold=0) — a single model is trained on the full dataset and used for everything; cheaper, but the training embeddings leak label information.
Getting Started
The embedding module ships in the basetabpfn-extensions package (no extra needed). Install it alongside the local tabpfn engine:
The Interface
TabPFNEmbedding follows the scikit-learn transformer contract, and the method you call depends on whose embeddings you want:
| Method | Use for | Returns |
|---|---|---|
fit_transform(X_train, y_train) | the training data | Out-of-fold embeddings when n_fold >= 2 (no label leakage); full-model embeddings when n_fold == 0 |
transform(X) | unseen / held-out data | Embeddings from the model trained on the full training set |
transform always runs through the final full-data model. It never returns
cached training embeddings, even if X happens to equal the training set — so
for leakage-free training embeddings, always use fit_transform (or read the
train_embeddings_ attribute after fit).model= parameter. Use a classifier or regressor depending on your task — the examples below show both.
1. Out-of-fold (robust) embeddings — recommended
Withn_fold >= 2, the training data is split into K folds; a fresh model is trained on each fold and used to embed its held-out partition. The out-of-fold (OOF) embeddings are reassembled into the original sample order, and a final model is refit on the full training set to embed unseen data.
transform. This is the robust variant introduced in “A Closer Look at TabPFN v2:
Strength, Limitation, and Extension” (arXiv:2502.17361),
and larger n_fold values yield more robust embeddings.
In practice this lifts downstream performance: the get_embeddings.py example compares a baseline linear model, vanilla TabPFN embeddings, and K-fold embeddings on the same data — the K-fold embeddings come out ahead for both classification accuracy and regression R².
Classifiers use StratifiedKFold and regressors use KFold. Set shuffle=True (with an optional random_state) to shuffle the split. n_fold=1 is invalid — use 0 for vanilla or >= 2 for cross-validation.
2. Vanilla embeddings
Withn_fold=0, a single model is trained on the entire training set and reused for both training and unseen data. This is cheaper (one fit instead of K + 1) and fine when you only need embeddings for unseen data via transform, but avoid it for training-set embeddings you plan to feed into a downstream model — see the leakage caveat above.
Output shape. Both
fit_transform and transform return a 3D array of
shape (n_estimators, n_samples, embed_dim) — one embedding matrix per ensemble
member. This is not a drop-in 2D input for an sklearn Pipeline. Select a single
member (embeddings[0]) or aggregate across axis=0 before passing the result
to a downstream estimator.Using Embeddings as Features
A common pattern is to use TabPFN embeddings as features for a lightweight downstream model. Because the embeddings are 3D, select an ensemble member (embeddings[0]) to get a 2D feature matrix.
- Classification
- Regression
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
n_fold | int | 0 | 0 disables CV (vanilla). >= 2 enables K-fold out-of-fold embeddings. 1 is invalid. |
model | TabPFNClassifier | TabPFNRegressor | None | None | Pre-configured TabPFN estimator. When None, the task is inferred from y at fit time (with a warning). Passing it explicitly is recommended. |
shuffle | bool | False | Whether to shuffle the K-fold split. |
random_state | int | None | None | Seed used by the K-fold split when shuffle=True. |
model_ (the fitted full-data model) and train_embeddings_ (the training-set embeddings, OOF when n_fold >= 2).
Migration. The old
get_embeddings(X_train, y_train, X, data_source=...)
method and the tabpfn_clf / tabpfn_reg constructor arguments are deprecated.
Use model= together with fit_transform (training, OOF) and transform
(unseen data) instead.Example Script
Full runnable example for classification and regression.
Google Colab Example
Check out our Google Colab for a demo.