diff --git a/openwisp_controller/config/tests/test_config.py b/openwisp_controller/config/tests/test_config.py index 10bedd043..99d73b798 100644 --- a/openwisp_controller/config/tests/test_config.py +++ b/openwisp_controller/config/tests/test_config.py @@ -984,7 +984,8 @@ def test_certificate_renew_invalidates_checksum_cache(self): vpnclient_cert.renew() # An additional call from cache invalidation of # DeviceGroupCommonName View - self.assertEqual(mocked_delete.call_count, 3) + # +1 call from _release_update_config_lock releasing the per-device lock + self.assertEqual(mocked_delete.call_count, 4) del config.backend_instance self.assertNotEqual(config.get_cached_checksum(), old_checksum) config.refresh_from_db() diff --git a/openwisp_controller/connection/tasks.py b/openwisp_controller/connection/tasks.py index d75bbde20..73fcdcc6a 100644 --- a/openwisp_controller/connection/tasks.py +++ b/openwisp_controller/connection/tasks.py @@ -1,9 +1,11 @@ import logging import time +import uuid import swapper -from celery import current_app, shared_task +from celery import shared_task from celery.exceptions import SoftTimeLimitExceeded +from django.core.cache import cache from django.core.exceptions import ObjectDoesNotExist from django.utils.translation import gettext_lazy as _ from swapper import load_model @@ -13,19 +15,36 @@ from .exceptions import NoWorkingDeviceConnectionError logger = logging.getLogger(__name__) -_TASK_NAME = "openwisp_controller.connection.tasks.update_config" +_UPDATE_CONFIG_LOCK_KEY = "ow_update_config_{device_id}" +# Lock timeout (in seconds) acts as a safety net to release the lock +# in case the task crashes without proper cleanup. +_UPDATE_CONFIG_LOCK_TIMEOUT = 300 -def _is_update_in_progress(device_id): - active = current_app.control.inspect().active() - if not active: - return False - # check if there's any other running task before adding it - for task_list in active.values(): - for task in task_list: - if task["name"] == _TASK_NAME and str(device_id) in task["args"]: - return True - return False +def _acquire_update_config_lock(device_id): + """ + Attempts to atomically acquire a per-device lock using the Django cache. + Returns a unique token string if the lock was acquired, None otherwise. + The token must be passed to _release_update_config_lock to ensure + only the lock owner can release it. + """ + lock_key = _UPDATE_CONFIG_LOCK_KEY.format(device_id=device_id) + token = str(uuid.uuid4()) + # cache.add is atomic: returns True only if the key doesn't already exist + if cache.add(lock_key, token, timeout=_UPDATE_CONFIG_LOCK_TIMEOUT): + return token + return None + + +def _release_update_config_lock(device_id, token): + """ + Releases the per-device update_config lock only if the caller + owns it (i.e. the stored token matches). + """ + lock_key = _UPDATE_CONFIG_LOCK_KEY.format(device_id=device_id) + stored_token = cache.get(lock_key) + if stored_token == token: + cache.delete(lock_key) @shared_task @@ -48,15 +67,26 @@ def update_config(device_id): except ObjectDoesNotExist as e: logger.warning(f'update_config("{device_id}") failed: {e}') return - if _is_update_in_progress(device_id): + lock_token = _acquire_update_config_lock(device_id) + if not lock_token: + logger.info( + f"update_config for device {device_id} is already in progress, skipping" + ) return try: - device_conn = DeviceConnection.get_working_connection(device) - except NoWorkingDeviceConnectionError: - return - else: - logger.info(f"Updating {device} (pk: {device_id})") - device_conn.update_config() + try: + device_conn = DeviceConnection.get_working_connection(device) + except NoWorkingDeviceConnectionError as e: + logger.warning( + f"update_config for device {device_id}: " + f"DeviceConnection.get_working_connection failed: {e}" + ) + return + else: + logger.info(f"Updating {device} (pk: {device_id})") + device_conn.update_config() + finally: + _release_update_config_lock(device_id, lock_token) # task timeout is SSH_COMMAND_TIMEOUT plus a 20% margin diff --git a/openwisp_controller/connection/tests/test_models.py b/openwisp_controller/connection/tests/test_models.py index 14693dfbd..fccf88247 100644 --- a/openwisp_controller/connection/tests/test_models.py +++ b/openwisp_controller/connection/tests/test_models.py @@ -21,7 +21,11 @@ ) from ..exceptions import NoWorkingDeviceConnectionError from ..signals import is_working_changed -from ..tasks import _TASK_NAME, update_config +from ..tasks import ( + _acquire_update_config_lock, + _release_update_config_lock, + update_config, +) from .utils import CreateConnectionsMixin Config = load_model("config", "Config") @@ -1026,20 +1030,19 @@ def _assert_applying_conf_test_command(mocked_exec): @mock.patch.object(DeviceConnection, "update_config") @mock.patch.object(DeviceConnection, "get_working_connection") def test_device_update_config_in_progress( - self, mocked_get_working_connection, update_config, mocked_sleep + self, mocked_get_working_connection, mocked_update_config, mocked_sleep ): conf = self._prepare_conf_object() - with mock.patch("celery.app.control.Inspect.active") as mocked_active: - mocked_active.return_value = { - "task": [{"name": _TASK_NAME, "args": [str(conf.device.pk)]}] - } + with mock.patch( + "openwisp_controller.connection.tasks._acquire_update_config_lock", + return_value=None, + ): conf.config = {"general": {"timezone": "UTC"}} conf.full_clean() conf.save() - mocked_active.assert_called_once() mocked_get_working_connection.assert_not_called() - update_config.assert_not_called() + mocked_update_config.assert_not_called() @mock.patch("time.sleep") @mock.patch.object(DeviceConnection, "update_config") @@ -1052,16 +1055,46 @@ def test_device_update_config_not_in_progress( conf.device.deviceconnection_set.first() ) - with mock.patch("celery.app.control.Inspect.active") as mocked_active: - mocked_active.return_value = { - "task": [{"name": _TASK_NAME, "args": ["..."]}] - } + with mock.patch( + "openwisp_controller.connection.tasks._acquire_update_config_lock", + return_value="fake-lock-token", + ), mock.patch( + "openwisp_controller.connection.tasks._release_update_config_lock", + ) as mocked_release: conf.config = {"general": {"timezone": "UTC"}} conf.full_clean() conf.save() - mocked_active.assert_called_once() mocked_get_working_connection.assert_called_once() mocked_update_config.assert_called_once() + mocked_release.assert_called_once_with( + str(conf.device.pk), "fake-lock-token" + ) + + def test_acquire_update_config_lock(self): + """Test that the lock can be acquired and prevents duplicate acquisition.""" + device_id = "test-device-id" + # First acquisition should succeed and return a token + token = _acquire_update_config_lock(device_id) + self.addCleanup(_release_update_config_lock, device_id, token) + self.assertIsNotNone(token) + # Second acquisition should fail (lock already held) + self.assertIsNone(_acquire_update_config_lock(device_id)) + # After releasing with correct token, acquisition should succeed again + _release_update_config_lock(device_id, token) + token2 = _acquire_update_config_lock(device_id) + self.addCleanup(_release_update_config_lock, device_id, token2) + self.assertIsNotNone(token2) + + def test_release_update_config_lock_wrong_token(self): + """Only the lock owner can release the lock.""" + device_id = "test-device-id" + token = _acquire_update_config_lock(device_id) + self.addCleanup(_release_update_config_lock, device_id, token) + self.assertIsNotNone(token) + # Releasing with wrong token should not delete the lock + _release_update_config_lock(device_id, "wrong-token") + # Lock should still be held + self.assertIsNone(_acquire_update_config_lock(device_id)) @mock.patch(_connect_path) def test_schedule_command_called(self, connect_mocked):