From 0c670c5b5d8db477a65ec73e29f6d89be2e47f20 Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sun, 28 Jun 2026 15:36:19 +0200 Subject: [PATCH] Fix docker scheduler state aggregation --- .../run/torchx_backend/schedulers/docker.py | 11 ++-- .../torchx_backend/schedulers/test_docker.py | 51 ++++++++++++------- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/nemo_run/run/torchx_backend/schedulers/docker.py b/nemo_run/run/torchx_backend/schedulers/docker.py index 4f68920c..1ce71abb 100644 --- a/nemo_run/run/torchx_backend/schedulers/docker.py +++ b/nemo_run/run/torchx_backend/schedulers/docker.py @@ -189,12 +189,11 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]: states.append(state) state = AppState.UNKNOWN - if any(is_terminal(state) for state in states): - if any(state == AppState.SUCCEEDED for state in states): - state = AppState.SUCCEEDED - else: - state = AppState.FAILED - elif len(states) > 0: + if any(state == AppState.FAILED for state in states): + state = AppState.FAILED + elif len(states) == len(req.containers) and all(state == AppState.SUCCEEDED for state in states): + state = AppState.SUCCEEDED + elif any(not is_terminal(state) for state in states): state = next(state for state in states if not is_terminal(state)) return DescribeAppResponse( diff --git a/test/run/torchx_backend/schedulers/test_docker.py b/test/run/torchx_backend/schedulers/test_docker.py index 551d8a60..329dac35 100644 --- a/test/run/torchx_backend/schedulers/test_docker.py +++ b/test/run/torchx_backend/schedulers/test_docker.py @@ -187,37 +187,54 @@ def test_describe_failed(docker_scheduler, docker_executor): assert len(response.roles) == 1 -@pytest.mark.xfail -def test_describe_failure_not_detected(docker_scheduler, docker_executor): +@pytest.mark.parametrize( + ("container_states", "expected_state"), + [ + ([AppState.SUCCEEDED, AppState.FAILED], AppState.FAILED), + ([AppState.SUCCEEDED, AppState.RUNNING], AppState.RUNNING), + ([AppState.SUCCEEDED, AppState.SUCCEEDED], AppState.SUCCEEDED), + ], +) +def test_describe_aggregates_container_states( + docker_scheduler, docker_executor, container_states, expected_state +): with ( mock.patch.object(DockerJobRequest, "load") as mock_load, mock.patch.object(DockerContainer, "get_container") as mock_get_container, mock.patch.object(PersistentDockerScheduler, "_get_app_state") as mock_get_app_state, + mock.patch.object( + PersistentDockerScheduler, "_docker_client", new_callable=mock.PropertyMock + ) as mock_docker_client, ): - container = DockerContainer( - name="test_role", - command=["test"], - executor=docker_executor, - extra_env={}, - ) + mock_docker_client.return_value = mock.Mock() + containers = [ + DockerContainer( + name="test_role", + command=["test"], + executor=docker_executor, + extra_env={}, + ), + DockerContainer( + name="test_role_2", + command=["test"], + executor=docker_executor, + extra_env={}, + ), + ] req = DockerJobRequest( id="test_session___test_role___test_container_id", executor=docker_executor, - containers=[container], + containers=containers, ) mock_load.return_value = req - mock_get_container.return_value = container - mock_get_app_state.return_value = None - status_file = os.path.join(req.executor.job_dir, f"status_{req.containers[0].name}.out") - - with open(status_file, "w") as f: - f.write(json.dumps({"exit_code": 1})) + mock_get_container.side_effect = containers + mock_get_app_state.side_effect = container_states response = docker_scheduler.describe(req.id) assert response is not None assert response.app_id == req.id - assert "SUCCEEDED" in str(response.state) - assert len(response.roles) == 1 + assert response.state == expected_state + assert len(response.roles) == 2 def test_save_and_get_job_dirs():