diff --git a/openwisp_controller/connection/tasks.py b/openwisp_controller/connection/tasks.py index d75bbde20..38a7ffaeb 100644 --- a/openwisp_controller/connection/tasks.py +++ b/openwisp_controller/connection/tasks.py @@ -2,7 +2,7 @@ import time import swapper -from celery import current_app, shared_task +from celery import current_app, current_task, shared_task from celery.exceptions import SoftTimeLimitExceeded from django.core.exceptions import ObjectDoesNotExist from django.utils.translation import gettext_lazy as _ @@ -20,12 +20,16 @@ def _is_update_in_progress(device_id): active = current_app.control.inspect().active() if not active: return False - # check if there's any other running task before adding it - for task_list in active.values(): - for task in task_list: - if task["name"] == _TASK_NAME and str(device_id) in task["args"]: - return True - return False + current_task_id = ( + current_task.request.id if current_task and current_task.request else None + ) + return any( + task["name"] == _TASK_NAME + and str(device_id) in task["args"] + and task.get("id") != current_task_id + for task_list in active.values() + for task in task_list + ) @shared_task diff --git a/openwisp_controller/connection/tests/test_tasks.py b/openwisp_controller/connection/tests/test_tasks.py index 700dbbf32..857beca8d 100644 --- a/openwisp_controller/connection/tests/test_tasks.py +++ b/openwisp_controller/connection/tests/test_tasks.py @@ -89,6 +89,84 @@ def test_launch_command_exception(self, *args): self.assertEqual(command.output, "Internal system error: test error\n") +class TestIsUpdateInProgress(CreateConnectionsMixin, TestCase): + @mock.patch("openwisp_controller.connection.tasks.current_task") + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_same_worker( + self, mocked_current_app, mocked_current_task + ): + device_id = 1 + mocked_current_task.request.id = "task123" + mocked_inspect = mock.Mock() + mocked_current_app.control.inspect.return_value = mocked_inspect + mocked_inspect.active.return_value = { + "worker1": [ + { + "name": "openwisp_controller.connection.tasks.update_config", + "args": ["1"], + "id": "task123", + } + ] + } + result = tasks._is_update_in_progress(device_id) + self.assertFalse(result) + + @mock.patch("openwisp_controller.connection.tasks.current_task") + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_different_worker( + self, mocked_current_app, mocked_current_task + ): + device_id = 1 + mocked_current_task.request.id = "task123" + mocked_inspect = mock.Mock() + mocked_current_app.control.inspect.return_value = mocked_inspect + mocked_inspect.active.return_value = { + "worker2": [ + { + "name": "openwisp_controller.connection.tasks.update_config", + "args": ["1"], + "id": "task456", + } + ] + } + result = tasks._is_update_in_progress(device_id) + self.assertTrue(result) + + @mock.patch("openwisp_controller.connection.tasks.current_task") + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_no_active_tasks( + self, mocked_current_app, mocked_current_task + ): + device_id = 1 + mocked_current_task.request.id = "task123" + mocked_inspect = mock.Mock() + mocked_current_app.control.inspect.return_value = mocked_inspect + mocked_inspect.active.return_value = {} + result = tasks._is_update_in_progress(device_id) + self.assertFalse(result) + + @mock.patch("openwisp_controller.connection.tasks.current_task") + @mock.patch("openwisp_controller.connection.tasks.current_app") + def test_is_update_in_progress_different_device( + self, mocked_current_app, mocked_current_task + ): + device_id = 1 + mocked_current_task.request.id = "task123" + mocked_inspect = mock.Mock() + mocked_current_app.control.inspect.return_value = mocked_inspect + mocked_inspect.active.return_value = { + "worker1": [ + { + "name": "openwisp_controller.connection.tasks.update_config", + "args": ["2"], + "id": "task456", + } + ] + } + result = tasks._is_update_in_progress(device_id) + self.assertFalse(result) + + class TestTransactionTasks( TestRegistrationMixin, CreateConnectionsMixin, TransactionTestCase ):