From c328f7262bb92fe31e6a6a02f75569569c920e02 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Jun 2026 09:37:26 +0000 Subject: [PATCH 1/7] Add support for cache_model_artifact --- .../cloud/backend/sagemaker_backend.py | 31 +-- src/autogluon/cloud/model/foundation_model.py | 187 +++++++++++++++--- src/autogluon/cloud/model/registry.py | 122 ++++++------ .../sagemaker_scripts/timeseries_fm_serve.py | 14 +- .../general/test_foundation_model.py | 148 ++++++++++++++ 5 files changed, 386 insertions(+), 116 deletions(-) create mode 100644 tests/unittests/general/test_foundation_model.py diff --git a/src/autogluon/cloud/backend/sagemaker_backend.py b/src/autogluon/cloud/backend/sagemaker_backend.py index b3c8cb1..c473fe1 100644 --- a/src/autogluon/cloud/backend/sagemaker_backend.py +++ b/src/autogluon/cloud/backend/sagemaker_backend.py @@ -356,7 +356,8 @@ def deploy( wait: bool = True, model_kwargs: Optional[Dict] = None, deploy_kwargs: Optional[Dict] = None, - serve_config: Optional[Dict[str, Any]] = None, + fm_serve_config: Optional[Dict[str, Any]] = None, + repack: bool = True, ) -> None: """ Deploy a predictor as a SageMaker endpoint, which can be used to do real-time inference later. @@ -397,8 +398,8 @@ def deploy( deploy_kwargs: Any extra arguments needed to pass to deploy. Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy for all options - serve_config: Optional[Dict[str, Any]], default = None - Configuration dict passed to the serve script via the AG_SERVE_CONFIG env var. + fm_serve_config: Optional[Dict[str, Any]], default = None + Configuration dict passed to the FM serve script via the AG_FM_SERVE_CONFIG env var. """ assert self.endpoint is None, ( "There is an endpoint already attached. Either detach it with `detach` or clean it up with `cleanup_deployment`" @@ -444,19 +445,21 @@ def deploy( ) entry_point = self._serve_script_path - # Pick model class: - # - No artifact → create minimal tarball with serve script (FM deploy) - # - Artifact from different source or custom entry point → Repack (inject script into tarball) - # - Artifact from fit job, default entry point → NonRepack (script already in tarball) + # Pick model class. The question is whether the tarball already contains + # the entry_point script — if yes, NonRepack uses it as-is; if no, Repack + # injects the script at deploy time. if predictor_path is None: predictor_path = self._create_serve_script_tarball(entry_point, endpoint_name) model_cls = AutoGluonNonRepackInferenceModel + elif not repack: + model_cls = AutoGluonNonRepackInferenceModel else: - fit_output = self._fit_job.get_output_path() if self._fit_job is not None else None - if predictor_path != fit_output or user_entry_point is not None: - model_cls = AutoGluonRepackInferenceModel - else: - model_cls = AutoGluonNonRepackInferenceModel + is_default_fit_output = ( + self._fit_job is not None + and predictor_path == self._fit_job.get_output_path() + and user_entry_point is None + ) + model_cls = AutoGluonNonRepackInferenceModel if is_default_fit_output else AutoGluonRepackInferenceModel # Assemble env vars and deploy predictor_cls = self._realtime_predictor_cls @@ -483,8 +486,8 @@ def deploy( else: model_kwargs_env = {SAGEMAKER_MODEL_SERVER_WORKERS: "1"} - if serve_config is not None: - model_kwargs_env["AG_SERVE_CONFIG"] = json.dumps(serve_config) + if fm_serve_config is not None: + model_kwargs_env["AG_FM_SERVE_CONFIG"] = json.dumps(fm_serve_config) model = model_cls( model_data=predictor_path, diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 2f384ff..8d7a5f0 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -2,6 +2,9 @@ from __future__ import annotations +import json +import logging +import tarfile import tempfile from abc import abstractmethod from pathlib import Path @@ -9,13 +12,35 @@ import pandas as pd +from autogluon.common.utils.s3_utils import s3_path_to_bucket_prefix + from ..backend.backend_factory import BackendFactory from ..backend.constant import SAGEMAKER, TABULAR_SAGEMAKER, TIMESERIES_SAGEMAKER from ..endpoint.timeseries_endpoint import TimeSeriesEndpoint from ..scripts.script_manager import ScriptManager from ..utils.aws_utils import resolve_cloud_output_path +from ..version import __version__ from .registry import get_model_config +logger = logging.getLogger(__name__) + +# SageMaker extracts model.tar.gz to /opt/ml/model in the container. +_CONTAINER_WEIGHTS_DIR = "/opt/ml/model/weights" + +_AG_CLOUD_VERSION_METADATA_KEY = "autogluon-cloud-version" + + +def _s3_head_or_none(s3_client: Any, bucket: str, key: str) -> Optional[Dict[str, Any]]: + """Return ``head_object`` response if the key exists, ``None`` for 404. Other errors propagate.""" + from botocore.exceptions import ClientError + + try: + return s3_client.head_object(Bucket=bucket, Key=key) + except ClientError as e: + if e.response.get("Error", {}).get("Code") in ("404", "NoSuchKey", "NotFound"): + return None + raise + class FoundationModel: """ @@ -37,7 +62,7 @@ def __new__(cls, model_id: str, **kwargs) -> "FoundationModel": if cls is not FoundationModel: return super().__new__(cls) config = get_model_config(model_id) - task = config["task"] + task = config.task if task == "forecasting": return super().__new__(TimeSeriesFoundationModel) elif task in ("classification", "regression"): @@ -47,9 +72,11 @@ def __new__(cls, model_id: str, **kwargs) -> "FoundationModel": def __init__( self, model_id: str, + *, + hyperparameters: Optional[Dict[str, Any]] = None, + model_artifact_uri: Optional[str] = None, backend: Literal["sagemaker"] = "sagemaker", cloud_output_path: Optional[str] = None, - hyperparameters: Optional[Dict[str, Any]] = None, role: Optional[str] = None, ): """ @@ -57,6 +84,12 @@ def __init__( ---------- model_id ID of the foundation model from the model registry. + hyperparameters + Default hyperparameters applied to inference and (when supported) training. + model_artifact_uri + S3 URI of a pre-bundled ``model.tar.gz`` produced by + :meth:`cache_model_artifact`. When set, deploys skip the runtime + HuggingFace download and load weights from the bundled artifact. backend Cloud backend to use. cloud_output_path @@ -68,14 +101,13 @@ def __init__( * ``None`` (default) — use the bucket saved in ``~/.autogluon/cloud.yaml`` (set by :func:`autogluon.cloud.bootstrap` / :func:`autogluon.cloud.register`) and append a timestamped subfolder. Raises if no bucket is configured. - hyperparameters - Default hyperparameters applied to inference and (when supported) training. role ARN of the SageMaker execution role used to run training and inference jobs. If ``None``, falls back to ``role_arn`` in ``~/.autogluon/cloud.yaml`` (set by :func:`autogluon.cloud.bootstrap` / :func:`autogluon.cloud.register`), and finally to ``sagemaker.get_execution_role()``. """ self.model_id = model_id + self.model_artifact_uri = model_artifact_uri self.cloud_output_path = resolve_cloud_output_path(cloud_output_path, backend_name=backend) self._config = get_model_config(model_id) self._hyperparameter_overrides = hyperparameters or {} @@ -98,9 +130,14 @@ def __init__( def _get_hyperparameters( self, context: Literal["inference", "training"], overrides: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - """Merge registry defaults → constructor overrides → call-site overrides.""" - config_key = "inference_hyperparameters" if context == "inference" else "training_hyperparameters" - return self._config.get(config_key, {}) | self._hyperparameter_overrides | (overrides or {}) + """Merge registry defaults → constructor overrides → call-site overrides, + defaulting ``model_path`` to ``model_source_uri`` if not set.""" + registry_defaults = ( + self._config.inference_hyperparameters if context == "inference" else self._config.training_hyperparameters + ) + merged = registry_defaults | self._hyperparameter_overrides | (overrides or {}) + merged.setdefault("model_path", self._config.model_source_uri) + return merged @abstractmethod def _build_predictor_init_args(self, **user_kwargs) -> Dict[str, Any]: @@ -148,25 +185,29 @@ def _deploy_backend( ) -> None: """Shared deploy logic. Subclasses call this then wrap the endpoint.""" if instance_type is None: - instance_type = self._config["deploy_instance_type"] + instance_type = self._config.deploy_instance_type - serve_config = { - "model_name": self._config["model_name"], - "hyperparameters": self._get_hyperparameters("inference", hyperparameters), + merged_hp = self._get_hyperparameters("inference", hyperparameters) + if self.model_artifact_uri is not None: + merged_hp["model_path"] = _CONTAINER_WEIGHTS_DIR + fm_serve_config = { + "ag_model_key": self._config.ag_model_key, + "hyperparameters": merged_hp, } model_kwargs = backend_kwargs.pop("model_kwargs", {}) model_kwargs["entry_point"] = self._serve_script_path self._backend.deploy( - predictor_path=None, + predictor_path=self.model_artifact_uri, endpoint_name=endpoint_name, framework_version=framework_version, instance_type=instance_type, custom_image_uri=custom_image_uri, wait=wait, model_kwargs=model_kwargs, - serve_config=serve_config, + fm_serve_config=fm_serve_config, + repack=self.model_artifact_uri is None, **backend_kwargs, ) assert self._backend.endpoint is not None @@ -206,30 +247,121 @@ def fit( :meta private: """ - if not self._config.get("fine_tunable", False): + if not self._config.fine_tunable: raise ValueError(f"Model '{self.model_id}' does not support fine-tuning.") raise NotImplementedError - def cache_model_artifact(self, s3_path: str) -> str: + def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> "FoundationModel": """ - Pre-cache model weights to S3 (for VPC-deployed endpoints). + Download model weights from HuggingFace, bundle them with the FM serve script + into a SageMaker-compatible ``model.tar.gz``, and upload to S3. - Launches a small job that downloads weights from HuggingFace - and uploads them to S3. + Lets :meth:`deploy` skip the runtime HuggingFace download — required for + network-isolated endpoints (e.g. SageMaker Serverless Inference). Returns a + new :class:`FoundationModel` with ``model_artifact_uri`` set to the uploaded + tarball. + + Destination key: ``{cache_path}/{model_id}/model.tar.gz``. If it already + exists, upload is skipped unless ``overwrite=True``; a warning is logged + when the cached artifact's autogluon-cloud version differs from the + current version. Parameters ---------- - s3_path - S3 path where the model weights should be cached. + cache_path + S3 prefix under which the artifact will be uploaded. Multiple foundation + models can share one prefix. + overwrite + If True, re-upload even when the destination key exists. Returns ------- - str - S3 path to the cached artifact. - - :meta private: + FoundationModel + A new instance with ``model_artifact_uri`` populated. The original is unchanged. """ - raise NotImplementedError + try: + from huggingface_hub import snapshot_download + except ImportError as e: + raise ImportError( + "cache_model_artifact requires `huggingface_hub`. Install with: pip install huggingface_hub" + ) from e + + if not cache_path.startswith("s3://"): + raise ValueError(f"cache_path must be an s3:// URI, got: {cache_path!r}") + + source_uri = self._config.model_source_uri + cache_key = f"{cache_path.rstrip('/')}/{self.model_id}/model.tar.gz" + bucket, key = s3_path_to_bucket_prefix(cache_key) + s3 = self._backend.sagemaker_session.boto_session.client("s3") + + head = None if overwrite else _s3_head_or_none(s3, bucket, key) + if head is not None: + cached_version = head["Metadata"].get(_AG_CLOUD_VERSION_METADATA_KEY) + if cached_version != __version__: + logger.warning( + f"Cached artifact at {cache_key} was bundled with autogluon-cloud " + f"{cached_version!r}, current is {__version__!r}. " + f"Pass overwrite=True to refresh." + ) + else: + logger.info(f"Cached artifact already exists at {cache_key}; skipping upload") + else: + with tempfile.TemporaryDirectory(prefix="ag_fm_cache_") as tmp: + tmp_path = Path(tmp) + weights_dir = tmp_path / "weights" + logger.info(f"Downloading {source_uri} from HuggingFace to {weights_dir}") + snapshot_download(repo_id=source_uri, local_dir=str(weights_dir)) + + code_dir = tmp_path / "code" + code_dir.mkdir() + serve_script = Path(self._serve_script_path) + (code_dir / serve_script.name).write_bytes(serve_script.read_bytes()) + + tarball = tmp_path / "model.tar.gz" + logger.info(f"Bundling weights + serve script into {tarball}") + with tarfile.open(tarball, "w:gz") as tar: + tar.add(weights_dir, arcname="weights") + tar.add(code_dir, arcname="code") + logger.info(f"Uploading to {cache_key}") + s3.upload_file( + str(tarball), + bucket, + key, + ExtraArgs={"Metadata": {_AG_CLOUD_VERSION_METADATA_KEY: __version__}}, + ) + + return self.__class__( + model_id=self.model_id, + hyperparameters=self._hyperparameter_overrides or None, + model_artifact_uri=cache_key, + cloud_output_path=self.cloud_output_path, + role=self._backend.role_arn, + ) + + def to_dict(self) -> Dict[str, Any]: + """Serialize the model identity. Runtime context (``role``, ``cloud_output_path``) + is excluded so configs can be shared across users.""" + out: Dict[str, Any] = {"model_id": self.model_id} + if self._hyperparameter_overrides: + out["hyperparameters"] = self._hyperparameter_overrides + if self.model_artifact_uri: + out["model_artifact_uri"] = self.model_artifact_uri + return out + + def to_json(self) -> str: + """Serialize :meth:`to_dict` output as a JSON string.""" + return json.dumps(self.to_dict()) + + @classmethod + def from_dict(cls, config: Dict[str, Any], **runtime_context: Any) -> "FoundationModel": + """Restore from :meth:`to_dict` output. Pass ``role`` / ``cloud_output_path`` + as ``runtime_context``.""" + return cls(**config, **runtime_context) + + @classmethod + def from_json(cls, s: str, **runtime_context: Any) -> "FoundationModel": + """Restore from a :meth:`to_json` string.""" + return cls.from_dict(json.loads(s), **runtime_context) class TimeSeriesFoundationModel(FoundationModel): @@ -291,10 +423,9 @@ def deploy( return TimeSeriesEndpoint(self._backend.endpoint) def _build_predictor_fit_args(self, hyperparameters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - model_name = self._config["model_name"] merged_hp = self._get_hyperparameters("inference", hyperparameters) return { - "hyperparameters": {model_name: merged_hp}, + "hyperparameters": {self._config.ag_model_key: merged_hp}, "skip_model_selection": True, } @@ -383,7 +514,7 @@ def predict( Optional[pd.DataFrame] """ if instance_type is None: - instance_type = self._config["predict_instance_type"] + instance_type = self._config.predict_instance_type predictor_init_args = self._build_predictor_init_args( target=target, diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py index 2703bd3..1e5bb50 100644 --- a/src/autogluon/cloud/model/registry.py +++ b/src/autogluon/cloud/model/registry.py @@ -3,82 +3,70 @@ Maps model_id to AG-compatible configuration for deploy / predict. """ -from typing import Any, Dict, Literal, TypedDict +from dataclasses import dataclass, field +from typing import Any, Dict, Literal -class FoundationModelConfig(TypedDict): +@dataclass(frozen=True) +class FoundationModelConfig: task: Literal["forecasting", "classification", "regression"] - model_name: str # AG model class name (e.g. "Chronos", "Chronos2", "Mitra") - inference_hyperparameters: Dict[str, Any] # defaults for deploy() and predict() - training_hyperparameters: Dict[str, Any] # defaults for fit() + ag_model_key: str # key in the AG hyperparameters dict (e.g. "Chronos", "Chronos2", "Mitra") + model_source_uri: str # where weights are downloaded from (e.g. "amazon/chronos-2") predict_instance_type: str # batch predict deploy_instance_type: str # real-time endpoint fit_instance_type: str # fine-tuning - fine_tunable: bool # whether .fit() is supported + inference_hyperparameters: Dict[str, Any] = field(default_factory=dict) # defaults for deploy() and predict() + training_hyperparameters: Dict[str, Any] = field(default_factory=dict) # defaults for fit() + fine_tunable: bool = False # whether .fit() is supported -FOUNDATION_MODEL_REGISTRY: dict[str, FoundationModelConfig] = { - "chronos-bolt-tiny": { - "task": "forecasting", - "model_name": "Chronos", - "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, - "training_hyperparameters": {"model_path": "amazon/chronos-bolt-tiny"}, - "predict_instance_type": "ml.m5.2xlarge", - "deploy_instance_type": "ml.g5.xlarge", - "fit_instance_type": "ml.g5.xlarge", - "fine_tunable": False, - }, - "chronos-bolt-small": { - "task": "forecasting", - "model_name": "Chronos", - "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, - "training_hyperparameters": {"model_path": "amazon/chronos-bolt-small"}, - "predict_instance_type": "ml.m5.2xlarge", - "deploy_instance_type": "ml.g5.xlarge", - "fit_instance_type": "ml.g5.xlarge", - "fine_tunable": False, - }, - "chronos-bolt-base": { - "task": "forecasting", - "model_name": "Chronos", - "inference_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, - "training_hyperparameters": {"model_path": "amazon/chronos-bolt-base"}, - "predict_instance_type": "ml.m5.2xlarge", - "deploy_instance_type": "ml.g5.xlarge", - "fit_instance_type": "ml.g5.xlarge", - "fine_tunable": False, - }, - "chronos-2": { - "task": "forecasting", - "model_name": "Chronos2", - "inference_hyperparameters": {"model_path": "amazon/chronos-2"}, - "training_hyperparameters": {"model_path": "amazon/chronos-2", "fine_tune": True}, - "predict_instance_type": "ml.m5.2xlarge", - "deploy_instance_type": "ml.g5.xlarge", - "fit_instance_type": "ml.g5.xlarge", - "fine_tunable": True, - }, +_DEFAULT_INSTANCE_TYPES = { + "predict_instance_type": "ml.m5.2xlarge", + "deploy_instance_type": "ml.g5.xlarge", + "fit_instance_type": "ml.g5.xlarge", +} + + +FOUNDATION_MODEL_REGISTRY: Dict[str, FoundationModelConfig] = { + "chronos-bolt-tiny": FoundationModelConfig( + task="forecasting", + ag_model_key="Chronos", + model_source_uri="amazon/chronos-bolt-tiny", + **_DEFAULT_INSTANCE_TYPES, + ), + "chronos-bolt-small": FoundationModelConfig( + task="forecasting", + ag_model_key="Chronos", + model_source_uri="amazon/chronos-bolt-small", + **_DEFAULT_INSTANCE_TYPES, + ), + "chronos-bolt-base": FoundationModelConfig( + task="forecasting", + ag_model_key="Chronos", + model_source_uri="amazon/chronos-bolt-base", + **_DEFAULT_INSTANCE_TYPES, + ), + "chronos-2": FoundationModelConfig( + task="forecasting", + ag_model_key="Chronos2", + model_source_uri="amazon/chronos-2", + training_hyperparameters={"fine_tune": True}, + fine_tunable=True, + **_DEFAULT_INSTANCE_TYPES, + ), # TODO: Replace dummy configs with real values - "mitra-classification": { - "task": "classification", - "model_name": "Mitra", - "inference_hyperparameters": {"model_path": "TODO"}, - "training_hyperparameters": {"model_path": "TODO"}, - "predict_instance_type": "ml.m5.2xlarge", - "deploy_instance_type": "ml.g5.xlarge", - "fit_instance_type": "ml.g5.xlarge", - "fine_tunable": False, - }, - "mitra-regression": { - "task": "regression", - "model_name": "Mitra", - "inference_hyperparameters": {"model_path": "TODO"}, - "training_hyperparameters": {"model_path": "TODO"}, - "predict_instance_type": "ml.m5.2xlarge", - "deploy_instance_type": "ml.g5.xlarge", - "fit_instance_type": "ml.g5.xlarge", - "fine_tunable": False, - }, + "mitra-classification": FoundationModelConfig( + task="classification", + ag_model_key="Mitra", + model_source_uri="TODO", + **_DEFAULT_INSTANCE_TYPES, + ), + "mitra-regression": FoundationModelConfig( + task="regression", + ag_model_key="Mitra", + model_source_uri="TODO", + **_DEFAULT_INSTANCE_TYPES, + ), } diff --git a/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py b/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py index 613b14b..0811377 100644 --- a/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py +++ b/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py @@ -1,7 +1,7 @@ """Serve script for time series foundation models (Chronos, etc.) on SageMaker endpoints. -Config comes from the AG_SERVE_CONFIG env var (set by the backend at deploy time): - {"model_name": "Chronos", "hyperparameters": {"model_path": "amazon/chronos-bolt-base", ...}} +Config comes from the AG_FM_FM_SERVE_CONFIG env var (set by the backend at deploy time): + {"ag_model_key": "Chronos", "hyperparameters": {"model_path": "amazon/chronos-bolt-base", ...}} """ import json @@ -14,19 +14,19 @@ from autogluon.timeseries import TimeSeriesDataFrame from autogluon.timeseries.models import ModelRegistry -_SERVE_CONFIG = json.loads(os.environ.get("AG_SERVE_CONFIG", "{}")) +_FM_SERVE_CONFIG = json.loads(os.environ.get("AG_FM_FM_SERVE_CONFIG", "{}")) _SUPPORTED_INPUT_CONTENT_TYPES = {"application/x-autogluon", "application/json"} def model_fn(model_dir): """Instantiate the foundation model and load weights into memory.""" - model_name = _SERVE_CONFIG["model_name"] - hyperparameters = _SERVE_CONFIG.get("hyperparameters", {}) + ag_model_key = _FM_SERVE_CONFIG["ag_model_key"] + hyperparameters = _FM_SERVE_CONFIG.get("hyperparameters", {}) - model_cls = ModelRegistry.get_model_class(model_name) + model_cls = ModelRegistry.get_model_class(ag_model_key) # freq and prediction_length are overridden per-request in transform_fn model = model_cls( - path=model_name, + path=ag_model_key, freq=None, prediction_length=1, hyperparameters=hyperparameters, diff --git a/tests/unittests/general/test_foundation_model.py b/tests/unittests/general/test_foundation_model.py new file mode 100644 index 0000000..c7bf54c --- /dev/null +++ b/tests/unittests/general/test_foundation_model.py @@ -0,0 +1,148 @@ +"""Unit tests for FoundationModel: serialization, hyperparameter resolution, and +deploy-time wiring of model_artifact_uri / model_path.""" + +from unittest import mock + +import pytest + +from autogluon.cloud.model import FoundationModel + + +@pytest.fixture(autouse=True) +def _stub_aws(monkeypatch): + """Avoid touching AWS / config files during construction.""" + monkeypatch.setattr( + "autogluon.cloud.model.foundation_model.resolve_cloud_output_path", + lambda path, backend_name: path or "s3://stub/output", + ) + monkeypatch.setattr( + "autogluon.cloud.backend.backend_factory.BackendFactory.get_backend", + lambda **kwargs: mock.MagicMock(role_arn="arn:aws:iam::0:role/stub"), + ) + + +def test_to_dict_minimal_emits_only_model_id(): + fm = FoundationModel("chronos-2", cloud_output_path="s3://b") + assert fm.to_dict() == {"model_id": "chronos-2"} + + +def test_to_dict_includes_overrides_and_artifact_when_set(): + fm = FoundationModel( + "chronos-2", + cloud_output_path="s3://b", + hyperparameters={"context_length": 256}, + model_artifact_uri="s3://b/cache/chronos-2/abc/model.tar.gz", + ) + assert fm.to_dict() == { + "model_id": "chronos-2", + "hyperparameters": {"context_length": 256}, + "model_artifact_uri": "s3://b/cache/chronos-2/abc/model.tar.gz", + } + + +def test_to_dict_excludes_runtime_context(): + fm = FoundationModel( + "chronos-2", + cloud_output_path="s3://my-bucket/runs/", + role="arn:aws:iam::0:role/runtime", + ) + d = fm.to_dict() + assert "role" not in d + assert "cloud_output_path" not in d + assert "backend" not in d + + +def test_from_dict_round_trip(): + fm = FoundationModel( + "chronos-2", + cloud_output_path="s3://b", + hyperparameters={"context_length": 256}, + model_artifact_uri="s3://b/cache/chronos-2/abc/model.tar.gz", + ) + fm2 = FoundationModel.from_dict(fm.to_dict(), cloud_output_path="s3://other") + assert fm2.to_dict() == fm.to_dict() + + +def test_from_json_round_trip(): + fm = FoundationModel("chronos-2", cloud_output_path="s3://b") + fm2 = FoundationModel.from_json(fm.to_json(), cloud_output_path="s3://b") + assert fm2.to_dict() == fm.to_dict() + + +def test_inference_hyperparameters_default_model_path_to_source_uri(): + fm = FoundationModel("chronos-2", cloud_output_path="s3://b") + hp = fm._get_hyperparameters("inference") + assert hp["model_path"] == "amazon/chronos-2" + + +def test_user_hyperparameter_override_wins_over_default_model_path(): + fm = FoundationModel( + "chronos-2", + cloud_output_path="s3://b", + hyperparameters={"model_path": "my-org/my-finetune"}, + ) + hp = fm._get_hyperparameters("inference") + assert hp["model_path"] == "my-org/my-finetune" + + +def test_deploy_passes_artifact_uri_and_overrides_model_path_to_container_dir(): + fm = FoundationModel( + "chronos-2", + cloud_output_path="s3://b", + model_artifact_uri="s3://b/cache/chronos-2/model.tar.gz", + ) + fm._backend.endpoint = mock.MagicMock() # _deploy_backend asserts this is set after the call + fm._deploy_backend() + + call = fm._backend.deploy.call_args + assert call.kwargs["predictor_path"] == "s3://b/cache/chronos-2/model.tar.gz" + assert call.kwargs["repack"] is False + serve_cfg = call.kwargs["fm_serve_config"] + assert serve_cfg["hyperparameters"]["model_path"] == "/opt/ml/model/weights" + + +def test_deploy_without_artifact_passes_none_predictor_path_and_source_uri(): + fm = FoundationModel("chronos-2", cloud_output_path="s3://b") + fm._backend.endpoint = mock.MagicMock() + fm._deploy_backend() + + call = fm._backend.deploy.call_args + assert call.kwargs["predictor_path"] is None + assert call.kwargs["repack"] is True + serve_cfg = call.kwargs["fm_serve_config"] + assert serve_cfg["hyperparameters"]["model_path"] == "amazon/chronos-2" + + +def test_cache_model_artifact_rejects_non_s3_path(): + fm = FoundationModel("chronos-2", cloud_output_path="s3://b") + with pytest.raises(ValueError, match="s3://"): + fm.cache_model_artifact("/local/path") + + +def test_sagemaker_backend_uses_nonrepack_when_repack_is_false(): + """A pre-bundled cached artifact should bypass the SDK's download/repack/re-upload path.""" + from autogluon.cloud.backend.sagemaker_backend import SagemakerBackend + + sb = "autogluon.cloud.backend.sagemaker_backend" + with ( + mock.patch(f"{sb}.setup_sagemaker_session", return_value=mock.MagicMock(boto_region_name="us-east-1")), + mock.patch(f"{sb}.resolve_execution_role", return_value="arn:aws:iam::000000000000:role/t"), + mock.patch(f"{sb}.AutoGluonNonRepackInferenceModel") as nonrepack_cls, + mock.patch(f"{sb}.AutoGluonRepackInferenceModel") as repack_cls, + mock.patch.object(SagemakerBackend, "_upload_predictor", side_effect=lambda p, _: p), + ): + backend = SagemakerBackend( + local_output_path="/tmp/t", + cloud_output_path="s3://bucket/run", + predictor_type="timeseries", + ) + backend._fit_job = None + backend.deploy( + predictor_path="s3://bucket/cache/chronos-2/model.tar.gz", + endpoint_name="ep", + model_kwargs={"entry_point": "stub.py"}, + repack=False, + ) + + nonrepack_cls.assert_called_once() + repack_cls.assert_not_called() From 8115577813332fd8ea071c413ab56bef54d051e7 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Jun 2026 10:24:49 +0000 Subject: [PATCH 2/7] Make hf hub required dep --- setup.py | 1 + src/autogluon/cloud/model/foundation_model.py | 7 +------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 30b18b1..822bd25 100644 --- a/setup.py +++ b/setup.py @@ -122,6 +122,7 @@ def default_setup_args(*, version): # CLI dependencies (autogluon-cloud command) "click>=8.0,<9", "rich>=13.0,<15", + "huggingface_hub>=0.20,<2", ] extras_require = dict() diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 8d7a5f0..4b86c42 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -279,12 +279,7 @@ def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> " FoundationModel A new instance with ``model_artifact_uri`` populated. The original is unchanged. """ - try: - from huggingface_hub import snapshot_download - except ImportError as e: - raise ImportError( - "cache_model_artifact requires `huggingface_hub`. Install with: pip install huggingface_hub" - ) from e + from huggingface_hub import snapshot_download if not cache_path.startswith("s3://"): raise ValueError(f"cache_path must be an s3:// URI, got: {cache_path!r}") From a1c94e053827352819f581ba237d4974b181c24d Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Jun 2026 11:38:52 +0000 Subject: [PATCH 3/7] Fix env var --- src/autogluon/cloud/model/foundation_model.py | 7 +++---- .../scripts/sagemaker_scripts/timeseries_fm_serve.py | 4 ++-- tests/unittests/general/test_foundation_model.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 4b86c42..81cfab8 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -293,13 +293,12 @@ def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> " if head is not None: cached_version = head["Metadata"].get(_AG_CLOUD_VERSION_METADATA_KEY) if cached_version != __version__: - logger.warning( + raise RuntimeError( f"Cached artifact at {cache_key} was bundled with autogluon-cloud " f"{cached_version!r}, current is {__version__!r}. " - f"Pass overwrite=True to refresh." + f"Pass overwrite=True to re-bundle and re-upload." ) - else: - logger.info(f"Cached artifact already exists at {cache_key}; skipping upload") + logger.info(f"Cached artifact already exists at {cache_key}; skipping upload") else: with tempfile.TemporaryDirectory(prefix="ag_fm_cache_") as tmp: tmp_path = Path(tmp) diff --git a/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py b/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py index 0811377..813045a 100644 --- a/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py +++ b/src/autogluon/cloud/scripts/sagemaker_scripts/timeseries_fm_serve.py @@ -1,6 +1,6 @@ """Serve script for time series foundation models (Chronos, etc.) on SageMaker endpoints. -Config comes from the AG_FM_FM_SERVE_CONFIG env var (set by the backend at deploy time): +Config comes from the AG_FM_SERVE_CONFIG env var (set by the backend at deploy time): {"ag_model_key": "Chronos", "hyperparameters": {"model_path": "amazon/chronos-bolt-base", ...}} """ @@ -14,7 +14,7 @@ from autogluon.timeseries import TimeSeriesDataFrame from autogluon.timeseries.models import ModelRegistry -_FM_SERVE_CONFIG = json.loads(os.environ.get("AG_FM_FM_SERVE_CONFIG", "{}")) +_FM_SERVE_CONFIG = json.loads(os.environ.get("AG_FM_SERVE_CONFIG", "{}")) _SUPPORTED_INPUT_CONTENT_TYPES = {"application/x-autogluon", "application/json"} diff --git a/tests/unittests/general/test_foundation_model.py b/tests/unittests/general/test_foundation_model.py index c7bf54c..082f25f 100644 --- a/tests/unittests/general/test_foundation_model.py +++ b/tests/unittests/general/test_foundation_model.py @@ -119,6 +119,18 @@ def test_cache_model_artifact_rejects_non_s3_path(): fm.cache_model_artifact("/local/path") +def test_cache_model_artifact_raises_on_stale_version_without_overwrite(): + """Returning a model pointing at a tarball bundled by a different autogluon-cloud + version surfaces as a confusing endpoint failure later. Force the user to opt in.""" + fm = FoundationModel("chronos-2", cloud_output_path="s3://b") + s3 = mock.MagicMock() + s3.head_object.return_value = {"Metadata": {"autogluon-cloud-version": "0.0.0-stale"}} + fm._backend.sagemaker_session.boto_session.client.return_value = s3 + + with pytest.raises(RuntimeError, match="overwrite=True"): + fm.cache_model_artifact("s3://b/cache") + + def test_sagemaker_backend_uses_nonrepack_when_repack_is_false(): """A pre-bundled cached artifact should bypass the SDK's download/repack/re-upload path.""" from autogluon.cloud.backend.sagemaker_backend import SagemakerBackend From 59cf0c95e763b82d6428945608d3d0f12c0a471f Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Jun 2026 11:53:06 +0000 Subject: [PATCH 4/7] Address review comments --- .../cloud/backend/sagemaker_backend.py | 6 +++ src/autogluon/cloud/model/foundation_model.py | 26 +++++----- src/autogluon/cloud/model/registry.py | 47 +++++-------------- .../general/test_foundation_model.py | 4 +- 4 files changed, 34 insertions(+), 49 deletions(-) diff --git a/src/autogluon/cloud/backend/sagemaker_backend.py b/src/autogluon/cloud/backend/sagemaker_backend.py index c473fe1..fa9162a 100644 --- a/src/autogluon/cloud/backend/sagemaker_backend.py +++ b/src/autogluon/cloud/backend/sagemaker_backend.py @@ -400,6 +400,12 @@ def deploy( Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy for all options fm_serve_config: Optional[Dict[str, Any]], default = None Configuration dict passed to the FM serve script via the AG_FM_SERVE_CONFIG env var. + repack: bool, default = True + Whether the SageMaker SDK should download ``predictor_path``, inject the + entry-point script, and re-upload it. Set to False when ``predictor_path`` + already contains the serve script (e.g. an artifact bundled by + :meth:`FoundationModel.cache_model_artifact`) to skip the round-trip. + Ignored when ``predictor_path`` is None. """ assert self.endpoint is None, ( "There is an endpoint already attached. Either detach it with `detach` or clean it up with `cleanup_deployment`" diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 81cfab8..149fb21 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -73,25 +73,17 @@ def __init__( self, model_id: str, *, + cloud_output_path: Optional[str] = None, + role: Optional[str] = None, hyperparameters: Optional[Dict[str, Any]] = None, model_artifact_uri: Optional[str] = None, backend: Literal["sagemaker"] = "sagemaker", - cloud_output_path: Optional[str] = None, - role: Optional[str] = None, ): """ Parameters ---------- model_id ID of the foundation model from the model registry. - hyperparameters - Default hyperparameters applied to inference and (when supported) training. - model_artifact_uri - S3 URI of a pre-bundled ``model.tar.gz`` produced by - :meth:`cache_model_artifact`. When set, deploys skip the runtime - HuggingFace download and load weights from the bundled artifact. - backend - Cloud backend to use. cloud_output_path S3 location where intermediate artifacts are stored. Accepts: @@ -105,6 +97,14 @@ def __init__( ARN of the SageMaker execution role used to run training and inference jobs. If ``None``, falls back to ``role_arn`` in ``~/.autogluon/cloud.yaml`` (set by :func:`autogluon.cloud.bootstrap` / :func:`autogluon.cloud.register`), and finally to ``sagemaker.get_execution_role()``. + hyperparameters + Default hyperparameters applied to inference and (when supported) training. + model_artifact_uri + S3 URI of a pre-bundled ``model.tar.gz`` produced by + :meth:`cache_model_artifact`. When set, deploys skip the runtime + HuggingFace download and load weights from the bundled artifact. + backend + Cloud backend to use. """ self.model_id = model_id self.model_artifact_uri = model_artifact_uri @@ -262,9 +262,9 @@ def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> " tarball. Destination key: ``{cache_path}/{model_id}/model.tar.gz``. If it already - exists, upload is skipped unless ``overwrite=True``; a warning is logged - when the cached artifact's autogluon-cloud version differs from the - current version. + exists, upload is skipped unless ``overwrite=True``; a stale-cache mismatch + between the bundled artifact's autogluon-cloud version and the current + version raises ``RuntimeError`` and prompts the caller to re-bundle. Parameters ---------- diff --git a/src/autogluon/cloud/model/registry.py b/src/autogluon/cloud/model/registry.py index 1e5bb50..895794c 100644 --- a/src/autogluon/cloud/model/registry.py +++ b/src/autogluon/cloud/model/registry.py @@ -11,61 +11,40 @@ class FoundationModelConfig: task: Literal["forecasting", "classification", "regression"] ag_model_key: str # key in the AG hyperparameters dict (e.g. "Chronos", "Chronos2", "Mitra") - model_source_uri: str # where weights are downloaded from (e.g. "amazon/chronos-2") - predict_instance_type: str # batch predict - deploy_instance_type: str # real-time endpoint - fit_instance_type: str # fine-tuning + model_source_uri: str # where weights are downloaded from (e.g. "autogluon/chronos-2") + predict_instance_type: str = "ml.m5.2xlarge" # batch predict + deploy_instance_type: str = "ml.g5.xlarge" # real-time endpoint + fit_instance_type: str = "ml.g5.xlarge" # fine-tuning inference_hyperparameters: Dict[str, Any] = field(default_factory=dict) # defaults for deploy() and predict() training_hyperparameters: Dict[str, Any] = field(default_factory=dict) # defaults for fit() fine_tunable: bool = False # whether .fit() is supported -_DEFAULT_INSTANCE_TYPES = { - "predict_instance_type": "ml.m5.2xlarge", - "deploy_instance_type": "ml.g5.xlarge", - "fit_instance_type": "ml.g5.xlarge", -} - - FOUNDATION_MODEL_REGISTRY: Dict[str, FoundationModelConfig] = { "chronos-bolt-tiny": FoundationModelConfig( task="forecasting", ag_model_key="Chronos", - model_source_uri="amazon/chronos-bolt-tiny", - **_DEFAULT_INSTANCE_TYPES, + model_source_uri="autogluon/chronos-bolt-tiny", ), "chronos-bolt-small": FoundationModelConfig( task="forecasting", ag_model_key="Chronos", - model_source_uri="amazon/chronos-bolt-small", - **_DEFAULT_INSTANCE_TYPES, + model_source_uri="autogluon/chronos-bolt-small", ), "chronos-bolt-base": FoundationModelConfig( task="forecasting", ag_model_key="Chronos", - model_source_uri="amazon/chronos-bolt-base", - **_DEFAULT_INSTANCE_TYPES, + model_source_uri="autogluon/chronos-bolt-base", ), - "chronos-2": FoundationModelConfig( + "chronos-2-small": FoundationModelConfig( task="forecasting", ag_model_key="Chronos2", - model_source_uri="amazon/chronos-2", - training_hyperparameters={"fine_tune": True}, - fine_tunable=True, - **_DEFAULT_INSTANCE_TYPES, + model_source_uri="autogluon/chronos-2-small", ), - # TODO: Replace dummy configs with real values - "mitra-classification": FoundationModelConfig( - task="classification", - ag_model_key="Mitra", - model_source_uri="TODO", - **_DEFAULT_INSTANCE_TYPES, - ), - "mitra-regression": FoundationModelConfig( - task="regression", - ag_model_key="Mitra", - model_source_uri="TODO", - **_DEFAULT_INSTANCE_TYPES, + "chronos-2": FoundationModelConfig( + task="forecasting", + ag_model_key="Chronos2", + model_source_uri="autogluon/chronos-2", ), } diff --git a/tests/unittests/general/test_foundation_model.py b/tests/unittests/general/test_foundation_model.py index 082f25f..01905ea 100644 --- a/tests/unittests/general/test_foundation_model.py +++ b/tests/unittests/general/test_foundation_model.py @@ -72,7 +72,7 @@ def test_from_json_round_trip(): def test_inference_hyperparameters_default_model_path_to_source_uri(): fm = FoundationModel("chronos-2", cloud_output_path="s3://b") hp = fm._get_hyperparameters("inference") - assert hp["model_path"] == "amazon/chronos-2" + assert hp["model_path"] == "autogluon/chronos-2" def test_user_hyperparameter_override_wins_over_default_model_path(): @@ -110,7 +110,7 @@ def test_deploy_without_artifact_passes_none_predictor_path_and_source_uri(): assert call.kwargs["predictor_path"] is None assert call.kwargs["repack"] is True serve_cfg = call.kwargs["fm_serve_config"] - assert serve_cfg["hyperparameters"]["model_path"] == "amazon/chronos-2" + assert serve_cfg["hyperparameters"]["model_path"] == "autogluon/chronos-2" def test_cache_model_artifact_rejects_non_s3_path(): From e12ffbbee6b3cb046ce18335041600b6c534f65e Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Jun 2026 11:57:46 +0000 Subject: [PATCH 5/7] Fix docstrings --- .../cloud/backend/sagemaker_backend.py | 9 ++-- src/autogluon/cloud/model/foundation_model.py | 54 +++++++++---------- .../general/test_foundation_model.py | 8 +-- 3 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/autogluon/cloud/backend/sagemaker_backend.py b/src/autogluon/cloud/backend/sagemaker_backend.py index fa9162a..559427c 100644 --- a/src/autogluon/cloud/backend/sagemaker_backend.py +++ b/src/autogluon/cloud/backend/sagemaker_backend.py @@ -401,11 +401,10 @@ def deploy( fm_serve_config: Optional[Dict[str, Any]], default = None Configuration dict passed to the FM serve script via the AG_FM_SERVE_CONFIG env var. repack: bool, default = True - Whether the SageMaker SDK should download ``predictor_path``, inject the - entry-point script, and re-upload it. Set to False when ``predictor_path`` - already contains the serve script (e.g. an artifact bundled by - :meth:`FoundationModel.cache_model_artifact`) to skip the round-trip. - Ignored when ``predictor_path`` is None. + Whether the SageMaker SDK should download ``predictor_path``, inject the entry-point script, and re-upload + it. Set to False when ``predictor_path`` already contains the serve script (e.g. an artifact bundled by + :meth:`FoundationModel.cache_model_artifact`) to skip the round-trip. Ignored when ``predictor_path`` is + None. """ assert self.endpoint is None, ( "There is an endpoint already attached. Either detach it with `detach` or clean it up with `cleanup_deployment`" diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 149fb21..5920cd0 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -88,11 +88,11 @@ def __init__( S3 location where intermediate artifacts are stored. Accepts: * ``s3://bucket`` — a unique timestamped subfolder ``ag-`` is appended. - * ``s3://bucket/prefix`` — used verbatim. Re-running with the same prefix - will overwrite previously written artifacts. - * ``None`` (default) — use the bucket saved in ``~/.autogluon/cloud.yaml`` (set - by :func:`autogluon.cloud.bootstrap` / :func:`autogluon.cloud.register`) and - append a timestamped subfolder. Raises if no bucket is configured. + * ``s3://bucket/prefix`` — used verbatim. Re-running with the same prefix will overwrite previously written + artifacts. + * ``None`` (default) — use the bucket saved in ``~/.autogluon/cloud.yaml`` (set by + :func:`autogluon.cloud.bootstrap` / :func:`autogluon.cloud.register`) and append a timestamped subfolder. + Raises if no bucket is configured. role ARN of the SageMaker execution role used to run training and inference jobs. If ``None``, falls back to ``role_arn`` in ``~/.autogluon/cloud.yaml`` (set by :func:`autogluon.cloud.bootstrap` / @@ -100,9 +100,8 @@ def __init__( hyperparameters Default hyperparameters applied to inference and (when supported) training. model_artifact_uri - S3 URI of a pre-bundled ``model.tar.gz`` produced by - :meth:`cache_model_artifact`. When set, deploys skip the runtime - HuggingFace download and load weights from the bundled artifact. + S3 URI of a pre-bundled ``model.tar.gz`` produced by :meth:`cache_model_artifact`. When set, deploys skip + the runtime HuggingFace download and load weights from the bundled artifact. backend Cloud backend to use. """ @@ -130,11 +129,12 @@ def __init__( def _get_hyperparameters( self, context: Literal["inference", "training"], overrides: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: - """Merge registry defaults → constructor overrides → call-site overrides, - defaulting ``model_path`` to ``model_source_uri`` if not set.""" - registry_defaults = ( - self._config.inference_hyperparameters if context == "inference" else self._config.training_hyperparameters - ) + """Merge registry defaults → constructor overrides → call-site overrides, defaulting ``model_path`` to + ``model_source_uri`` if not set.""" + if context == "inference": + registry_defaults = self._config.inference_hyperparameters + else: + registry_defaults = self._config.training_hyperparameters merged = registry_defaults | self._hyperparameter_overrides | (overrides or {}) merged.setdefault("model_path", self._config.model_source_uri) return merged @@ -253,24 +253,21 @@ def fit( def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> "FoundationModel": """ - Download model weights from HuggingFace, bundle them with the FM serve script - into a SageMaker-compatible ``model.tar.gz``, and upload to S3. + Download model weights from HuggingFace, bundle them with the FM serve script into a SageMaker-compatible + ``model.tar.gz``, and upload to S3. - Lets :meth:`deploy` skip the runtime HuggingFace download — required for - network-isolated endpoints (e.g. SageMaker Serverless Inference). Returns a - new :class:`FoundationModel` with ``model_artifact_uri`` set to the uploaded - tarball. + Lets :meth:`deploy` skip the runtime HuggingFace download — required for network-isolated endpoints (e.g. + SageMaker Serverless Inference). Returns a new :class:`FoundationModel` with ``model_artifact_uri`` set to the + uploaded tarball. - Destination key: ``{cache_path}/{model_id}/model.tar.gz``. If it already - exists, upload is skipped unless ``overwrite=True``; a stale-cache mismatch - between the bundled artifact's autogluon-cloud version and the current - version raises ``RuntimeError`` and prompts the caller to re-bundle. + Destination key: ``{cache_path}/{model_id}/model.tar.gz``. If it already exists, upload is skipped unless + ``overwrite=True``; a stale-cache mismatch between the bundled artifact's autogluon-cloud version and the + current version raises ``RuntimeError`` and prompts the caller to re-bundle. Parameters ---------- cache_path - S3 prefix under which the artifact will be uploaded. Multiple foundation - models can share one prefix. + S3 prefix under which the artifact will be uploaded. Multiple foundation models can share one prefix. overwrite If True, re-upload even when the destination key exists. @@ -333,8 +330,8 @@ def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> " ) def to_dict(self) -> Dict[str, Any]: - """Serialize the model identity. Runtime context (``role``, ``cloud_output_path``) - is excluded so configs can be shared across users.""" + """Serialize the model identity. Runtime context (``role``, ``cloud_output_path``) is excluded so configs can + be shared across users.""" out: Dict[str, Any] = {"model_id": self.model_id} if self._hyperparameter_overrides: out["hyperparameters"] = self._hyperparameter_overrides @@ -348,8 +345,7 @@ def to_json(self) -> str: @classmethod def from_dict(cls, config: Dict[str, Any], **runtime_context: Any) -> "FoundationModel": - """Restore from :meth:`to_dict` output. Pass ``role`` / ``cloud_output_path`` - as ``runtime_context``.""" + """Restore from :meth:`to_dict` output. Pass ``role`` / ``cloud_output_path`` as ``runtime_context``.""" return cls(**config, **runtime_context) @classmethod diff --git a/tests/unittests/general/test_foundation_model.py b/tests/unittests/general/test_foundation_model.py index 01905ea..804a102 100644 --- a/tests/unittests/general/test_foundation_model.py +++ b/tests/unittests/general/test_foundation_model.py @@ -1,5 +1,5 @@ -"""Unit tests for FoundationModel: serialization, hyperparameter resolution, and -deploy-time wiring of model_artifact_uri / model_path.""" +"""Unit tests for FoundationModel: serialization, hyperparameter resolution, and deploy-time wiring of +model_artifact_uri / model_path.""" from unittest import mock @@ -120,8 +120,8 @@ def test_cache_model_artifact_rejects_non_s3_path(): def test_cache_model_artifact_raises_on_stale_version_without_overwrite(): - """Returning a model pointing at a tarball bundled by a different autogluon-cloud - version surfaces as a confusing endpoint failure later. Force the user to opt in.""" + """Returning a model pointing at a tarball bundled by a different autogluon-cloud version surfaces as a confusing + endpoint failure later. Force the user to opt in.""" fm = FoundationModel("chronos-2", cloud_output_path="s3://b") s3 = mock.MagicMock() s3.head_object.return_value = {"Metadata": {"autogluon-cloud-version": "0.0.0-stale"}} From 16627d597e16ed13fd64a2e9bd86b7c1774fec7d Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Jun 2026 12:01:22 +0000 Subject: [PATCH 6/7] Address review comments --- src/autogluon/cloud/model/foundation_model.py | 13 ++++++- .../general/test_foundation_model.py | 38 ++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index 5920cd0..b77afdc 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -189,6 +189,15 @@ def _deploy_backend( merged_hp = self._get_hyperparameters("inference", hyperparameters) if self.model_artifact_uri is not None: + user_model_path = (hyperparameters or {}).get("model_path") or self._hyperparameter_overrides.get( + "model_path" + ) + if user_model_path is not None: + raise ValueError( + "Cannot set hyperparameters['model_path'] when model_artifact_uri is in use — the bundled artifact " + f"determines the in-container weights path ({_CONTAINER_WEIGHTS_DIR}). Drop model_path, or call " + "deploy() on a FoundationModel without model_artifact_uri." + ) merged_hp["model_path"] = _CONTAINER_WEIGHTS_DIR fm_serve_config = { "ag_model_key": self._config.ag_model_key, @@ -198,6 +207,8 @@ def _deploy_backend( model_kwargs = backend_kwargs.pop("model_kwargs", {}) model_kwargs["entry_point"] = self._serve_script_path + # FM deploys never want SDK repack: predictor_path is either None (script-only tarball is built locally) or a + # pre-bundled cache artifact that already contains the serve script. self._backend.deploy( predictor_path=self.model_artifact_uri, endpoint_name=endpoint_name, @@ -207,7 +218,7 @@ def _deploy_backend( wait=wait, model_kwargs=model_kwargs, fm_serve_config=fm_serve_config, - repack=self.model_artifact_uri is None, + repack=False, **backend_kwargs, ) assert self._backend.endpoint is not None diff --git a/tests/unittests/general/test_foundation_model.py b/tests/unittests/general/test_foundation_model.py index 804a102..53a253e 100644 --- a/tests/unittests/general/test_foundation_model.py +++ b/tests/unittests/general/test_foundation_model.py @@ -1,6 +1,7 @@ """Unit tests for FoundationModel: serialization, hyperparameter resolution, and deploy-time wiring of model_artifact_uri / model_path.""" +from pathlib import Path from unittest import mock import pytest @@ -108,17 +109,52 @@ def test_deploy_without_artifact_passes_none_predictor_path_and_source_uri(): call = fm._backend.deploy.call_args assert call.kwargs["predictor_path"] is None - assert call.kwargs["repack"] is True + assert call.kwargs["repack"] is False serve_cfg = call.kwargs["fm_serve_config"] assert serve_cfg["hyperparameters"]["model_path"] == "autogluon/chronos-2" +def test_deploy_rejects_user_model_path_when_artifact_uri_set(): + """User-supplied model_path is incoherent with model_artifact_uri (the bundled tarball dictates the in-container + path). Raise rather than silently overwrite.""" + fm = FoundationModel( + "chronos-2", + cloud_output_path="s3://b", + model_artifact_uri="s3://b/cache/chronos-2/model.tar.gz", + ) + fm._backend.endpoint = mock.MagicMock() + with pytest.raises(ValueError, match="model_artifact_uri"): + fm._deploy_backend(hyperparameters={"model_path": "my-org/something-else"}) + + def test_cache_model_artifact_rejects_non_s3_path(): fm = FoundationModel("chronos-2", cloud_output_path="s3://b") with pytest.raises(ValueError, match="s3://"): fm.cache_model_artifact("/local/path") +def test_cache_model_artifact_uploads_with_version_metadata(monkeypatch): + """On cache miss, upload_file runs with the version metadata key — that's the cache-invalidation contract.""" + from autogluon.cloud.version import __version__ + + fm = FoundationModel("chronos-2", cloud_output_path="s3://b") + s3 = mock.MagicMock() + fm._backend.sagemaker_session.boto_session.client.return_value = s3 + monkeypatch.setattr("autogluon.cloud.model.foundation_model._s3_head_or_none", lambda *_: None) + monkeypatch.setattr("autogluon.cloud.model.foundation_model.tarfile", mock.MagicMock()) + monkeypatch.setattr("huggingface_hub.snapshot_download", mock.MagicMock()) + monkeypatch.setattr(FoundationModel, "_serve_script_path", "/tmp/nonexistent-stub.py") + monkeypatch.setattr(Path, "read_bytes", lambda self: b"") + monkeypatch.setattr(Path, "write_bytes", lambda self, data: None) + + new_fm = fm.cache_model_artifact("s3://b/cache") + + assert new_fm.model_artifact_uri == "s3://b/cache/chronos-2/model.tar.gz" + s3.upload_file.assert_called_once() + metadata = s3.upload_file.call_args.kwargs["ExtraArgs"]["Metadata"] + assert metadata == {"autogluon-cloud-version": __version__} + + def test_cache_model_artifact_raises_on_stale_version_without_overwrite(): """Returning a model pointing at a tarball bundled by a different autogluon-cloud version surfaces as a confusing endpoint failure later. Force the user to opt in.""" From 5b5fc5943d95337af35e487d1ed4fc90ac3d57a6 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Tue, 2 Jun 2026 12:38:15 +0000 Subject: [PATCH 7/7] Fix serve utils packaging --- .github/workflow_scripts/test_cloud.sh | 2 +- src/autogluon/cloud/model/foundation_model.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflow_scripts/test_cloud.sh b/.github/workflow_scripts/test_cloud.sh index 89460d0..654ddc4 100755 --- a/.github/workflow_scripts/test_cloud.sh +++ b/.github/workflow_scripts/test_cloud.sh @@ -17,4 +17,4 @@ fi install_cloud_test -python3 -m pytest -n 2 --junitxml=results.xml tests/unittests/$MODULE/ --framework_version $AG_VERSION +python3 -m pytest -n 4 --junitxml=results.xml tests/unittests/$MODULE/ --framework_version $AG_VERSION diff --git a/src/autogluon/cloud/model/foundation_model.py b/src/autogluon/cloud/model/foundation_model.py index b77afdc..4429240 100644 --- a/src/autogluon/cloud/model/foundation_model.py +++ b/src/autogluon/cloud/model/foundation_model.py @@ -314,16 +314,16 @@ def cache_model_artifact(self, cache_path: str, *, overwrite: bool = False) -> " logger.info(f"Downloading {source_uri} from HuggingFace to {weights_dir}") snapshot_download(repo_id=source_uri, local_dir=str(weights_dir)) - code_dir = tmp_path / "code" - code_dir.mkdir() + # Mirror the layout produced by SagemakerBackend._create_serve_script_tarball: + # entry-point script + serving_utils/ under code/, so the cached endpoint can + # `from serving_utils.timeseries import ...` exactly like a fresh deploy. serve_script = Path(self._serve_script_path) - (code_dir / serve_script.name).write_bytes(serve_script.read_bytes()) - tarball = tmp_path / "model.tar.gz" logger.info(f"Bundling weights + serve script into {tarball}") with tarfile.open(tarball, "w:gz") as tar: tar.add(weights_dir, arcname="weights") - tar.add(code_dir, arcname="code") + tar.add(serve_script, arcname=f"code/{serve_script.name}") + tar.add(ScriptManager.SAGEMAKER_SERVING_UTILS_DIR, arcname="code/serving_utils") logger.info(f"Uploading to {cache_key}") s3.upload_file( str(tarball),