Skip to content

Add SageMaker BYOC client (tabpfn_client.sagemaker)#286

Draft
ggprior wants to merge 4 commits into
mainfrom
georg/sagemaker-v3-thinking
Draft

Add SageMaker BYOC client (tabpfn_client.sagemaker)#286
ggprior wants to merge 4 commits into
mainfrom
georg/sagemaker-v3-thinking

Conversation

@ggprior
Copy link
Copy Markdown
Contributor

@ggprior ggprior commented May 29, 2026

Adds a new tabpfn_client.sagemaker submodule providing scikit-learn-style TabPFNClassifier and TabPFNRegressor that proxy through boto3.client("sagemaker-runtime").invoke_endpoint to a TabPFN SageMaker BYOC endpoint

@ggprior ggprior requested a review from a team as a code owner May 29, 2026 09:12
@ggprior ggprior requested review from simo-prior and removed request for a team May 29, 2026 09:12
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for AWS SageMaker (BYOC) endpoints by introducing TabPFNClassifier and TabPFNRegressor estimators under a new sagemaker submodule, along with documentation and dependency updates. The review feedback focuses on improving scikit-learn compatibility (such as setting classes_ and ensuring pickleability), optimizing performance by caching the boto3 client and streaming JSON parsing, and adding support for quantile regression parameters.

Comment on lines +147 to +158
def fit(self, X: Any, y: Any) -> "_SagemakerBase":
X_arr = X if isinstance(X, (pd.DataFrame, pd.Series)) else np.asarray(X)
y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y)
if X_arr.shape[0] != y_arr.shape[0]:
raise ValueError(
f"X and y must have the same number of samples; "
f"got X={X_arr.shape}, y={y_arr.shape}"
)
self.X_train_ = X_arr
self.y_train_ = y_arr
self._cached_model_id: Optional[str] = None
return self
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In scikit-learn, classifiers are expected to expose a classes_ attribute after being fitted. Currently, TabPFNClassifier does not set self.classes_ during fit(), which violates the scikit-learn estimator contract and will cause errors in downstream utilities (e.g., classification reports or evaluation metrics). Additionally, X is allowed to be a 1D pd.Series, which is incorrect since the feature matrix X must always be 2D.\n\nThis suggestion corrects the type check for X to exclude pd.Series and ensures self.classes_ is populated for classification tasks.

    def fit(self, X: Any, y: Any) -> "_SagemakerBase":\n        X_arr = X if isinstance(X, pd.DataFrame) else np.asarray(X)\n        y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y)\n        if X_arr.shape[0] != y_arr.shape[0]:\n            raise ValueError(\n                f"X and y must have the same number of samples; "\n                f"got X={X_arr.shape}, y={y_arr.shape}"\n            )\n        self.X_train_ = X_arr\n        self.y_train_ = y_arr\n        self._cached_model_id = None\n        if self._TASK == "classification":\n            self.classes_ = np.unique(y_arr)\n        return self

Comment on lines +139 to +145
def _runtime_client(self):
_require_boto3()
if self.boto_session is not None:
return self.boto_session.client("sagemaker-runtime")
if self.region_name is not None:
return boto3.client("sagemaker-runtime", region_name=self.region_name)
return boto3.client("sagemaker-runtime")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Creating a new boto3 client on every invocation introduces significant latency overhead due to loading service models and configuring credentials. Caching the client on the estimator instance avoids this overhead.\n\nTo ensure the estimator remains pickleable (which is critical for scikit-learn compatibility, e.g., in GridSearchCV or multiprocessing), we also implement __getstate__ to exclude the non-pickleable client object from serialization.

    def _runtime_client(self):\n        _require_boto3()\n        if not hasattr(self, "_client") or self._client is None:\n            if self.boto_session is not None:\n                self._client = self.boto_session.client("sagemaker-runtime")\n            elif self.region_name is not None:\n                self._client = boto3.client("sagemaker-runtime", region_name=self.region_name)\n            else:\n                self._client = boto3.client("sagemaker-runtime")\n        return self._client\n\n    def __getstate__(self):\n        state = self.__dict__.copy()\n        state.pop("_client", None)\n        return state

Comment on lines +160 to +169
def _invoke(self, X_test: Any, output_type: str) -> Dict[str, Any]:
check_is_fitted(self, ["X_train_", "y_train_"])
body: Dict[str, Any] = {
"task_config": {
"task": self._TASK,
"tabpfn_config": self._build_tabpfn_config(),
"predict_params": {"output_type": output_type},
},
"X_test": _to_jsonable(X_test),
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of _invoke only forwards output_type in predict_params. This breaks support for the "quantiles" output type in TabPFNRegressor, as there is no way to pass the list of desired quantiles to the SageMaker endpoint.\n\nThis suggestion updates _invoke to accept an optional predict_params dictionary, allowing subclasses to pass extra parameters like quantiles to the endpoint.

    def _invoke(\n        self,\n        X_test: Any,\n        output_type: str,\n        predict_params: Optional[Dict[str, Any]] = None,\n    ) -> Dict[str, Any]:\n        check_is_fitted(self, ["X_train_", "y_train_"])\n        params = {"output_type": output_type}\n        if predict_params is not None:\n            params.update(predict_params)\n        body: Dict[str, Any] = {\n            "task_config": {\n                "task": self._TASK,\n                "tabpfn_config": self._build_tabpfn_config(),\n                "predict_params": params,\n            },\n            "X_test": _to_jsonable(X_test),\n        }

Comment on lines +229 to +231
def predict(self, X: Any, output_type: str = "mean") -> np.ndarray:
result = self._invoke(X, output_type=output_type)
return np.asarray(result["prediction"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Expose the quantiles parameter in TabPFNRegressor.predict to maintain API parity with the managed TabPFNRegressor and support quantile regression on SageMaker.

    def predict(\n        self,\n        X: Any,\n        output_type: str = "mean",\n        quantiles: Optional[list[float]] = None,\n    ) -> np.ndarray:\n        predict_params = {}\n        if quantiles is not None:\n            predict_params["quantiles"] = quantiles\n        result = self._invoke(X, output_type=output_type, predict_params=predict_params)\n        return np.asarray(result["prediction"])

Comment on lines +174 to +182
# `y_train` on the wire is 2D (n_samples, 1) per PredictRequest.
y = self.y_train_
if isinstance(y, pd.Series):
body["y_train"] = y.to_frame().values.tolist()
else:
y_arr = np.asarray(y)
if y_arr.ndim == 1:
y_arr = y_arr.reshape(-1, 1)
body["y_train"] = y_arr.tolist()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The serialization logic for y_train can be simplified. Using np.asarray and reshaping if 1D is extremely robust and works identically for pd.Series, pd.DataFrame, np.ndarray, and standard lists, eliminating the need for a separate isinstance(y, pd.Series) branch.

            y_arr = np.asarray(self.y_train_)\n            if y_arr.ndim == 1:\n                y_arr = y_arr.reshape(-1, 1)\n            body["y_train"] = y_arr.tolist()

Accept="application/json",
Body=json.dumps(body).encode("utf-8"),
)
payload = json.loads(resp["Body"].read())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of reading the entire response body into memory with .read() and then parsing it with json.loads(), use json.load() to parse the StreamingBody directly. This is more memory-efficient and idiomatic.

Suggested change
payload = json.loads(resp["Body"].read())
payload = json.load(resp["Body"])

@ggprior ggprior removed the request for review from simo-prior May 29, 2026 09:17
@ggprior ggprior marked this pull request as draft May 29, 2026 09:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant