Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,55 @@ and login (on another machine) using your access token, skipping the interactive
tabpfn_client.set_access_token(token)
```

## AWS SageMaker (BYOC)

If you've subscribed to the TabPFN AWS Marketplace listing and deployed the container to a SageMaker real-time endpoint, you can invoke it through `tabpfn_client.sagemaker` using a near-identical scikit-learn surface. There is no PriorLabs API token in this path — you authenticate to your own AWS account, and `predict` calls are billed by AWS SageMaker rather than against your TabPFN usage allowance.

Install with the optional `sagemaker` extra to pull in `boto3`:

```bash
pip install --upgrade 'tabpfn-client[sagemaker]'
```

Then point the estimator at your endpoint:

```python
from tabpfn_client.sagemaker import TabPFNClassifier, TabPFNRegressor
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

clf = TabPFNClassifier(
endpoint_name="your-sagemaker-endpoint-name",
region_name="us-east-1",
)
clf.fit(X_train, y_train)
clf.predict(X_test)
clf.predict_proba(X_test)
```

Notes:

- AWS credentials are resolved through the standard `boto3` credential chain (env vars, `~/.aws/credentials`, instance profile, SSO, etc.). Pass `boto_session=session` to use an explicit `boto3.Session`.
- `fit()` is local: TabPFN is in-context, so the estimator just keeps `X_train` / `y_train` and ships them on each `predict*` call. There is no separate training job.
- Set `use_kv_cache=True` to opt into the v3 KV-cache path on the server: the first round-trip uploads training data and captures a `model_id`; subsequent `predict*` calls reference that id and skip the training upload. This trades a stateful endpoint for lower per-call latency and payload size.
- Constructor kwargs mirror the public `tabpfn_client.TabPFNClassifier` / `TabPFNRegressor` so the same code is portable between the managed API and a SageMaker endpoint, modulo `endpoint_name` / `region_name`.

Thinking mode is supported on SageMaker by passing the same `thinking_mode` / `thinking_effort` / `thinking_timeout_s` / `thinking_metric` kwargs:

```python
clf = TabPFNClassifier(
endpoint_name="your-sagemaker-endpoint-name",
region_name="us-east-1",
thinking_mode=True,
thinking_effort="medium",
)
```

The first `predict*` call after `fit()` runs the autogluon-wrapped fit on the endpoint and can take from tens of seconds up to several minutes depending on `thinking_effort` and data size; the fitted model is cached on the endpoint and subsequent calls are fast. Caching is **required** when thinking is enabled (the client sets `use_kv_cache=True` automatically) — without it every prediction would redo the fit, which would exceed SageMaker's synchronous invoke window. Only `thinking_effort="medium"` is reliable within the real-time endpoint's ~60 s sync window for the *first* call; `"high"` may exceed it and is currently best-effort.

## Join Our Community

We're building the future of tabular machine learning and would love your involvement! Here's how you can participate and get help:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ dependencies = [
"pyarrow>=14.0.0,<=23.0.1",
]

[project.optional-dependencies]
sagemaker = ["boto3>=1.34"]

[project.urls]
documentation = "https://priorlabs.ai/docs"
source = "https://github.com/priorlabs/tabpfn-client"
Expand Down
16 changes: 16 additions & 0 deletions src/tabpfn_client/sagemaker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Prior Labs GmbH 2025.
# Licensed under the Apache License, Version 2.0
"""SageMaker BYOC client for TabPFN.

Users subscribed to the TabPFN AWS Marketplace listing deploy the BYOC
container to a SageMaker real-time endpoint and invoke it through this
submodule. The wire protocol is the inline JSON form accepted by the
container's /invocations route; auth is the standard boto3 credential chain.

from tabpfn_client.sagemaker import TabPFNClassifier, TabPFNRegressor
"""

from tabpfn_client.sagemaker.estimator import TabPFNClassifier, TabPFNRegressor


__all__ = ["TabPFNClassifier", "TabPFNRegressor"]
288 changes: 288 additions & 0 deletions src/tabpfn_client/sagemaker/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
# Copyright (c) Prior Labs GmbH 2025.
# Licensed under the Apache License, Version 2.0
"""scikit-learn estimators that invoke a TabPFN SageMaker BYOC endpoint.

The endpoint is the container defined in `dists/marketplaces/aws` in the
`tabpfn-server` repo. It accepts a single inline JSON body at POST
/invocations matching `prior.predictor.requests.PredictRequest`, and returns
`{"prediction": ..., "metadata": ..., "model_id": ...}`. The estimators here
build that body from the sklearn-style call surface, dispatch via
`boto3.client("sagemaker-runtime").invoke_endpoint`, and return the
prediction as a numpy array.

`fit()` is local-only: TabPFN is in-context, so we just keep X/y around and
ship them with each `predict*` call. The optional `use_kv_cache=True` path
opts into the server's V3 FIT_WITH_CACHE mode — the first round-trip uploads
training data and captures a `model_id`; subsequent predicts skip the upload
and reference that id.
"""

from __future__ import annotations

import json
from typing import Any, Dict, Literal, Optional

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_is_fitted

try:
import boto3 # type: ignore[import-untyped]

_BOTO3_AVAILABLE = True
except ImportError: # pragma: no cover
boto3 = None # type: ignore[assignment]
_BOTO3_AVAILABLE = False


ThinkingEffort = Literal["medium", "high"]


def _require_boto3() -> None:
if not _BOTO3_AVAILABLE:
raise ImportError(
"boto3 is required for tabpfn_client.sagemaker. "
"Install with: pip install 'tabpfn-client[sagemaker]'"
)


def _to_jsonable(X: Any) -> list:
"""Coerce numpy / pandas inputs to plain Python lists for JSON."""
if isinstance(X, pd.DataFrame):
return X.values.tolist()
if isinstance(X, pd.Series):
return X.tolist()
return np.asarray(X).tolist()


class _SagemakerBase(BaseEstimator):
"""Shared invoke_endpoint plumbing for SageMaker TabPFN estimators.

Subclasses set `_TASK`. Constructor kwargs mirror the public
`tabpfn_client.TabPFNClassifier` so user code is portable; everything but
the SageMaker-specific bits is forwarded into `task_config.tabpfn_config`
on the wire. `model_path` is currently dropped server-side (the active
checkpoint is whatever was baked into the model artifact); we keep it on
the constructor for API parity.
"""

_TASK: str = "" # overridden by subclasses

def __init__(
self,
endpoint_name: str,
region_name: Optional[str] = None,
boto_session: Optional[Any] = None,
model_path: str = "auto",
n_estimators: int = 8,
softmax_temperature: float = 0.9,
balance_probabilities: bool = False,
average_before_softmax: bool = False,
ignore_pretraining_limits: bool = True,
inference_precision: Literal["autocast", "auto"] = "auto",
random_state: Optional[int] = 0,
inference_config: Optional[Dict[str, Any]] = None,
paper_version: bool = False,
thinking_mode: bool = False,
thinking_effort: Optional[ThinkingEffort] = None,
thinking_timeout_s: Optional[float] = None,
thinking_metric: Optional[str] = None,
use_kv_cache: bool = False,
):
self.endpoint_name = endpoint_name
self.region_name = region_name
self.boto_session = boto_session
self.model_path = model_path
self.n_estimators = n_estimators
self.softmax_temperature = softmax_temperature
self.balance_probabilities = balance_probabilities
self.average_before_softmax = average_before_softmax
self.ignore_pretraining_limits = ignore_pretraining_limits
self.inference_precision = inference_precision
self.random_state = random_state
self.inference_config = inference_config
self.paper_version = paper_version
self.thinking_mode = thinking_mode
self.thinking_effort = thinking_effort
self.thinking_timeout_s = thinking_timeout_s
self.thinking_metric = thinking_metric
self.use_kv_cache = use_kv_cache

@property
def _thinking_active(self) -> bool:
return self.thinking_mode or self.thinking_effort is not None

@property
def _effective_use_kv_cache(self) -> bool:
# Thinking implies caching: without it every predict redoes the
# autogluon HPO sweep. The server enforces this too; we mirror it
# client-side so the wire body always agrees.
return self.use_kv_cache or self._thinking_active

def _build_tabpfn_config(self) -> Dict[str, Any]:
cfg: Dict[str, Any] = {
"n_estimators": self.n_estimators,
"softmax_temperature": self.softmax_temperature,
"average_before_softmax": self.average_before_softmax,
"ignore_pretraining_limits": self.ignore_pretraining_limits,
"inference_precision": self.inference_precision,
"random_state": self.random_state,
"inference_config": self.inference_config,
"fit_mode": "fit_with_cache" if self._effective_use_kv_cache else "fit_preprocessors",
}
# paper_version is OSS TabPFN-only — server's tagged-union forbids
# extras. balance_probabilities lives only on ClassifierTabPFNConfig.
# thinking_* fields live at the top of the body, not under
# tabpfn_config — see `_invoke`.
if self._TASK == "classification":
cfg["balance_probabilities"] = self.balance_probabilities
if self.model_path not in ("auto", "default"):
cfg["model_path"] = self.model_path
return cfg

def _build_thinking_block(self) -> Dict[str, Any]:
"""Top-level wire fields for thinking-mode. Sibling of
`task_config`, mirroring gapi's `FitRequest` shape. Empty when
thinking isn't active so callers can splat unconditionally."""
if not self._thinking_active:
return {}
block: Dict[str, Any] = {
"thinking_effort": self.thinking_effort if self.thinking_effort is not None else "medium",
}
if self.thinking_timeout_s is not None:
block["thinking_timeout_s"] = self.thinking_timeout_s
if self.thinking_metric is not None:
block["thinking_metric"] = self.thinking_metric
return block

def _runtime_client(self):
# Cached on the instance: boto3 service-model load + credential
# resolution is non-trivial and we don't want it on every predict.
client = getattr(self, "_cached_client", None)
if client is not None:
return client
_require_boto3()
if self.boto_session is not None:
client = self.boto_session.client("sagemaker-runtime")
elif self.region_name is not None:
client = boto3.client("sagemaker-runtime", region_name=self.region_name)
else:
client = boto3.client("sagemaker-runtime")
self._cached_client = client
return client

def __getstate__(self) -> Dict[str, Any]:
# boto3 clients aren't pickleable. Strip the cache so the estimator
# stays compatible with sklearn's pickle-based parallel/grid utilities.
state = self.__dict__.copy()
state.pop("_cached_client", None)
return state

def fit(self, X: Any, y: Any) -> "_SagemakerBase":
# X must be 2D; only DataFrame/array. y can be 1D (Series/array) or 2D.
X_arr = X if isinstance(X, pd.DataFrame) else np.asarray(X)
y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y)
if X_arr.shape[0] != y_arr.shape[0]:
raise ValueError(
f"X and y must have the same number of samples; "
f"got X={X_arr.shape}, y={y_arr.shape}"
)
self.X_train_ = X_arr
self.y_train_ = y_arr
self._cached_model_id: Optional[str] = None
if self._TASK == "classification":
self.classes_ = np.unique(y_arr)
return self
Comment on lines +182 to +196
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In scikit-learn, classifiers are expected to expose a classes_ attribute after being fitted. Currently, TabPFNClassifier does not set self.classes_ during fit(), which violates the scikit-learn estimator contract and will cause errors in downstream utilities (e.g., classification reports or evaluation metrics). Additionally, X is allowed to be a 1D pd.Series, which is incorrect since the feature matrix X must always be 2D.\n\nThis suggestion corrects the type check for X to exclude pd.Series and ensures self.classes_ is populated for classification tasks.

    def fit(self, X: Any, y: Any) -> "_SagemakerBase":\n        X_arr = X if isinstance(X, pd.DataFrame) else np.asarray(X)\n        y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y)\n        if X_arr.shape[0] != y_arr.shape[0]:\n            raise ValueError(\n                f"X and y must have the same number of samples; "\n                f"got X={X_arr.shape}, y={y_arr.shape}"\n            )\n        self.X_train_ = X_arr\n        self.y_train_ = y_arr\n        self._cached_model_id = None\n        if self._TASK == "classification":\n            self.classes_ = np.unique(y_arr)\n        return self


def _invoke(
self,
X_test: Any,
output_type: str,
predict_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
check_is_fitted(self, ["X_train_", "y_train_"])
params: Dict[str, Any] = {"output_type": output_type}
if predict_params:
params.update(predict_params)
body: Dict[str, Any] = {
"task_config": {
"task": self._TASK,
"tabpfn_config": self._build_tabpfn_config(),
"predict_params": params,
},
"X_test": _to_jsonable(X_test),
**self._build_thinking_block(),
}
if self._effective_use_kv_cache and self._cached_model_id is not None:
body["context"] = {"model_id": self._cached_model_id}
else:
body["X_train"] = _to_jsonable(self.X_train_)
# `y_train` on the wire is 2D (n_samples, 1) per PredictRequest.
# np.asarray handles pd.Series too, so the single-path form covers
# ndarray / list / DataFrame / Series uniformly.
y_arr = np.asarray(self.y_train_)
if y_arr.ndim == 1:
y_arr = y_arr.reshape(-1, 1)
body["y_train"] = y_arr.tolist()
resp = self._runtime_client().invoke_endpoint(
EndpointName=self.endpoint_name,
ContentType="application/json",
Accept="application/json",
Body=json.dumps(body).encode("utf-8"),
)
# StreamingBody is file-like; json.load avoids buffering the full
# response in memory, which matters for output_type="full".
payload = json.load(resp["Body"])
if self._effective_use_kv_cache:
self._cached_model_id = payload.get("model_id") or self._cached_model_id
return payload


class TabPFNClassifier(_SagemakerBase, ClassifierMixin):
"""TabPFN classifier backed by a SageMaker real-time endpoint.

Example:
from tabpfn_client.sagemaker import TabPFNClassifier
clf = TabPFNClassifier(
endpoint_name="tabpfn-sm-alpha-v3-thinking-001",
region_name="us-east-1",
)
clf.fit(X_train, y_train)
clf.predict(X_test)
clf.predict_proba(X_test)
"""

_TASK = "classification"

def predict(self, X: Any) -> np.ndarray:
result = self._invoke(X, output_type="preds")
return np.asarray(result["prediction"])

def predict_proba(self, X: Any) -> np.ndarray:
result = self._invoke(X, output_type="probas")
return np.asarray(result["prediction"])


class TabPFNRegressor(_SagemakerBase, RegressorMixin):
"""TabPFN regressor backed by a SageMaker real-time endpoint.

`output_type` defaults to "mean"; pass "median", "mode", "full", or
"quantiles" for the alternative distributional outputs the server
exposes. When `output_type="quantiles"`, `quantiles` selects the cut
points (each in [0, 1]).
"""

_TASK = "regression"

def predict(
self,
X: Any,
output_type: str = "mean",
quantiles: Optional[list] = None,
) -> np.ndarray:
predict_params: Dict[str, Any] = {}
if quantiles is not None:
predict_params["quantiles"] = quantiles
result = self._invoke(X, output_type=output_type, predict_params=predict_params)
return np.asarray(result["prediction"])
Loading