diff --git a/CHANGELOG.md b/CHANGELOG.md index 54899e0b..d9d22a69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This file is used to list changes made in each version of the aws-parallelcluste **CHANGES** - Direct users to slurm_resume log to see EC2 error codes if no instances are launched. +- Emit metric `ClustermgtdHeartbeat` to signal clustermgtd heartbeat. 3.14.1 ------ diff --git a/src/slurm_plugin/cloudwatch_utils.py b/src/slurm_plugin/cloudwatch_utils.py new file mode 100644 index 00000000..80041b4d --- /dev/null +++ b/src/slurm_plugin/cloudwatch_utils.py @@ -0,0 +1,93 @@ +# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from datetime import datetime, timezone +from typing import Dict, List, Optional + +import boto3 +from botocore.config import Config + +logger = logging.getLogger(__name__) + +METRICS_NAMESPACE = "ParallelCluster" +METRICS_DIMENSION_CLUSTER_NAME = "ClusterName" +METRICS_DIMENSION_INSTANCE_ID = "InstanceId" + + +class CloudWatchMetricsPublisher: + """Class for publishing metrics to CloudWatch.""" + + def __init__(self, region: str, cluster_name: str, instance_id: str, boto3_config: Config): + """ + Initialize CloudWatchMetricsPublisher. + + Args: + region: AWS region + cluster_name: Name of the ParallelCluster cluster + instance_id: EC2 instance ID to include in metric dimensions + boto3_config: Boto3 configuration for retries and proxies + """ + self._region = region + self._cluster_name = cluster_name + self._instance_id = instance_id + self._boto3_config = boto3_config + self._cloudwatch_client = None + + @property + def cloudwatch_client(self): + """Lazy initialization of CloudWatch client.""" + if self._cloudwatch_client is None: + self._cloudwatch_client = boto3.client("cloudwatch", region_name=self._region, config=self._boto3_config) + return self._cloudwatch_client + + def put_metric( + self, + metric_name: str, + value: float, + unit: str = "Count", + additional_dimensions: Optional[List[Dict[str, str]]] = None, + ): + """ + Publish a metric to CloudWatch. + + Automatically sets timestamp and includes ClusterName as a dimension. + + Args: + metric_name: Name of the metric to publish + value: Metric value + unit: CloudWatch unit (default: "Count") + additional_dimensions: Optional list of additional dimensions [{"Name": "...", "Value": "..."}] + """ + dimensions = [ + {"Name": METRICS_DIMENSION_CLUSTER_NAME, "Value": self._cluster_name}, + {"Name": METRICS_DIMENSION_INSTANCE_ID, "Value": self._instance_id}, + ] + if additional_dimensions: + dimensions.extend(additional_dimensions) + + try: + self.cloudwatch_client.put_metric_data( + Namespace=METRICS_NAMESPACE, + MetricData=[ + { + "MetricName": metric_name, + "Dimensions": dimensions, + "Timestamp": datetime.now(tz=timezone.utc), + "Value": value, + "Unit": unit, + } + ], + ) + logger.debug("Published metric %s with value %s", metric_name, value) + except Exception as e: + logger.error("Failed to publish metric %s: %s", metric_name, e) diff --git a/src/slurm_plugin/clustermgtd.py b/src/slurm_plugin/clustermgtd.py index e9a217bb..5fabdd7e 100644 --- a/src/slurm_plugin/clustermgtd.py +++ b/src/slurm_plugin/clustermgtd.py @@ -41,6 +41,7 @@ from common.utils import check_command_output, read_json, sleep_remaining_loop_time, time_is_up, wait_remaining_time from retrying import retry from slurm_plugin.capacity_block_manager import CapacityBlockManager +from slurm_plugin.cloudwatch_utils import CloudWatchMetricsPublisher from slurm_plugin.cluster_event_publisher import ClusterEventPublisher from slurm_plugin.common import TIMESTAMP_FORMAT, ScalingStrategy, log_exception, print_with_count from slurm_plugin.console_logger import ConsoleLogger @@ -60,6 +61,7 @@ LOOP_TIME = 60 CONSOLE_OUTPUT_WAIT_TIME = 5 * 60 MAXIMUM_TASK_BACKLOG = 100 +CW_METRICS_HEARTBEAT = "ClustermgtdHeartbeat" log = logging.getLogger(__name__) compute_logger = log.getChild("console_output") event_logger = log.getChild("events") @@ -401,6 +403,7 @@ def __init__(self, config): self._event_publisher = None self._partition_nodelist_mapping_instance = None self._capacity_block_manager = None + self._metrics_publisher = None self.set_config(config) def set_config(self, config: ClustermgtdConfig): @@ -426,6 +429,7 @@ def set_config(self, config: ClustermgtdConfig): self._instance_manager = self._initialize_instance_manager(config) self._console_logger = self._initialize_console_logger(config) self._capacity_block_manager = self._initialize_capacity_block_manager(config) + self._metrics_publisher = self._initialize_metrics_publisher(config) def shutdown(self): if self._task_executor: @@ -480,6 +484,16 @@ def _initialize_capacity_block_manager(config): region=config.region, fleet_config=config.fleet_config, boto3_config=config.boto3_config ) + @staticmethod + def _initialize_metrics_publisher(config): + """Initialize CloudWatch metrics publisher.""" + return CloudWatchMetricsPublisher( + region=config.region, + cluster_name=config.cluster_name, + instance_id=config.head_node_instance_id, + boto3_config=config.boto3_config, + ) + def _update_compute_fleet_status(self, status): log.info("Updating compute fleet status from %s to %s", self._compute_fleet_status, status) self._compute_fleet_status_manager.update_status(status) @@ -574,6 +588,9 @@ def manage_cluster(self): # Write clustermgtd heartbeat to file self._write_timestamp_to_file() + # Publish heartbeat metric to CloudWatch + self._metrics_publisher.put_metric(metric_name=CW_METRICS_HEARTBEAT, value=1) + def _write_timestamp_to_file(self): """Write timestamp into shared file so compute nodes can determine if head node is online.""" # Make clustermgtd heartbeat readable to all users diff --git a/tests/slurm_plugin/test_cloudwatch_utils.py b/tests/slurm_plugin/test_cloudwatch_utils.py new file mode 100644 index 00000000..054a449a --- /dev/null +++ b/tests/slurm_plugin/test_cloudwatch_utils.py @@ -0,0 +1,175 @@ +# Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with +# the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from datetime import datetime, timezone +from unittest.mock import MagicMock + +import pytest +from assertpy import assert_that +from botocore.config import Config +from botocore.exceptions import ClientError +from slurm_plugin.cloudwatch_utils import METRICS_NAMESPACE, CloudWatchMetricsPublisher + + +class TestCloudWatchMetricsPublisher: + """Tests for CloudWatchMetricsPublisher class.""" + + @pytest.fixture + def boto3_config(self): + return Config(retries={"max_attempts": 1, "mode": "standard"}) + + @pytest.fixture + def metrics_publisher(self, boto3_config): + return CloudWatchMetricsPublisher( + region="us-east-1", + cluster_name="test-cluster", + instance_id="i-1234567890abcdef0", + boto3_config=boto3_config, + ) + + def test_init(self, metrics_publisher, boto3_config): + """Test CloudWatchMetricsPublisher initialization.""" + assert_that(metrics_publisher._region).is_equal_to("us-east-1") + assert_that(metrics_publisher._cluster_name).is_equal_to("test-cluster") + assert_that(metrics_publisher._boto3_config).is_equal_to(boto3_config) + assert_that(metrics_publisher._instance_id).is_equal_to("i-1234567890abcdef0") + assert_that(metrics_publisher._cloudwatch_client).is_none() + + def test_cloudwatch_client_lazy_initialization(self, metrics_publisher, mocker): + """Test that CloudWatch client is lazily initialized.""" + mock_client = MagicMock() + mock_boto3 = mocker.patch("slurm_plugin.cloudwatch_utils.boto3") + mock_boto3.client.return_value = mock_client + + # First access should create the client + client = metrics_publisher.cloudwatch_client + assert_that(client).is_equal_to(mock_client) + mock_boto3.client.assert_called_once_with( + "cloudwatch", + region_name="us-east-1", + config=metrics_publisher._boto3_config, + ) + + # Second access should return the cached client + mock_boto3.client.reset_mock() + client2 = metrics_publisher.cloudwatch_client + assert_that(client2).is_equal_to(mock_client) + mock_boto3.client.assert_not_called() + + @pytest.mark.parametrize( + "metric_name, value, unit, additional_dimensions, expected_dimensions", + [ + pytest.param( + "TestMetric", + 42, + "Count", + None, + [ + {"Name": "ClusterName", "Value": "test-cluster"}, + {"Name": "InstanceId", "Value": "i-1234567890abcdef0"}, + ], + id="basic", + ), + pytest.param( + "HeadNodeDaemonHeartbeat", + 1, + "Count", + [{"Name": "DaemonName", "Value": "clustermgtd"}], + [ + {"Name": "ClusterName", "Value": "test-cluster"}, + {"Name": "InstanceId", "Value": "i-1234567890abcdef0"}, + {"Name": "DaemonName", "Value": "clustermgtd"}, + ], + id="with_additional_dimension", + ), + pytest.param( + "LatencyMetric", + 150.5, + "Milliseconds", + None, + [ + {"Name": "ClusterName", "Value": "test-cluster"}, + {"Name": "InstanceId", "Value": "i-1234567890abcdef0"}, + ], + id="with_custom_unit", + ), + pytest.param( + "CustomMetric", + 100, + "Count", + [ + {"Name": "DaemonName", "Value": "clustermgtd"}, + {"Name": "NodeType", "Value": "HeadNode"}, + ], + [ + {"Name": "ClusterName", "Value": "test-cluster"}, + {"Name": "InstanceId", "Value": "i-1234567890abcdef0"}, + {"Name": "DaemonName", "Value": "clustermgtd"}, + {"Name": "NodeType", "Value": "HeadNode"}, + ], + id="with_multiple_additional_dimensions", + ), + ], + ) + def test_put_metric( + self, + metrics_publisher, + mocker, + metric_name: str, + value: float, + unit: str, + additional_dimensions: list, + expected_dimensions: list, + ): + """Test put_metric with various parameter combinations.""" + mock_client = MagicMock() + metrics_publisher._cloudwatch_client = mock_client + mock_datetime = mocker.patch("slurm_plugin.cloudwatch_utils.datetime") + fixed_time = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + mock_datetime.now.return_value = fixed_time + + metrics_publisher.put_metric( + metric_name=metric_name, + value=value, + unit=unit, + additional_dimensions=additional_dimensions, + ) + + mock_client.put_metric_data.assert_called_once_with( + Namespace=METRICS_NAMESPACE, + MetricData=[ + { + "MetricName": metric_name, + "Dimensions": expected_dimensions, + "Timestamp": fixed_time, + "Value": value, + "Unit": unit, + } + ], + ) + + def test_put_metric_handles_exception(self, metrics_publisher, caplog): + """Test that put_metric handles exceptions gracefully.""" + mock_client = MagicMock() + mock_client.put_metric_data.side_effect = ClientError( + {"Error": {"Code": "WHATEVER_CODE", "Message": "WHATEVER_MESSAGE"}}, + "PutMetricData", + ) + metrics_publisher._cloudwatch_client = mock_client + + with caplog.at_level(logging.WARNING): + # Should not raise exception + metrics_publisher.put_metric(metric_name="WHATEVER_METRIC_NAME", value=1) + + assert_that(caplog.text).matches( + r"Failed to publish metric WHATEVER_METRIC_NAME:.*WHATEVER_CODE.*WHATEVER_MESSAGE" + ) diff --git a/tests/slurm_plugin/test_clustermgtd.py b/tests/slurm_plugin/test_clustermgtd.py index 36d4783a..3428ec59 100644 --- a/tests/slurm_plugin/test_clustermgtd.py +++ b/tests/slurm_plugin/test_clustermgtd.py @@ -1643,7 +1643,14 @@ def test_manage_cluster( fleet_config={}, ) mocker.patch("time.sleep") + cloudwatch_metrics_publisher_mock = mocker.patch("slurm_plugin.clustermgtd.CloudWatchMetricsPublisher") cluster_manager = ClusterManager(mock_sync_config) + cloudwatch_metrics_publisher_mock.assert_called_once_with( + region="us-east-2", + cluster_name="hit-test", + instance_id="i-instance-id", + boto3_config=mock_sync_config.boto3_config, + ) cluster_manager._current_time = "current_time" cluster_manager._static_nodes_in_replacement = {} # Set up function mocks @@ -1670,6 +1677,7 @@ def test_manage_cluster( get_ec2_instances_mock = mocker.patch.object( ClusterManager, "_get_ec2_instances", autospec=True, return_value=mock_cluster_instances ) + metrics_publisher_mock = cloudwatch_metrics_publisher_mock.return_value get_node_info_with_retry_mock = mocker.patch.object( ClusterManager, "_get_node_info_with_retry", @@ -1683,6 +1691,7 @@ def test_manage_cluster( # Assert function calls initialize_instance_manager_mock.assert_called_once() write_timestamp_to_file_mock.assert_called_once() + metrics_publisher_mock.put_metric.assert_called_once_with(metric_name="ClustermgtdHeartbeat", value=1) compute_fleet_status_manager_mock.get_status.assert_called_once() if disable_cluster_management: perform_health_check_actions_mock.assert_not_called() @@ -2255,6 +2264,7 @@ def test_manage_cluster_boto3( boto3_stubber("ec2", mocked_boto3_request) mocker.patch("slurm_plugin.clustermgtd.datetime").now.return_value = datetime(2020, 1, 2, 0, 0, 0) mocker.patch("slurm_plugin.clustermgtd.read_json", side_effect=[FLEET_CONFIG, LAUNCH_OVERRIDES, LAUNCH_OVERRIDES]) + mocker.patch("slurm_plugin.clustermgtd.CloudWatchMetricsPublisher") sync_config = ClustermgtdConfig(test_datadir / config_file) sync_config.launch_overrides = {"dynamic": {"c5.xlarge": {"InstanceType": "t2.micro"}}} cluster_manager = ClusterManager(sync_config)