Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

## Train and Deploy AutoGluon in the Cloud

[![Latest Release](https://img.shields.io/github/v/release/autogluon/autogluon-cloud)](https://github.com/autogluon/autogluon-cloud/releases)
[![PyPI](https://img.shields.io/pypi/v/autogluon.cloud.svg)](https://pypi.org/project/autogluon.cloud/)
[![Python Versions](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.12-blue)](https://pypi.org/project/autogluon.cloud/)
[![GitHub license](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](./LICENSE)
[![Continuous Integration](https://github.com/autogluon/autogluon-cloud/actions/workflows/continuous_integration.yml/badge.svg)](https://github.com/autogluon/autogluon-cloud/actions/workflows/continuous_integration.yml)
Expand All @@ -14,12 +14,12 @@

</div>

AutoGluon-Cloud makes it easy to run [AutoGluon](https://auto.gluon.ai/stable/index.html) in the cloud. With a few lines of code, you can train models and run inference on [Amazon SageMaker](https://aws.amazon.com/sagemaker/) without managing infrastructure or installing AutoGluon's heavy dependencies on your local machine.
[AutoGluon](https://auto.gluon.ai/stable/index.html) is an open-source AutoML library that trains state-of-the-art ML models on tabular, time-series, and multimodal data with just a few lines of code. AutoGluon-Cloud takes that same API and runs it on AWS — train models and serve predictions on [Amazon SageMaker](https://aws.amazon.com/sagemaker/) without managing infrastructure or setting up a heavyweight ML environment on your local machine.

It supports two workflows:

- **Train AutoGluon predictors in the cloud** — the same `fit → deploy → predict` workflow as local AutoGluon, with all the heavy lifting offloaded to SageMaker.
- **Run pretrained foundation models** — deploy state-of-the-art pretrained models like Chronos-2 for zero-shot inference, with no training required.
- **[Train your own predictor](https://auto.gluon.ai/cloud/stable/tutorials/predictor-tabular.html)** — the same `fit → deploy → predict` workflow as local AutoGluon, with all the heavy lifting offloaded to SageMaker.
- **[Run pretrained foundation models](https://auto.gluon.ai/cloud/stable/tutorials/foundation-model-timeseries.html)** — deploy state-of-the-art pretrained models like [Chronos-2](https://huggingface.co/amazon/chronos-2) for zero-shot inference, with no training required.

## 💾 Installation & setup

Expand Down
32 changes: 26 additions & 6 deletions src/autogluon/cloud/model/foundation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,14 @@ class FoundationModel:
"""
Pretrained foundation model inference on AWS.

Factory: FoundationModel("chronos-bolt-base", ...) returns the appropriate
task-specific subclass (TimeSeriesFoundationModel, TabularFoundationModel).
Factory: ``FoundationModel(model_id, ...)`` dispatches on the model's task and returns the
appropriate task-specific subclass (:class:`TimeSeriesFoundationModel`, ``TabularFoundationModel``).
Most users instantiate the subclass directly instead.

Examples
--------
>>> model = FoundationModel("chronos-bolt-base")
>>> predictions = model.predict(data, prediction_length=12)
>>> model = FoundationModel("chronos-2") # returns a TimeSeriesFoundationModel
>>> predictions = model.predict(data, prediction_length=24)
"""

_backend_map: Dict[str, str] = {}
Expand Down Expand Up @@ -84,7 +85,9 @@ def __init__(
Parameters
----------
model_id
ID of the foundation model from the model registry.
ID of the foundation model from the model registry. See
`Available models <https://auto.gluon.ai/cloud/stable/tutorials/foundation-model-timeseries.html#available-models>`_
in the foundation model tutorial for the list of supported values.
cloud_output_path
S3 location where intermediate artifacts are stored. Accepts:

Expand Down Expand Up @@ -373,7 +376,24 @@ def from_json(cls, s: str, **runtime_context: Any) -> "FoundationModel":


class TimeSeriesFoundationModel(FoundationModel):
"""Foundation model for time series forecasting (Chronos, etc.)."""
"""Pretrained time series foundation model for zero-shot forecasting on AWS SageMaker.

Wraps pretrained models like `Chronos-2 <https://huggingface.co/autogluon/chronos-2>`_ and
Chronos-Bolt and runs prediction as a managed SageMaker job, with no training required. See
`the foundation model tutorial <https://auto.gluon.ai/cloud/stable/tutorials/foundation-model-timeseries.html>`_
for the supported ``model_id`` values and a full walkthrough.

Predictions can be produced in three modes:

* **Batch** — :meth:`predict` runs a one-off SageMaker training job and writes forecasts to S3.
Best for one-shot inference.
* **Real-time** — :meth:`deploy` provisions a real-time endpoint; call
:meth:`TimeSeriesEndpoint.predict` for low-latency inference, then
:meth:`TimeSeriesEndpoint.delete_endpoint` to tear it down.
* **Serverless** — :meth:`deploy` with ``inference_mode="serverless"`` provisions a SageMaker
Serverless Inference endpoint that scales to zero. Requires a cached model artifact (see
:meth:`cache_model_artifact`).
"""

_backend_map = {SAGEMAKER: TIMESERIES_SAGEMAKER}
_predictor_type = "timeseries"
Expand Down
20 changes: 10 additions & 10 deletions src/autogluon/cloud/predictor/cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def fit(
For SageMaker backend, valid keys are:
1. autogluon_sagemaker_estimator_kwargs
Any extra arguments needed to initialize AutoGluonSagemakerEstimator
Please refer to https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/training/estimators.html#sagemaker.estimator.Estimator for all options
2. fit_kwargs
Any extra arguments needed to pass to fit.
Please refer to https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator.fit for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/training/estimators.html#sagemaker.estimator.Estimator.fit for all options
For RayAWS backend, valid keys are:
1. custom_config: Optional[Union[str, Dict[str, Any]]] = None,
The custom cluster configuration.
Expand Down Expand Up @@ -441,10 +441,10 @@ def deploy(
For SageMaker backend, valid keys are:
1. model_kwargs: dict, default = dict()
Any extra arguments needed to initialize Sagemaker Model
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#model for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/model.html#model for all options
2. deploy_kwargs
Any extra arguments needed to pass to deploy.
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/model.html#sagemaker.model.Model.deploy for all options
"""
if inference_mode == "serverless" and instance_type is not None:
raise ValueError("`instance_type` must not be set when `inference_mode='serverless'`.")
Expand Down Expand Up @@ -621,14 +621,14 @@ def predict(
If `persist` is `False`, file would first be downloaded to this path and then removed.
4. model_kwargs: dict, default = dict()
Any extra arguments needed to initialize Sagemaker Model
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#model for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/model.html#model for all options
5. transformer_kwargs: dict
Any extra arguments needed to pass to transformer.
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer for all options.
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/transformer.html#sagemaker.transformer.Transformer for all options.
6. transform_kwargs:
Any extra arguments needed to pass to transform.
Please refer to
https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.
https://sagemaker.readthedocs.io/en/v2/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.

Returns
-------
Expand Down Expand Up @@ -716,14 +716,14 @@ def predict_proba(
If `persist` is `False`, file would first be downloaded to this path and then removed.
4. model_kwargs: dict, default = dict()
Any extra arguments needed to initialize Sagemaker Model
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#model for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/model.html#model for all options
5. transformer_kwargs: dict
Any extra arguments needed to pass to transformer.
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer for all options.
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/transformer.html#sagemaker.transformer.Transformer for all options.
6. transform_kwargs:
Any extra arguments needed to pass to transform.
Please refer to
https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.
https://sagemaker.readthedocs.io/en/v2/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.

Returns
-------
Expand Down
30 changes: 22 additions & 8 deletions src/autogluon/cloud/predictor/timeseries_cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def fit(
For SageMaker backend, valid keys are:
1. autogluon_sagemaker_estimator_kwargs
Any extra arguments needed to initialize AutoGluonSagemakerEstimator
Please refer to https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/training/estimators.html#sagemaker.estimator.Estimator for all options
2. fit_kwargs
Any extra arguments needed to pass to fit.
Please refer to https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.Estimator.fit for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/training/estimators.html#sagemaker.estimator.Estimator.fit for all options

Returns
-------
Expand Down Expand Up @@ -286,14 +286,14 @@ def predict(
If `persist` is `False`, file would first be downloaded to this path and then removed.
4. model_kwargs: dict, default = dict()
Any extra arguments needed to initialize Sagemaker Model
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#model for all options
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/model.html#model for all options
5. transformer_kwargs: dict
Any extra arguments needed to pass to transformer.
Please refer to https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer for all options.
Please refer to https://sagemaker.readthedocs.io/en/v2/api/inference/transformer.html#sagemaker.transformer.Transformer for all options.
6. transform_kwargs:
Any extra arguments needed to pass to transform.
Please refer to
https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.
https://sagemaker.readthedocs.io/en/v2/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.
"""
if backend_kwargs is None:
backend_kwargs = {}
Expand Down Expand Up @@ -374,9 +374,23 @@ def fit_predict(
Name of the column with the unique identifier of each time series (item).
timestamp_column: str, default = "timestamp"
Name of the column with the observation timestamps.
framework_version, job_name, instance_type, instance_count, volume_size, custom_image_uri, wait,
backend_kwargs:
Same semantics as ``fit()``.
framework_version: str, default = `latest`
Training container version of autogluon. If `latest`, will use the latest available container version.
If `custom_image_uri` is set, this argument will be ignored.
job_name: str, default = None
Name of the launched training job. If None, CloudPredictor will create one with prefix ag-cloudpredictor.
instance_type: str, default = 'ml.m5.2xlarge'
Instance type the predictor will be trained on with SageMaker.
instance_count: int, default = 1
Number of instances used to fit the predictor.
volume_size: int, default = 100
Size in GB of the EBS volume to use for storing input data during training.
custom_image_uri: Optional[str], default = None
Custom container image URI. If set, ``framework_version`` is ignored.
wait: bool, default = True
Whether the call should wait until the job completes.
backend_kwargs: Optional[dict], default = None
Backend-specific arguments. Same keys as ``fit()``.
predictions_path: Optional[str]
S3 URL where predictions will be written by the training container (e.g.
``s3://my-bucket/runs/2024-05-01/predictions.csv``). The container's SageMaker execution role must
Expand Down
Loading