Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This file is used to list changes made in each version of the aws-parallelcluste

**CHANGES**
- Direct users to slurm_resume log to see EC2 error codes if no instances are launched.
- Emit metric `ClustermgtdHeartbeat` to signal clustermgtd heartbeat.

3.14.1
------
Expand Down
93 changes: 93 additions & 0 deletions src/slurm_plugin/cloudwatch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.

import logging
from datetime import datetime, timezone
from typing import Dict, List, Optional

import boto3
from botocore.config import Config

logger = logging.getLogger(__name__)

METRICS_NAMESPACE = "ParallelCluster"
METRICS_DIMENSION_CLUSTER_NAME = "ClusterName"
METRICS_DIMENSION_INSTANCE_ID = "InstanceId"


class CloudWatchMetricsPublisher:
"""Class for publishing metrics to CloudWatch."""

def __init__(self, region: str, cluster_name: str, instance_id: str, boto3_config: Config):
"""
Initialize CloudWatchMetricsPublisher.

Args:
region: AWS region
cluster_name: Name of the ParallelCluster cluster
instance_id: EC2 instance ID to include in metric dimensions
boto3_config: Boto3 configuration for retries and proxies
"""
self._region = region
self._cluster_name = cluster_name
self._instance_id = instance_id
self._boto3_config = boto3_config
self._cloudwatch_client = None

@property
def cloudwatch_client(self):
"""Lazy initialization of CloudWatch client."""
if self._cloudwatch_client is None:
self._cloudwatch_client = boto3.client("cloudwatch", region_name=self._region, config=self._boto3_config)
return self._cloudwatch_client

def put_metric(
self,
metric_name: str,
value: float,
unit: str = "Count",
additional_dimensions: Optional[List[Dict[str, str]]] = None,
):
"""
Publish a metric to CloudWatch.

Automatically sets timestamp and includes ClusterName as a dimension.

Args:
metric_name: Name of the metric to publish
value: Metric value
unit: CloudWatch unit (default: "Count")
additional_dimensions: Optional list of additional dimensions [{"Name": "...", "Value": "..."}]
"""
dimensions = [
{"Name": METRICS_DIMENSION_CLUSTER_NAME, "Value": self._cluster_name},
{"Name": METRICS_DIMENSION_INSTANCE_ID, "Value": self._instance_id},
]
if additional_dimensions:
dimensions.extend(additional_dimensions)

try:
self.cloudwatch_client.put_metric_data(
Namespace=METRICS_NAMESPACE,
MetricData=[
{
"MetricName": metric_name,
"Dimensions": dimensions,
"Timestamp": datetime.now(tz=timezone.utc),
"Value": value,
"Unit": unit,
}
],
)
logger.debug("Published metric %s with value %s", metric_name, value)
except Exception as e:
logger.error("Failed to publish metric %s: %s", metric_name, e)
17 changes: 17 additions & 0 deletions src/slurm_plugin/clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from common.utils import check_command_output, read_json, sleep_remaining_loop_time, time_is_up, wait_remaining_time
from retrying import retry
from slurm_plugin.capacity_block_manager import CapacityBlockManager
from slurm_plugin.cloudwatch_utils import CloudWatchMetricsPublisher
from slurm_plugin.cluster_event_publisher import ClusterEventPublisher
from slurm_plugin.common import TIMESTAMP_FORMAT, ScalingStrategy, log_exception, print_with_count
from slurm_plugin.console_logger import ConsoleLogger
Expand All @@ -60,6 +61,7 @@
LOOP_TIME = 60
CONSOLE_OUTPUT_WAIT_TIME = 5 * 60
MAXIMUM_TASK_BACKLOG = 100
CW_METRICS_HEARTBEAT = "ClustermgtdHeartbeat"
log = logging.getLogger(__name__)
compute_logger = log.getChild("console_output")
event_logger = log.getChild("events")
Expand Down Expand Up @@ -401,6 +403,7 @@ def __init__(self, config):
self._event_publisher = None
self._partition_nodelist_mapping_instance = None
self._capacity_block_manager = None
self._metrics_publisher = None
self.set_config(config)

def set_config(self, config: ClustermgtdConfig):
Expand All @@ -426,6 +429,7 @@ def set_config(self, config: ClustermgtdConfig):
self._instance_manager = self._initialize_instance_manager(config)
self._console_logger = self._initialize_console_logger(config)
self._capacity_block_manager = self._initialize_capacity_block_manager(config)
self._metrics_publisher = self._initialize_metrics_publisher(config)

def shutdown(self):
if self._task_executor:
Expand Down Expand Up @@ -480,6 +484,16 @@ def _initialize_capacity_block_manager(config):
region=config.region, fleet_config=config.fleet_config, boto3_config=config.boto3_config
)

@staticmethod
def _initialize_metrics_publisher(config):
"""Initialize CloudWatch metrics publisher."""
return CloudWatchMetricsPublisher(
region=config.region,
cluster_name=config.cluster_name,
instance_id=config.head_node_instance_id,
boto3_config=config.boto3_config,
)

def _update_compute_fleet_status(self, status):
log.info("Updating compute fleet status from %s to %s", self._compute_fleet_status, status)
self._compute_fleet_status_manager.update_status(status)
Expand Down Expand Up @@ -574,6 +588,9 @@ def manage_cluster(self):
# Write clustermgtd heartbeat to file
self._write_timestamp_to_file()

# Publish heartbeat metric to CloudWatch
self._metrics_publisher.put_metric(metric_name=CW_METRICS_HEARTBEAT, value=1)

def _write_timestamp_to_file(self):
"""Write timestamp into shared file so compute nodes can determine if head node is online."""
# Make clustermgtd heartbeat readable to all users
Expand Down
175 changes: 175 additions & 0 deletions tests/slurm_plugin/test_cloudwatch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with
# the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.

import logging
from datetime import datetime, timezone
from unittest.mock import MagicMock

import pytest
from assertpy import assert_that
from botocore.config import Config
from botocore.exceptions import ClientError
from slurm_plugin.cloudwatch_utils import METRICS_NAMESPACE, CloudWatchMetricsPublisher


class TestCloudWatchMetricsPublisher:
"""Tests for CloudWatchMetricsPublisher class."""

@pytest.fixture
def boto3_config(self):
return Config(retries={"max_attempts": 1, "mode": "standard"})

@pytest.fixture
def metrics_publisher(self, boto3_config):
return CloudWatchMetricsPublisher(
region="us-east-1",
cluster_name="test-cluster",
instance_id="i-1234567890abcdef0",
boto3_config=boto3_config,
)

def test_init(self, metrics_publisher, boto3_config):
"""Test CloudWatchMetricsPublisher initialization."""
assert_that(metrics_publisher._region).is_equal_to("us-east-1")
assert_that(metrics_publisher._cluster_name).is_equal_to("test-cluster")
assert_that(metrics_publisher._boto3_config).is_equal_to(boto3_config)
assert_that(metrics_publisher._instance_id).is_equal_to("i-1234567890abcdef0")
assert_that(metrics_publisher._cloudwatch_client).is_none()

def test_cloudwatch_client_lazy_initialization(self, metrics_publisher, mocker):
"""Test that CloudWatch client is lazily initialized."""
mock_client = MagicMock()
mock_boto3 = mocker.patch("slurm_plugin.cloudwatch_utils.boto3")
mock_boto3.client.return_value = mock_client

# First access should create the client
client = metrics_publisher.cloudwatch_client
assert_that(client).is_equal_to(mock_client)
mock_boto3.client.assert_called_once_with(
"cloudwatch",
region_name="us-east-1",
config=metrics_publisher._boto3_config,
)

# Second access should return the cached client
mock_boto3.client.reset_mock()
client2 = metrics_publisher.cloudwatch_client
assert_that(client2).is_equal_to(mock_client)
mock_boto3.client.assert_not_called()

@pytest.mark.parametrize(
"metric_name, value, unit, additional_dimensions, expected_dimensions",
[
pytest.param(
"TestMetric",
42,
"Count",
None,
[
{"Name": "ClusterName", "Value": "test-cluster"},
{"Name": "InstanceId", "Value": "i-1234567890abcdef0"},
],
id="basic",
),
pytest.param(
"HeadNodeDaemonHeartbeat",
1,
"Count",
[{"Name": "DaemonName", "Value": "clustermgtd"}],
[
{"Name": "ClusterName", "Value": "test-cluster"},
{"Name": "InstanceId", "Value": "i-1234567890abcdef0"},
{"Name": "DaemonName", "Value": "clustermgtd"},
],
id="with_additional_dimension",
),
pytest.param(
"LatencyMetric",
150.5,
"Milliseconds",
None,
[
{"Name": "ClusterName", "Value": "test-cluster"},
{"Name": "InstanceId", "Value": "i-1234567890abcdef0"},
],
id="with_custom_unit",
),
pytest.param(
"CustomMetric",
100,
"Count",
[
{"Name": "DaemonName", "Value": "clustermgtd"},
{"Name": "NodeType", "Value": "HeadNode"},
],
[
{"Name": "ClusterName", "Value": "test-cluster"},
{"Name": "InstanceId", "Value": "i-1234567890abcdef0"},
{"Name": "DaemonName", "Value": "clustermgtd"},
{"Name": "NodeType", "Value": "HeadNode"},
],
id="with_multiple_additional_dimensions",
),
],
)
def test_put_metric(
self,
metrics_publisher,
mocker,
metric_name: str,
value: float,
unit: str,
additional_dimensions: list,
expected_dimensions: list,
):
"""Test put_metric with various parameter combinations."""
mock_client = MagicMock()
metrics_publisher._cloudwatch_client = mock_client
mock_datetime = mocker.patch("slurm_plugin.cloudwatch_utils.datetime")
fixed_time = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
mock_datetime.now.return_value = fixed_time

metrics_publisher.put_metric(
metric_name=metric_name,
value=value,
unit=unit,
additional_dimensions=additional_dimensions,
)

mock_client.put_metric_data.assert_called_once_with(
Namespace=METRICS_NAMESPACE,
MetricData=[
{
"MetricName": metric_name,
"Dimensions": expected_dimensions,
"Timestamp": fixed_time,
"Value": value,
"Unit": unit,
}
],
)

def test_put_metric_handles_exception(self, metrics_publisher, caplog):
"""Test that put_metric handles exceptions gracefully."""
mock_client = MagicMock()
mock_client.put_metric_data.side_effect = ClientError(
{"Error": {"Code": "WHATEVER_CODE", "Message": "WHATEVER_MESSAGE"}},
"PutMetricData",
)
metrics_publisher._cloudwatch_client = mock_client

with caplog.at_level(logging.WARNING):
# Should not raise exception
metrics_publisher.put_metric(metric_name="WHATEVER_METRIC_NAME", value=1)

assert_that(caplog.text).matches(
r"Failed to publish metric WHATEVER_METRIC_NAME:.*WHATEVER_CODE.*WHATEVER_MESSAGE"
)
10 changes: 10 additions & 0 deletions tests/slurm_plugin/test_clustermgtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,14 @@ def test_manage_cluster(
fleet_config={},
)
mocker.patch("time.sleep")
cloudwatch_metrics_publisher_mock = mocker.patch("slurm_plugin.clustermgtd.CloudWatchMetricsPublisher")
cluster_manager = ClusterManager(mock_sync_config)
cloudwatch_metrics_publisher_mock.assert_called_once_with(
region="us-east-2",
cluster_name="hit-test",
instance_id="i-instance-id",
boto3_config=mock_sync_config.boto3_config,
)
cluster_manager._current_time = "current_time"
cluster_manager._static_nodes_in_replacement = {}
# Set up function mocks
Expand All @@ -1670,6 +1677,7 @@ def test_manage_cluster(
get_ec2_instances_mock = mocker.patch.object(
ClusterManager, "_get_ec2_instances", autospec=True, return_value=mock_cluster_instances
)
metrics_publisher_mock = cloudwatch_metrics_publisher_mock.return_value
get_node_info_with_retry_mock = mocker.patch.object(
ClusterManager,
"_get_node_info_with_retry",
Expand All @@ -1683,6 +1691,7 @@ def test_manage_cluster(
# Assert function calls
initialize_instance_manager_mock.assert_called_once()
write_timestamp_to_file_mock.assert_called_once()
metrics_publisher_mock.put_metric.assert_called_once_with(metric_name="ClustermgtdHeartbeat", value=1)
compute_fleet_status_manager_mock.get_status.assert_called_once()
if disable_cluster_management:
perform_health_check_actions_mock.assert_not_called()
Expand Down Expand Up @@ -2255,6 +2264,7 @@ def test_manage_cluster_boto3(
boto3_stubber("ec2", mocked_boto3_request)
mocker.patch("slurm_plugin.clustermgtd.datetime").now.return_value = datetime(2020, 1, 2, 0, 0, 0)
mocker.patch("slurm_plugin.clustermgtd.read_json", side_effect=[FLEET_CONFIG, LAUNCH_OVERRIDES, LAUNCH_OVERRIDES])
mocker.patch("slurm_plugin.clustermgtd.CloudWatchMetricsPublisher")
sync_config = ClustermgtdConfig(test_datadir / config_file)
sync_config.launch_overrides = {"dynamic": {"c5.xlarge": {"InstanceType": "t2.micro"}}}
cluster_manager = ClusterManager(sync_config)
Expand Down