From 828a6b68c716f57c0e469cccbab58873ace5850d Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Fri, 29 May 2026 11:11:49 +0200 Subject: [PATCH 1/4] Add SageMaker BYOC client (tabpfn_client.sagemaker) --- README.md | 36 ++++ pyproject.toml | 3 + src/tabpfn_client/sagemaker/__init__.py | 16 ++ src/tabpfn_client/sagemaker/estimator.py | 231 +++++++++++++++++++++++ 4 files changed, 286 insertions(+) create mode 100644 src/tabpfn_client/sagemaker/__init__.py create mode 100644 src/tabpfn_client/sagemaker/estimator.py diff --git a/README.md b/README.md index 0241f83..65139fe 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,42 @@ and login (on another machine) using your access token, skipping the interactive tabpfn_client.set_access_token(token) ``` +## AWS SageMaker (BYOC) + +If you've subscribed to the TabPFN AWS Marketplace listing and deployed the container to a SageMaker real-time endpoint, you can invoke it through `tabpfn_client.sagemaker` using a near-identical scikit-learn surface. There is no PriorLabs API token in this path — you authenticate to your own AWS account, and `predict` calls are billed by AWS SageMaker rather than against your TabPFN usage allowance. + +Install with the optional `sagemaker` extra to pull in `boto3`: + +```bash +pip install --upgrade 'tabpfn-client[sagemaker]' +``` + +Then point the estimator at your endpoint: + +```python +from tabpfn_client.sagemaker import TabPFNClassifier, TabPFNRegressor +from sklearn.datasets import load_breast_cancer +from sklearn.model_selection import train_test_split + +X, y = load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42) + +clf = TabPFNClassifier( + endpoint_name="your-sagemaker-endpoint-name", + region_name="us-east-1", +) +clf.fit(X_train, y_train) +clf.predict(X_test) +clf.predict_proba(X_test) +``` + +Notes: + +- AWS credentials are resolved through the standard `boto3` credential chain (env vars, `~/.aws/credentials`, instance profile, SSO, etc.). Pass `boto_session=session` to use an explicit `boto3.Session`. +- `fit()` is local: TabPFN is in-context, so the estimator just keeps `X_train` / `y_train` and ships them on each `predict*` call. There is no separate training job. +- Set `use_kv_cache=True` to opt into the v3 KV-cache path on the server: the first round-trip uploads training data and captures a `model_id`; subsequent `predict*` calls reference that id and skip the training upload. This trades a stateful endpoint for lower per-call latency and payload size. +- Constructor kwargs mirror the public `tabpfn_client.TabPFNClassifier` / `TabPFNRegressor` so the same code is portable between the managed API and a SageMaker endpoint, modulo `endpoint_name` / `region_name`. + ## Join Our Community We're building the future of tabular machine learning and would love your involvement! Here's how you can participate and get help: diff --git a/pyproject.toml b/pyproject.toml index 73bf878..b3c39d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ dependencies = [ "pyarrow>=14.0.0,<=23.0.1", ] +[project.optional-dependencies] +sagemaker = ["boto3>=1.34"] + [project.urls] documentation = "https://priorlabs.ai/docs" source = "https://github.com/priorlabs/tabpfn-client" diff --git a/src/tabpfn_client/sagemaker/__init__.py b/src/tabpfn_client/sagemaker/__init__.py new file mode 100644 index 0000000..fa4df07 --- /dev/null +++ b/src/tabpfn_client/sagemaker/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Prior Labs GmbH 2025. +# Licensed under the Apache License, Version 2.0 +"""SageMaker BYOC client for TabPFN. + +Users subscribed to the TabPFN AWS Marketplace listing deploy the BYOC +container to a SageMaker real-time endpoint and invoke it through this +submodule. The wire protocol is the inline JSON form accepted by the +container's /invocations route; auth is the standard boto3 credential chain. + + from tabpfn_client.sagemaker import TabPFNClassifier, TabPFNRegressor +""" + +from tabpfn_client.sagemaker.estimator import TabPFNClassifier, TabPFNRegressor + + +__all__ = ["TabPFNClassifier", "TabPFNRegressor"] diff --git a/src/tabpfn_client/sagemaker/estimator.py b/src/tabpfn_client/sagemaker/estimator.py new file mode 100644 index 0000000..84a1d69 --- /dev/null +++ b/src/tabpfn_client/sagemaker/estimator.py @@ -0,0 +1,231 @@ +# Copyright (c) Prior Labs GmbH 2025. +# Licensed under the Apache License, Version 2.0 +"""scikit-learn estimators that invoke a TabPFN SageMaker BYOC endpoint. + +The endpoint is the container defined in `dists/marketplaces/aws` in the +`tabpfn-server` repo. It accepts a single inline JSON body at POST +/invocations matching `prior.predictor.requests.PredictRequest`, and returns +`{"prediction": ..., "metadata": ..., "model_id": ...}`. The estimators here +build that body from the sklearn-style call surface, dispatch via +`boto3.client("sagemaker-runtime").invoke_endpoint`, and return the +prediction as a numpy array. + +`fit()` is local-only: TabPFN is in-context, so we just keep X/y around and +ship them with each `predict*` call. The optional `use_kv_cache=True` path +opts into the server's V3 FIT_WITH_CACHE mode — the first round-trip uploads +training data and captures a `model_id`; subsequent predicts skip the upload +and reference that id. +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, Literal, Optional + +import numpy as np +import pandas as pd +from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin +from sklearn.utils.validation import check_is_fitted + +try: + import boto3 # type: ignore[import-untyped] + + _BOTO3_AVAILABLE = True +except ImportError: # pragma: no cover + boto3 = None # type: ignore[assignment] + _BOTO3_AVAILABLE = False + + +ThinkingEffort = Literal["medium", "high"] + + +def _require_boto3() -> None: + if not _BOTO3_AVAILABLE: + raise ImportError( + "boto3 is required for tabpfn_client.sagemaker. " + "Install with: pip install 'tabpfn-client[sagemaker]'" + ) + + +def _to_jsonable(X: Any) -> list: + """Coerce numpy / pandas inputs to plain Python lists for JSON.""" + if isinstance(X, pd.DataFrame): + return X.values.tolist() + if isinstance(X, pd.Series): + return X.tolist() + return np.asarray(X).tolist() + + +class _SagemakerBase(BaseEstimator): + """Shared invoke_endpoint plumbing for SageMaker TabPFN estimators. + + Subclasses set `_TASK`. Constructor kwargs mirror the public + `tabpfn_client.TabPFNClassifier` so user code is portable; everything but + the SageMaker-specific bits is forwarded into `task_config.tabpfn_config` + on the wire. `model_path` is currently dropped server-side (the active + checkpoint is whatever was baked into the model artifact); we keep it on + the constructor for API parity. + """ + + _TASK: str = "" # overridden by subclasses + + def __init__( + self, + endpoint_name: str, + region_name: Optional[str] = None, + boto_session: Optional[Any] = None, + model_path: str = "auto", + n_estimators: int = 8, + softmax_temperature: float = 0.9, + balance_probabilities: bool = False, + average_before_softmax: bool = False, + ignore_pretraining_limits: bool = True, + inference_precision: Literal["autocast", "auto"] = "auto", + random_state: Optional[int] = 0, + inference_config: Optional[Dict[str, Any]] = None, + paper_version: bool = False, + thinking_mode: bool = False, + thinking_effort: Optional[ThinkingEffort] = None, + thinking_timeout_s: Optional[float] = None, + thinking_metric: Optional[str] = None, + use_kv_cache: bool = False, + ): + self.endpoint_name = endpoint_name + self.region_name = region_name + self.boto_session = boto_session + self.model_path = model_path + self.n_estimators = n_estimators + self.softmax_temperature = softmax_temperature + self.balance_probabilities = balance_probabilities + self.average_before_softmax = average_before_softmax + self.ignore_pretraining_limits = ignore_pretraining_limits + self.inference_precision = inference_precision + self.random_state = random_state + self.inference_config = inference_config + self.paper_version = paper_version + self.thinking_mode = thinking_mode + self.thinking_effort = thinking_effort + self.thinking_timeout_s = thinking_timeout_s + self.thinking_metric = thinking_metric + self.use_kv_cache = use_kv_cache + + def _build_tabpfn_config(self) -> Dict[str, Any]: + cfg: Dict[str, Any] = { + "n_estimators": self.n_estimators, + "softmax_temperature": self.softmax_temperature, + "average_before_softmax": self.average_before_softmax, + "ignore_pretraining_limits": self.ignore_pretraining_limits, + "inference_precision": self.inference_precision, + "random_state": self.random_state, + "inference_config": self.inference_config, + "fit_mode": "fit_with_cache" if self.use_kv_cache else "fit_preprocessors", + } + # paper_version is OSS TabPFN-only — server's tagged-union forbids + # extras. balance_probabilities lives only on ClassifierTabPFNConfig. + if self._TASK == "classification": + cfg["balance_probabilities"] = self.balance_probabilities + if self.thinking_mode or self.thinking_effort is not None: + cfg["thinking_mode"] = True + if self.thinking_effort is not None: + cfg["thinking_effort"] = self.thinking_effort + if self.thinking_timeout_s is not None: + cfg["thinking_timeout_s"] = self.thinking_timeout_s + if self.thinking_metric is not None: + cfg["thinking_metric"] = self.thinking_metric + if self.model_path not in ("auto", "default"): + cfg["model_path"] = self.model_path + return cfg + + def _runtime_client(self): + _require_boto3() + if self.boto_session is not None: + return self.boto_session.client("sagemaker-runtime") + if self.region_name is not None: + return boto3.client("sagemaker-runtime", region_name=self.region_name) + return boto3.client("sagemaker-runtime") + + def fit(self, X: Any, y: Any) -> "_SagemakerBase": + X_arr = X if isinstance(X, (pd.DataFrame, pd.Series)) else np.asarray(X) + y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y) + if X_arr.shape[0] != y_arr.shape[0]: + raise ValueError( + f"X and y must have the same number of samples; " + f"got X={X_arr.shape}, y={y_arr.shape}" + ) + self.X_train_ = X_arr + self.y_train_ = y_arr + self._cached_model_id: Optional[str] = None + return self + + def _invoke(self, X_test: Any, output_type: str) -> Dict[str, Any]: + check_is_fitted(self, ["X_train_", "y_train_"]) + body: Dict[str, Any] = { + "task_config": { + "task": self._TASK, + "tabpfn_config": self._build_tabpfn_config(), + "predict_params": {"output_type": output_type}, + }, + "X_test": _to_jsonable(X_test), + } + if self.use_kv_cache and self._cached_model_id is not None: + body["context"] = {"model_id": self._cached_model_id} + else: + body["X_train"] = _to_jsonable(self.X_train_) + # `y_train` on the wire is 2D (n_samples, 1) per PredictRequest. + y = self.y_train_ + if isinstance(y, pd.Series): + body["y_train"] = y.to_frame().values.tolist() + else: + y_arr = np.asarray(y) + if y_arr.ndim == 1: + y_arr = y_arr.reshape(-1, 1) + body["y_train"] = y_arr.tolist() + resp = self._runtime_client().invoke_endpoint( + EndpointName=self.endpoint_name, + ContentType="application/json", + Accept="application/json", + Body=json.dumps(body).encode("utf-8"), + ) + payload = json.loads(resp["Body"].read()) + if self.use_kv_cache: + self._cached_model_id = payload.get("model_id") or self._cached_model_id + return payload + + +class TabPFNClassifier(_SagemakerBase, ClassifierMixin): + """TabPFN classifier backed by a SageMaker real-time endpoint. + + Example: + from tabpfn_client.sagemaker import TabPFNClassifier + clf = TabPFNClassifier( + endpoint_name="tabpfn-sm-alpha-v3-thinking-001", + region_name="us-east-1", + ) + clf.fit(X_train, y_train) + clf.predict(X_test) + clf.predict_proba(X_test) + """ + + _TASK = "classification" + + def predict(self, X: Any) -> np.ndarray: + result = self._invoke(X, output_type="preds") + return np.asarray(result["prediction"]) + + def predict_proba(self, X: Any) -> np.ndarray: + result = self._invoke(X, output_type="probas") + return np.asarray(result["prediction"]) + + +class TabPFNRegressor(_SagemakerBase, RegressorMixin): + """TabPFN regressor backed by a SageMaker real-time endpoint. + + `output_type` defaults to "mean"; pass "median", "mode", or "full" for + the alternative distributional outputs the server exposes. + """ + + _TASK = "regression" + + def predict(self, X: Any, output_type: str = "mean") -> np.ndarray: + result = self._invoke(X, output_type=output_type) + return np.asarray(result["prediction"]) From d9e30a920bb34ac67cdad3d9a94a6430a4c1a54e Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Fri, 29 May 2026 11:17:28 +0200 Subject: [PATCH 2/4] Cache boto3 client, expose classes_, tighten X type check --- src/tabpfn_client/sagemaker/estimator.py | 28 +++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/tabpfn_client/sagemaker/estimator.py b/src/tabpfn_client/sagemaker/estimator.py index 84a1d69..ba913c5 100644 --- a/src/tabpfn_client/sagemaker/estimator.py +++ b/src/tabpfn_client/sagemaker/estimator.py @@ -137,15 +137,31 @@ def _build_tabpfn_config(self) -> Dict[str, Any]: return cfg def _runtime_client(self): + # Cached on the instance: boto3 service-model load + credential + # resolution is non-trivial and we don't want it on every predict. + client = getattr(self, "_cached_client", None) + if client is not None: + return client _require_boto3() if self.boto_session is not None: - return self.boto_session.client("sagemaker-runtime") - if self.region_name is not None: - return boto3.client("sagemaker-runtime", region_name=self.region_name) - return boto3.client("sagemaker-runtime") + client = self.boto_session.client("sagemaker-runtime") + elif self.region_name is not None: + client = boto3.client("sagemaker-runtime", region_name=self.region_name) + else: + client = boto3.client("sagemaker-runtime") + self._cached_client = client + return client + + def __getstate__(self) -> Dict[str, Any]: + # boto3 clients aren't pickleable. Strip the cache so the estimator + # stays compatible with sklearn's pickle-based parallel/grid utilities. + state = self.__dict__.copy() + state.pop("_cached_client", None) + return state def fit(self, X: Any, y: Any) -> "_SagemakerBase": - X_arr = X if isinstance(X, (pd.DataFrame, pd.Series)) else np.asarray(X) + # X must be 2D; only DataFrame/array. y can be 1D (Series/array) or 2D. + X_arr = X if isinstance(X, pd.DataFrame) else np.asarray(X) y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y) if X_arr.shape[0] != y_arr.shape[0]: raise ValueError( @@ -155,6 +171,8 @@ def fit(self, X: Any, y: Any) -> "_SagemakerBase": self.X_train_ = X_arr self.y_train_ = y_arr self._cached_model_id: Optional[str] = None + if self._TASK == "classification": + self.classes_ = np.unique(y_arr) return self def _invoke(self, X_test: Any, output_type: str) -> Dict[str, Any]: From c619b22ea020c6db77e355dd8f35d877a3b1c474 Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Fri, 29 May 2026 11:25:42 +0200 Subject: [PATCH 3/4] Address gemini review: quantiles support, y_train serialization, json streaming --- src/tabpfn_client/sagemaker/estimator.py | 48 ++++++++++++++++-------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/tabpfn_client/sagemaker/estimator.py b/src/tabpfn_client/sagemaker/estimator.py index ba913c5..dccd71f 100644 --- a/src/tabpfn_client/sagemaker/estimator.py +++ b/src/tabpfn_client/sagemaker/estimator.py @@ -175,13 +175,21 @@ def fit(self, X: Any, y: Any) -> "_SagemakerBase": self.classes_ = np.unique(y_arr) return self - def _invoke(self, X_test: Any, output_type: str) -> Dict[str, Any]: + def _invoke( + self, + X_test: Any, + output_type: str, + predict_params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: check_is_fitted(self, ["X_train_", "y_train_"]) + params: Dict[str, Any] = {"output_type": output_type} + if predict_params: + params.update(predict_params) body: Dict[str, Any] = { "task_config": { "task": self._TASK, "tabpfn_config": self._build_tabpfn_config(), - "predict_params": {"output_type": output_type}, + "predict_params": params, }, "X_test": _to_jsonable(X_test), } @@ -190,21 +198,21 @@ def _invoke(self, X_test: Any, output_type: str) -> Dict[str, Any]: else: body["X_train"] = _to_jsonable(self.X_train_) # `y_train` on the wire is 2D (n_samples, 1) per PredictRequest. - y = self.y_train_ - if isinstance(y, pd.Series): - body["y_train"] = y.to_frame().values.tolist() - else: - y_arr = np.asarray(y) - if y_arr.ndim == 1: - y_arr = y_arr.reshape(-1, 1) - body["y_train"] = y_arr.tolist() + # np.asarray handles pd.Series too, so the single-path form covers + # ndarray / list / DataFrame / Series uniformly. + y_arr = np.asarray(self.y_train_) + if y_arr.ndim == 1: + y_arr = y_arr.reshape(-1, 1) + body["y_train"] = y_arr.tolist() resp = self._runtime_client().invoke_endpoint( EndpointName=self.endpoint_name, ContentType="application/json", Accept="application/json", Body=json.dumps(body).encode("utf-8"), ) - payload = json.loads(resp["Body"].read()) + # StreamingBody is file-like; json.load avoids buffering the full + # response in memory, which matters for output_type="full". + payload = json.load(resp["Body"]) if self.use_kv_cache: self._cached_model_id = payload.get("model_id") or self._cached_model_id return payload @@ -238,12 +246,22 @@ def predict_proba(self, X: Any) -> np.ndarray: class TabPFNRegressor(_SagemakerBase, RegressorMixin): """TabPFN regressor backed by a SageMaker real-time endpoint. - `output_type` defaults to "mean"; pass "median", "mode", or "full" for - the alternative distributional outputs the server exposes. + `output_type` defaults to "mean"; pass "median", "mode", "full", or + "quantiles" for the alternative distributional outputs the server + exposes. When `output_type="quantiles"`, `quantiles` selects the cut + points (each in [0, 1]). """ _TASK = "regression" - def predict(self, X: Any, output_type: str = "mean") -> np.ndarray: - result = self._invoke(X, output_type=output_type) + def predict( + self, + X: Any, + output_type: str = "mean", + quantiles: Optional[list] = None, + ) -> np.ndarray: + predict_params: Dict[str, Any] = {} + if quantiles is not None: + predict_params["quantiles"] = quantiles + result = self._invoke(X, output_type=output_type, predict_params=predict_params) return np.asarray(result["prediction"]) From 955a357e6d533dc4aea0a1709bb99eeffcfa1e46 Mon Sep 17 00:00:00 2001 From: Georg Grab Date: Fri, 29 May 2026 14:42:05 +0200 Subject: [PATCH 4/4] Wire thinking_* at body top level; auto-enable kv cache when thinking --- README.md | 13 +++++++ src/tabpfn_client/sagemaker/estimator.py | 43 ++++++++++++++++++------ 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 65139fe..4cc1f24 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,19 @@ Notes: - Set `use_kv_cache=True` to opt into the v3 KV-cache path on the server: the first round-trip uploads training data and captures a `model_id`; subsequent `predict*` calls reference that id and skip the training upload. This trades a stateful endpoint for lower per-call latency and payload size. - Constructor kwargs mirror the public `tabpfn_client.TabPFNClassifier` / `TabPFNRegressor` so the same code is portable between the managed API and a SageMaker endpoint, modulo `endpoint_name` / `region_name`. +Thinking mode is supported on SageMaker by passing the same `thinking_mode` / `thinking_effort` / `thinking_timeout_s` / `thinking_metric` kwargs: + +```python +clf = TabPFNClassifier( + endpoint_name="your-sagemaker-endpoint-name", + region_name="us-east-1", + thinking_mode=True, + thinking_effort="medium", +) +``` + +The first `predict*` call after `fit()` runs the autogluon-wrapped fit on the endpoint and can take from tens of seconds up to several minutes depending on `thinking_effort` and data size; the fitted model is cached on the endpoint and subsequent calls are fast. Caching is **required** when thinking is enabled (the client sets `use_kv_cache=True` automatically) — without it every prediction would redo the fit, which would exceed SageMaker's synchronous invoke window. Only `thinking_effort="medium"` is reliable within the real-time endpoint's ~60 s sync window for the *first* call; `"high"` may exceed it and is currently best-effort. + ## Join Our Community We're building the future of tabular machine learning and would love your involvement! Here's how you can participate and get help: diff --git a/src/tabpfn_client/sagemaker/estimator.py b/src/tabpfn_client/sagemaker/estimator.py index dccd71f..5431abd 100644 --- a/src/tabpfn_client/sagemaker/estimator.py +++ b/src/tabpfn_client/sagemaker/estimator.py @@ -109,6 +109,17 @@ def __init__( self.thinking_metric = thinking_metric self.use_kv_cache = use_kv_cache + @property + def _thinking_active(self) -> bool: + return self.thinking_mode or self.thinking_effort is not None + + @property + def _effective_use_kv_cache(self) -> bool: + # Thinking implies caching: without it every predict redoes the + # autogluon HPO sweep. The server enforces this too; we mirror it + # client-side so the wire body always agrees. + return self.use_kv_cache or self._thinking_active + def _build_tabpfn_config(self) -> Dict[str, Any]: cfg: Dict[str, Any] = { "n_estimators": self.n_estimators, @@ -118,24 +129,33 @@ def _build_tabpfn_config(self) -> Dict[str, Any]: "inference_precision": self.inference_precision, "random_state": self.random_state, "inference_config": self.inference_config, - "fit_mode": "fit_with_cache" if self.use_kv_cache else "fit_preprocessors", + "fit_mode": "fit_with_cache" if self._effective_use_kv_cache else "fit_preprocessors", } # paper_version is OSS TabPFN-only — server's tagged-union forbids # extras. balance_probabilities lives only on ClassifierTabPFNConfig. + # thinking_* fields live at the top of the body, not under + # tabpfn_config — see `_invoke`. if self._TASK == "classification": cfg["balance_probabilities"] = self.balance_probabilities - if self.thinking_mode or self.thinking_effort is not None: - cfg["thinking_mode"] = True - if self.thinking_effort is not None: - cfg["thinking_effort"] = self.thinking_effort - if self.thinking_timeout_s is not None: - cfg["thinking_timeout_s"] = self.thinking_timeout_s - if self.thinking_metric is not None: - cfg["thinking_metric"] = self.thinking_metric if self.model_path not in ("auto", "default"): cfg["model_path"] = self.model_path return cfg + def _build_thinking_block(self) -> Dict[str, Any]: + """Top-level wire fields for thinking-mode. Sibling of + `task_config`, mirroring gapi's `FitRequest` shape. Empty when + thinking isn't active so callers can splat unconditionally.""" + if not self._thinking_active: + return {} + block: Dict[str, Any] = { + "thinking_effort": self.thinking_effort if self.thinking_effort is not None else "medium", + } + if self.thinking_timeout_s is not None: + block["thinking_timeout_s"] = self.thinking_timeout_s + if self.thinking_metric is not None: + block["thinking_metric"] = self.thinking_metric + return block + def _runtime_client(self): # Cached on the instance: boto3 service-model load + credential # resolution is non-trivial and we don't want it on every predict. @@ -192,8 +212,9 @@ def _invoke( "predict_params": params, }, "X_test": _to_jsonable(X_test), + **self._build_thinking_block(), } - if self.use_kv_cache and self._cached_model_id is not None: + if self._effective_use_kv_cache and self._cached_model_id is not None: body["context"] = {"model_id": self._cached_model_id} else: body["X_train"] = _to_jsonable(self.X_train_) @@ -213,7 +234,7 @@ def _invoke( # StreamingBody is file-like; json.load avoids buffering the full # response in memory, which matters for output_type="full". payload = json.load(resp["Body"]) - if self.use_kv_cache: + if self._effective_use_kv_cache: self._cached_model_id = payload.get("model_id") or self._cached_model_id return payload