Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/autogluon/cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging

from autogluon.common.utils.log_utils import _add_stream_handler

from .cloud_setup import bootstrap, register, status, teardown
from .model.foundation_model import TimeSeriesFoundationModel
from .predictor import MultiModalCloudPredictor, TabularCloudPredictor, TimeSeriesCloudPredictor

_add_stream_handler()
logging.getLogger(__name__).setLevel(logging.INFO)

__all__ = [
"MultiModalCloudPredictor",
Expand Down
24 changes: 16 additions & 8 deletions src/autogluon/cloud/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from __future__ import annotations

import logging
from contextlib import contextmanager
from typing import Optional

import boto3
Expand Down Expand Up @@ -48,6 +50,18 @@ def _abort_on_error(fn, *args, **kwargs):
raise click.ClickException(str(e)) from e


@contextmanager
def _quiet_logger(name: str, level: int = logging.WARNING):
"""Temporarily raise a logger's level so its INFO output doesn't fight the rich spinner."""
logger = logging.getLogger(name)
prior = logger.level
logger.setLevel(level)
try:
yield
finally:
logger.setLevel(prior)


@click.group()
@click.version_option(package_name="autogluon.cloud")
def cli() -> None:
Expand Down Expand Up @@ -98,14 +112,8 @@ def bootstrap(
if not yes and not Confirm.ask("Proceed?", default=True):
raise click.Abort()

with _console.status(f"Deploying stack '{effective_stack}'...", spinner="dots"):
_abort_on_error(
_bootstrap,
backend=backend,
stack_name=effective_stack,
session=session,
verbose=False,
)
with _quiet_logger("autogluon.cloud"), _console.status(f"Deploying stack '{effective_stack}'...", spinner="dots"):
_abort_on_error(_bootstrap, backend=backend, stack_name=effective_stack, session=session)

config = load_config()
if config and backend in config.backends:
Expand Down
35 changes: 14 additions & 21 deletions src/autogluon/cloud/cloud_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from importlib import resources
from typing import Dict, Literal, Optional
Expand All @@ -31,6 +32,8 @@

__all__ = ["bootstrap", "register", "status", "teardown", "StatusReport"]

logger = logging.getLogger(__name__)


@dataclass
class StatusReport:
Expand All @@ -50,7 +53,6 @@ def bootstrap(
backend: BackendName = "sagemaker",
stack_name: Optional[str] = None,
session: Optional[boto3.Session] = None,
verbose: bool = True,
) -> None:
"""Deploy the CloudFormation stack and persist resource identifiers.

Expand All @@ -70,8 +72,6 @@ def bootstrap(
session
A ``boto3.Session`` to use for AWS calls. If ``None``, a default session is constructed from the standard
credential chain (env vars, ``~/.aws/credentials``, SSO, instance profile).
verbose
If ``True`` (default), print progress messages to stdout.
"""
if backend not in SUPPORTED_BACKENDS:
raise ValueError(f"Unsupported backend {backend!r}. Choose from {SUPPORTED_BACKENDS}.")
Expand All @@ -85,19 +85,16 @@ def bootstrap(
)
stack_name = stack_name or f"ag-cloud-{backend.replace('_', '-')}"

if verbose:
print(f"Deploying CloudFormation stack {stack_name!r} (account {account}, region {region}, ~1 minute)...")
logger.info(f"Deploying CloudFormation stack {stack_name!r} (account {account}, region {region}, ~1 minute)...")
role_arn, bucket = _provision_stack(session, stack_name=stack_name, backend=backend)
if verbose:
print(f"Stack {stack_name!r} deployed.")
logger.info(f"Stack {stack_name!r} deployed.")

register(
role=role_arn,
bucket=bucket,
region=region,
backend=backend,
stack_name=stack_name,
verbose=verbose,
)


Expand All @@ -108,7 +105,6 @@ def register(
region: str,
backend: BackendName = "sagemaker",
stack_name: Optional[str] = None,
verbose: bool = True,
) -> None:
"""Persist resource identifiers to ``~/.autogluon/cloud.yaml`` under the given backend key.

Expand All @@ -133,8 +129,6 @@ def register(
Optional CloudFormation stack name. If you deployed the resources via your own CFN stack and want
:func:`teardown` to be able to delete it later, pass the name here. Defaults to ``None``, meaning teardown
will only remove the config entry, not touch AWS.
verbose
If ``True`` (default), print progress messages to stdout.
"""
if backend not in SUPPORTED_BACKENDS:
raise ValueError(f"Unsupported backend {backend!r}. Choose from {SUPPORTED_BACKENDS}.")
Expand All @@ -146,8 +140,7 @@ def register(
stack_name=stack_name,
)
save_config(config)
if verbose:
print(f"Saved AutoGluon-Cloud config for backend {backend!r} to {get_config_path()}")
logger.info(f"Saved AutoGluon-Cloud config for backend {backend!r} to {get_config_path()}")


def status(
Expand Down Expand Up @@ -210,36 +203,36 @@ def teardown(
"""
config = load_config()
if config is None or not config.backends:
print("No AutoGluon-Cloud config found — nothing to tear down.")
logger.warning("No AutoGluon-Cloud config found — nothing to tear down.")
return

if backend is not None and backend not in config.backends:
print(f"Backend {backend!r} not in config. Available: {sorted(config.backends)}")
logger.warning(f"Backend {backend!r} not in config. Available: {sorted(config.backends)}")
return

targets = [backend] if backend is not None else list(config.backends)
for name in targets:
backend_config = config.backends[name]
if backend_config.stack_name is None:
print(f"[{name}] no stack to delete.")
logger.info(f"[{name}] no stack to delete.")
else:
sess, account = _verified_session(session or boto3.Session(region_name=backend_config.region))
print(
logger.info(
f"[{name}] Deleting CloudFormation stack {backend_config.stack_name!r} "
f"(account {account}, region {backend_config.region}, ~1 minute)..."
)
cfn = sess.client("cloudformation")
cfn.delete_stack(StackName=backend_config.stack_name)
cfn.get_waiter("stack_delete_complete").wait(StackName=backend_config.stack_name)
print(f"[{name}] Stack {backend_config.stack_name!r} deleted.")
logger.info(f"[{name}] Stack {backend_config.stack_name!r} deleted.")
del config.backends[name]

if config.backends:
save_config(config)
print(f"Removed {targets} from config; remaining backends: {sorted(config.backends)}.")
logger.info(f"Removed {targets} from config; remaining backends: {sorted(config.backends)}.")
else:
delete_config()
print("Removed config file.")
logger.info("Removed config file.")


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -280,7 +273,7 @@ def _provision_stack(session: boto3.Session, *, stack_name: str, backend: Backen
if e.response["Error"]["Code"] != "AlreadyExistsException":
raise
stack_existed = True
print(f"Stack {stack_name!r} already exists — reusing it.")
logger.warning(f"Stack {stack_name!r} already exists — reusing it.")

if not stack_existed:
cfn.get_waiter("stack_create_complete").wait(StackName=stack_name)
Expand Down
1 change: 0 additions & 1 deletion tests/unittests/general/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_bootstrap_with_yes_skips_confirm_and_calls_python_api(runner, monkeypat
"backend": "sagemaker",
"stack_name": "my-stack",
"session": pytest.approx(calls["bootstrap"][0]["session"]),
"verbose": False,
}
]

Expand Down
55 changes: 34 additions & 21 deletions tests/unittests/general/test_cloud_setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for the ``autogluon.cloud`` setup API."""

import logging

import pytest

from autogluon.cloud import bootstrap, register, status, teardown
Expand All @@ -18,6 +20,13 @@ def isolated_config_dir(tmp_path, monkeypatch):
yield tmp_path


@pytest.fixture(autouse=True)
def _propagate_autogluon_logger(monkeypatch):
# autogluon.common's _add_stream_handler() sets propagate=False on the `autogluon` logger,
# which prevents caplog from seeing records. Flip it on for the duration of each test.
monkeypatch.setattr(logging.getLogger("autogluon"), "propagate", True)


def _register_default(backend="sagemaker"):
register(
role="arn:aws:iam::111122223333:role/x",
Expand All @@ -32,8 +41,9 @@ def _register_default(backend="sagemaker"):
# ---------------------------------------------------------------------------


def test_register_writes_file(capsys):
_register_default()
def test_register_writes_file(caplog):
with caplog.at_level("INFO", logger="autogluon.cloud.cloud_setup"):
_register_default()

cfg = load_config()
assert "sagemaker" in cfg.backends
Expand All @@ -42,7 +52,7 @@ def test_register_writes_file(capsys):
assert sage.bucket == "b1"
assert sage.region == "us-east-1"
assert sage.stack_name is None
assert "Saved AutoGluon-Cloud config for backend 'sagemaker'" in capsys.readouterr().out
assert any("Saved AutoGluon-Cloud config for backend 'sagemaker'" in r.message for r in caplog.records)


def test_register_overwrites_same_backend():
Expand Down Expand Up @@ -97,7 +107,7 @@ def test_register_records_stack_name_when_given():
# ---------------------------------------------------------------------------


def test_bootstrap_calls_cfn_then_registers(monkeypatch, capsys):
def test_bootstrap_calls_cfn_then_registers(monkeypatch, caplog):
class FakeSession:
region_name = "us-east-1"

Expand All @@ -113,17 +123,18 @@ def client(self, service):
lambda session, stack_name, backend: ("arn:aws:iam::123:role/r", "ag-cloud-bucket"),
)

bootstrap(backend="sagemaker", stack_name="my-stack")
with caplog.at_level("INFO", logger="autogluon.cloud.cloud_setup"):
bootstrap(backend="sagemaker", stack_name="my-stack")

cfg = load_config()
assert cfg.backends["sagemaker"].role_arn == "arn:aws:iam::123:role/r"
assert cfg.backends["sagemaker"].bucket == "ag-cloud-bucket"
assert cfg.backends["sagemaker"].stack_name == "my-stack"

out = capsys.readouterr().out
assert "Deploying CloudFormation stack 'my-stack'" in out
assert "account 123456789012" in out
assert "deployed" in out
messages = " ".join(r.message for r in caplog.records)
assert "Deploying CloudFormation stack 'my-stack'" in messages
assert "account 123456789012" in messages
assert "deployed" in messages


def test_bootstrap_returns_none(monkeypatch):
Expand Down Expand Up @@ -232,19 +243,21 @@ def client(self, service):
# ---------------------------------------------------------------------------


def test_teardown_without_config_is_noop(capsys):
assert teardown() is None
assert "nothing to tear down" in capsys.readouterr().out
def test_teardown_without_config_is_noop(caplog):
with caplog.at_level("WARNING", logger="autogluon.cloud.cloud_setup"):
assert teardown() is None
assert any("nothing to tear down" in r.message for r in caplog.records)


def test_teardown_no_stacks_just_removes_config(capsys):
def test_teardown_no_stacks_just_removes_config(caplog):
"""Backends registered without stack_name → only the config is removed."""
_register_default()
teardown()
with caplog.at_level("INFO", logger="autogluon.cloud.cloud_setup"):
teardown()
assert load_config() is None
out = capsys.readouterr().out
assert "no stack to delete" in out
assert "Removed config" in out
messages = " ".join(r.message for r in caplog.records)
assert "no stack to delete" in messages
assert "Removed config" in messages


def test_teardown_with_stack_deletes_each_backend(monkeypatch):
Expand Down Expand Up @@ -310,9 +323,9 @@ def test_teardown_specific_backend_keeps_others():
assert set(cfg.backends) == {"ray_aws"}


def test_teardown_unknown_backend_is_friendly(capsys):
def test_teardown_unknown_backend_is_friendly(caplog):
_register_default(backend="sagemaker")
teardown(backend="ray_aws") # not registered
out = capsys.readouterr().out
assert "not in config" in out
with caplog.at_level("WARNING", logger="autogluon.cloud.cloud_setup"):
teardown(backend="ray_aws") # not registered
assert any("not in config" in r.message for r in caplog.records)
assert load_config() is not None # nothing was removed
Loading