diff --git a/openwisp_controller/connection/tasks.py b/openwisp_controller/connection/tasks.py index d75bbde20..56e6a7bb1 100644 --- a/openwisp_controller/connection/tasks.py +++ b/openwisp_controller/connection/tasks.py @@ -16,20 +16,25 @@ _TASK_NAME = "openwisp_controller.connection.tasks.update_config" -def _is_update_in_progress(device_id): +def _is_update_in_progress(device_id, current_task_id=None): active = current_app.control.inspect().active() if not active: return False # check if there's any other running task before adding it + # exclude the current task by comparing task IDs for task_list in active.values(): for task in task_list: - if task["name"] == _TASK_NAME and str(device_id) in task["args"]: + if ( + task["name"] == _TASK_NAME + and str(device_id) in task["args"] + and task["id"] != current_task_id + ): return True return False -@shared_task -def update_config(device_id): +@shared_task(bind=True) +def update_config(self, device_id): """ Launches the ``update_config()`` operation of a specific device in the background @@ -48,7 +53,7 @@ def update_config(device_id): except ObjectDoesNotExist as e: logger.warning(f'update_config("{device_id}") failed: {e}') return - if _is_update_in_progress(device_id): + if _is_update_in_progress(device_id, current_task_id=self.request.id): return try: device_conn = DeviceConnection.get_working_connection(device) diff --git a/openwisp_controller/connection/tests/test_models.py b/openwisp_controller/connection/tests/test_models.py index 14693dfbd..4766a3f7e 100644 --- a/openwisp_controller/connection/tests/test_models.py +++ b/openwisp_controller/connection/tests/test_models.py @@ -1,6 +1,7 @@ import socket from unittest import mock from unittest.mock import PropertyMock +from uuid import uuid4 import paramiko from django.contrib.auth.models import ContentType @@ -1026,20 +1027,56 @@ def _assert_applying_conf_test_command(mocked_exec): @mock.patch.object(DeviceConnection, "update_config") @mock.patch.object(DeviceConnection, "get_working_connection") def test_device_update_config_in_progress( - self, mocked_get_working_connection, update_config, mocked_sleep + self, mocked_get_working_connection, mocked_update_config, mocked_sleep ): conf = self._prepare_conf_object() - with mock.patch("celery.app.control.Inspect.active") as mocked_active: - mocked_active.return_value = { - "task": [{"name": _TASK_NAME, "args": [str(conf.device.pk)]}] - } - conf.config = {"general": {"timezone": "UTC"}} - conf.full_clean() - conf.save() - mocked_active.assert_called_once() - mocked_get_working_connection.assert_not_called() - update_config.assert_not_called() + with self.subTest("More than one update_config task active for the device"): + with mock.patch("celery.app.control.Inspect.active") as mocked_active: + mocked_active.return_value = { + "task": [ + { + "name": _TASK_NAME, + "args": [str(conf.device.pk)], + "id": str(uuid4()), + } + ] + } + conf.config = {"general": {"timezone": "UTC"}} + conf.full_clean() + conf.save() + mocked_active.assert_called_once() + mocked_get_working_connection.assert_not_called() + mocked_update_config.assert_not_called() + + Config.objects.update(status="applied") + mocked_get_working_connection.return_value = ( + conf.device.deviceconnection_set.first() + ) + with self.subTest("Only one task is active for the device"): + task_id = str(uuid4()) + with mock.patch( + "celery.app.control.Inspect.active" + ) as mocked_active, mock.patch( + "celery.app.task.Context.id", + new_callable=mock.PropertyMock, + return_value=task_id, + ): + mocked_active.return_value = { + "task": [ + { + "name": _TASK_NAME, + "args": [str(conf.device.pk)], + "id": task_id, + } + ] + } + conf.config = {"general": {"timezone": "Asia/Kolkata"}} + conf.full_clean() + conf.save() + mocked_active.assert_called_once() + mocked_get_working_connection.assert_called_once() + mocked_update_config.assert_called_once() @mock.patch("time.sleep") @mock.patch.object(DeviceConnection, "update_config") @@ -1053,8 +1090,15 @@ def test_device_update_config_not_in_progress( ) with mock.patch("celery.app.control.Inspect.active") as mocked_active: + # Mock a task running for a different device (args is different) mocked_active.return_value = { - "task": [{"name": _TASK_NAME, "args": ["..."]}] + "task": [ + { + "name": _TASK_NAME, + "args": ["another-device-id"], # Different device + "id": "different-task-id", + } + ] } conf.config = {"general": {"timezone": "UTC"}} conf.full_clean() diff --git a/openwisp_controller/connection/tests/test_tasks.py b/openwisp_controller/connection/tests/test_tasks.py index 700dbbf32..8ce18a694 100644 --- a/openwisp_controller/connection/tests/test_tasks.py +++ b/openwisp_controller/connection/tests/test_tasks.py @@ -21,6 +21,56 @@ class TestTasks(CreateConnectionsMixin, TestCase): "openwisp_controller.connection.base.models.AbstractDeviceConnection.connect" ) + def _get_mocked_celery_active(self, device_id, task_id=None): + return { + "worker1": [ + { + "name": tasks._TASK_NAME, + "args": [device_id], + "id": task_id or str(uuid.uuid4()), + } + ] + } + + def test_is_update_in_progress_same_task(self): + device_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + with mock.patch( + "celery.app.control.Inspect.active", + return_value=self._get_mocked_celery_active(device_id, task_id), + ): + result = tasks._is_update_in_progress(device_id, current_task_id=task_id) + self.assertEqual(result, False) + + def test_is_update_in_progress_different_task(self): + device_id = str(uuid.uuid4()) + current_task_id = str(uuid.uuid4()) + other_task_id = str(uuid.uuid4()) + with mock.patch( + "celery.app.control.Inspect.active", + return_value=self._get_mocked_celery_active(device_id, other_task_id), + ): + result = tasks._is_update_in_progress( + device_id, current_task_id=current_task_id + ) + self.assertEqual(result, True) + + def test_is_update_in_progress_no_tasks(self): + device_id = str(uuid.uuid4()) + with mock.patch("celery.app.control.Inspect.active", return_value={}): + result = tasks._is_update_in_progress(device_id) + self.assertEqual(result, False) + + def test_is_update_in_progress_different_device(self): + device_id = str(uuid.uuid4()) + other_device_id = str(uuid.uuid4()) + with mock.patch( + "celery.app.control.Inspect.active", + return_value=self._get_mocked_celery_active(other_device_id), + ): + result = tasks._is_update_in_progress(device_id) + self.assertEqual(result, False) + @mock.patch("logging.Logger.warning") @mock.patch("time.sleep") def test_update_config_missing_config(self, mocked_sleep, mocked_warning):