Add SageMaker BYOC client (tabpfn_client.sagemaker)#286
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for AWS SageMaker (BYOC) endpoints by introducing TabPFNClassifier and TabPFNRegressor estimators under a new sagemaker submodule, along with documentation and dependency updates. The review feedback focuses on improving scikit-learn compatibility (such as setting classes_ and ensuring pickleability), optimizing performance by caching the boto3 client and streaming JSON parsing, and adding support for quantile regression parameters.
| def fit(self, X: Any, y: Any) -> "_SagemakerBase": | ||
| X_arr = X if isinstance(X, (pd.DataFrame, pd.Series)) else np.asarray(X) | ||
| y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y) | ||
| if X_arr.shape[0] != y_arr.shape[0]: | ||
| raise ValueError( | ||
| f"X and y must have the same number of samples; " | ||
| f"got X={X_arr.shape}, y={y_arr.shape}" | ||
| ) | ||
| self.X_train_ = X_arr | ||
| self.y_train_ = y_arr | ||
| self._cached_model_id: Optional[str] = None | ||
| return self |
There was a problem hiding this comment.
In scikit-learn, classifiers are expected to expose a classes_ attribute after being fitted. Currently, TabPFNClassifier does not set self.classes_ during fit(), which violates the scikit-learn estimator contract and will cause errors in downstream utilities (e.g., classification reports or evaluation metrics). Additionally, X is allowed to be a 1D pd.Series, which is incorrect since the feature matrix X must always be 2D.\n\nThis suggestion corrects the type check for X to exclude pd.Series and ensures self.classes_ is populated for classification tasks.
def fit(self, X: Any, y: Any) -> "_SagemakerBase":\n X_arr = X if isinstance(X, pd.DataFrame) else np.asarray(X)\n y_arr = y if isinstance(y, (pd.DataFrame, pd.Series)) else np.asarray(y)\n if X_arr.shape[0] != y_arr.shape[0]:\n raise ValueError(\n f"X and y must have the same number of samples; "\n f"got X={X_arr.shape}, y={y_arr.shape}"\n )\n self.X_train_ = X_arr\n self.y_train_ = y_arr\n self._cached_model_id = None\n if self._TASK == "classification":\n self.classes_ = np.unique(y_arr)\n return self| def _runtime_client(self): | ||
| _require_boto3() | ||
| if self.boto_session is not None: | ||
| return self.boto_session.client("sagemaker-runtime") | ||
| if self.region_name is not None: | ||
| return boto3.client("sagemaker-runtime", region_name=self.region_name) | ||
| return boto3.client("sagemaker-runtime") |
There was a problem hiding this comment.
Creating a new boto3 client on every invocation introduces significant latency overhead due to loading service models and configuring credentials. Caching the client on the estimator instance avoids this overhead.\n\nTo ensure the estimator remains pickleable (which is critical for scikit-learn compatibility, e.g., in GridSearchCV or multiprocessing), we also implement __getstate__ to exclude the non-pickleable client object from serialization.
def _runtime_client(self):\n _require_boto3()\n if not hasattr(self, "_client") or self._client is None:\n if self.boto_session is not None:\n self._client = self.boto_session.client("sagemaker-runtime")\n elif self.region_name is not None:\n self._client = boto3.client("sagemaker-runtime", region_name=self.region_name)\n else:\n self._client = boto3.client("sagemaker-runtime")\n return self._client\n\n def __getstate__(self):\n state = self.__dict__.copy()\n state.pop("_client", None)\n return state| def _invoke(self, X_test: Any, output_type: str) -> Dict[str, Any]: | ||
| check_is_fitted(self, ["X_train_", "y_train_"]) | ||
| body: Dict[str, Any] = { | ||
| "task_config": { | ||
| "task": self._TASK, | ||
| "tabpfn_config": self._build_tabpfn_config(), | ||
| "predict_params": {"output_type": output_type}, | ||
| }, | ||
| "X_test": _to_jsonable(X_test), | ||
| } |
There was a problem hiding this comment.
The current implementation of _invoke only forwards output_type in predict_params. This breaks support for the "quantiles" output type in TabPFNRegressor, as there is no way to pass the list of desired quantiles to the SageMaker endpoint.\n\nThis suggestion updates _invoke to accept an optional predict_params dictionary, allowing subclasses to pass extra parameters like quantiles to the endpoint.
def _invoke(\n self,\n X_test: Any,\n output_type: str,\n predict_params: Optional[Dict[str, Any]] = None,\n ) -> Dict[str, Any]:\n check_is_fitted(self, ["X_train_", "y_train_"])\n params = {"output_type": output_type}\n if predict_params is not None:\n params.update(predict_params)\n body: Dict[str, Any] = {\n "task_config": {\n "task": self._TASK,\n "tabpfn_config": self._build_tabpfn_config(),\n "predict_params": params,\n },\n "X_test": _to_jsonable(X_test),\n }| def predict(self, X: Any, output_type: str = "mean") -> np.ndarray: | ||
| result = self._invoke(X, output_type=output_type) | ||
| return np.asarray(result["prediction"]) |
There was a problem hiding this comment.
Expose the quantiles parameter in TabPFNRegressor.predict to maintain API parity with the managed TabPFNRegressor and support quantile regression on SageMaker.
def predict(\n self,\n X: Any,\n output_type: str = "mean",\n quantiles: Optional[list[float]] = None,\n ) -> np.ndarray:\n predict_params = {}\n if quantiles is not None:\n predict_params["quantiles"] = quantiles\n result = self._invoke(X, output_type=output_type, predict_params=predict_params)\n return np.asarray(result["prediction"])| # `y_train` on the wire is 2D (n_samples, 1) per PredictRequest. | ||
| y = self.y_train_ | ||
| if isinstance(y, pd.Series): | ||
| body["y_train"] = y.to_frame().values.tolist() | ||
| else: | ||
| y_arr = np.asarray(y) | ||
| if y_arr.ndim == 1: | ||
| y_arr = y_arr.reshape(-1, 1) | ||
| body["y_train"] = y_arr.tolist() |
There was a problem hiding this comment.
The serialization logic for y_train can be simplified. Using np.asarray and reshaping if 1D is extremely robust and works identically for pd.Series, pd.DataFrame, np.ndarray, and standard lists, eliminating the need for a separate isinstance(y, pd.Series) branch.
y_arr = np.asarray(self.y_train_)\n if y_arr.ndim == 1:\n y_arr = y_arr.reshape(-1, 1)\n body["y_train"] = y_arr.tolist()| Accept="application/json", | ||
| Body=json.dumps(body).encode("utf-8"), | ||
| ) | ||
| payload = json.loads(resp["Body"].read()) |
There was a problem hiding this comment.
Adds a new
tabpfn_client.sagemakersubmodule providing scikit-learn-styleTabPFNClassifierandTabPFNRegressorthat proxy throughboto3.client("sagemaker-runtime").invoke_endpointto a TabPFN SageMaker BYOC endpoint