Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflow_scripts/test_cloud.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ fi

install_cloud_test

python3 -m pytest -n 2 --junitxml=results.xml tests/unittests/$MODULE/ --framework_version $AG_VERSION
python3 -m pytest -n 4 --junitxml=results.xml tests/unittests/$MODULE/ --framework_version $AG_VERSION
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def default_setup_args(*, version):
# CLI dependencies (autogluon-cloud command)
"click>=8.0,<9",
"rich>=13.0,<15",
"huggingface_hub>=0.20,<2",
]

extras_require = dict()
Expand Down
36 changes: 22 additions & 14 deletions src/autogluon/cloud/backend/sagemaker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def deploy(
wait: bool = True,
model_kwargs: Optional[Dict] = None,
deploy_kwargs: Optional[Dict] = None,
serve_config: Optional[Dict[str, Any]] = None,
fm_serve_config: Optional[Dict[str, Any]] = None,
repack: bool = True,
) -> None:
"""
Deploy a predictor as a SageMaker endpoint, which can be used to do real-time inference later.
Expand Down Expand Up @@ -397,8 +398,13 @@ def deploy(
deploy_kwargs:
Any extra arguments needed to pass to deploy.
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy for all options
serve_config: Optional[Dict[str, Any]], default = None
Configuration dict passed to the serve script via the AG_SERVE_CONFIG env var.
fm_serve_config: Optional[Dict[str, Any]], default = None
Configuration dict passed to the FM serve script via the AG_FM_SERVE_CONFIG env var.
repack: bool, default = True
Whether the SageMaker SDK should download ``predictor_path``, inject the entry-point script, and re-upload
it. Set to False when ``predictor_path`` already contains the serve script (e.g. an artifact bundled by
:meth:`FoundationModel.cache_model_artifact`) to skip the round-trip. Ignored when ``predictor_path`` is
None.
"""
assert self.endpoint is None, (
"There is an endpoint already attached. Either detach it with `detach` or clean it up with `cleanup_deployment`"
Expand Down Expand Up @@ -444,19 +450,21 @@ def deploy(
)
entry_point = self._serve_script_path

# Pick model class:
# - No artifact → create minimal tarball with serve script (FM deploy)
# - Artifact from different source or custom entry point → Repack (inject script into tarball)
# - Artifact from fit job, default entry point → NonRepack (script already in tarball)
# Pick model class. The question is whether the tarball already contains
# the entry_point script — if yes, NonRepack uses it as-is; if no, Repack
# injects the script at deploy time.
if predictor_path is None:
predictor_path = self._create_serve_script_tarball(entry_point, endpoint_name)
model_cls = AutoGluonNonRepackInferenceModel
elif not repack:
model_cls = AutoGluonNonRepackInferenceModel
else:
fit_output = self._fit_job.get_output_path() if self._fit_job is not None else None
if predictor_path != fit_output or user_entry_point is not None:
model_cls = AutoGluonRepackInferenceModel
else:
model_cls = AutoGluonNonRepackInferenceModel
is_default_fit_output = (
self._fit_job is not None
and predictor_path == self._fit_job.get_output_path()
and user_entry_point is None
)
model_cls = AutoGluonNonRepackInferenceModel if is_default_fit_output else AutoGluonRepackInferenceModel

# Assemble env vars and deploy
predictor_cls = self._realtime_predictor_cls
Expand All @@ -483,8 +491,8 @@ def deploy(
else:
model_kwargs_env = {SAGEMAKER_MODEL_SERVER_WORKERS: "1"}

if serve_config is not None:
model_kwargs_env["AG_SERVE_CONFIG"] = json.dumps(serve_config)
if fm_serve_config is not None:
model_kwargs_env["AG_FM_SERVE_CONFIG"] = json.dumps(fm_serve_config)

model = model_cls(
model_data=predictor_path,
Expand Down
Loading
Loading