diff --git a/django/middleware/cache.py b/django/middleware/cache.py index 0fdffe1bbeee..e11bb0c63619 100644 --- a/django/middleware/cache.py +++ b/django/middleware/cache.py @@ -43,6 +43,8 @@ """ +from hashlib import md5 +import uuid from django.conf import settings from django.core.cache import DEFAULT_CACHE_ALIAS, caches from django.utils.cache import ( @@ -71,6 +73,7 @@ def __init__(self, get_response): self.page_timeout = None self.key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX self.cache_alias = settings.CACHE_MIDDLEWARE_ALIAS + self.key_group = None @property def cache(self): @@ -115,8 +118,11 @@ def process_response(self, request, response): return response patch_response_headers(response, timeout) if timeout and response.status_code == 200: + + key_group = handle_key_group(self.key_group, self.cache) + cache_key = learn_cache_key( - request, response, timeout, self.key_prefix, cache=self.cache + request, response, timeout, self.key_prefix, key_group, cache=self.cache, ) if hasattr(response, "render") and callable(response.render): response.add_post_render_callback( @@ -140,6 +146,7 @@ def __init__(self, get_response): super().__init__(get_response) self.key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX self.cache_alias = settings.CACHE_MIDDLEWARE_ALIAS + self.key_group = None @property def cache(self): @@ -148,14 +155,16 @@ def cache(self): def process_request(self, request): """ Check whether the page is already cached and return the cached - version if available. + version if available, otherwise returns None. """ if request.method not in ("GET", "HEAD"): request._cache_update_cache = False return None # Don't bother checking the cache. # try and get the cached GET response - cache_key = get_cache_key(request, self.key_prefix, "GET", cache=self.cache) + key_group = handle_key_group(self.key_group, self.cache) + + cache_key = get_cache_key(request, key_group, self.key_prefix, "GET", cache=self.cache) if cache_key is None: request._cache_update_cache = True return None # No cache information available, need to rebuild. @@ -176,6 +185,18 @@ def process_request(self, request): return response +def handle_key_group(key_group, cache): + group_cache_header = "decorators.group_caching" + group_key = f"{group_cache_header}.{key_group}" + result = cache.get(group_key) + if not result: + hash = md5(f"{group_key}.{uuid.uuid4()}".encode("ascii"), usedforsecurity=False).hexdigest() + cache.set(group_key, hash) + result = hash + + return result + + class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware): """ Cache middleware that provides basic behavior for many simple sites. @@ -205,6 +226,13 @@ def __init__(self, get_response, cache_timeout=None, page_timeout=None, **kwargs self.cache_alias = cache_alias except KeyError: pass + try: + key_group = kwargs["key_group"] + if key_group is None: + key_group = "" + self.key_group = key_group + except KeyError: + pass if cache_timeout is not None: self.cache_timeout = cache_timeout diff --git a/django/utils/cache.py b/django/utils/cache.py index cf797d0279a4..8184f0f60c71 100644 --- a/django/utils/cache.py +++ b/django/utils/cache.py @@ -29,6 +29,8 @@ cc_delim_re = _lazy_re_compile(r"\s*,\s*") +PATH_PREFIX = "views.decorators.cache" + def patch_cache_control(response, **kwargs): """ @@ -346,34 +348,40 @@ def _i18n_cache_key_suffix(request, cache_key): return cache_key -def _generate_cache_key(request, method, headerlist, key_prefix): - """Return a cache key from the headers given in the header list.""" +def _generate_cache_key(request, method, headerlist, key_prefix, key_group: str): + """ + Return a cache key from the headers given in the header list. + """ + cache_type = "cache_page" + ctx = md5(usedforsecurity=False) for header in headerlist: value = request.META.get(header) if value is not None: ctx.update(value.encode()) url = md5(request.build_absolute_uri().encode("ascii"), usedforsecurity=False) - cache_key = "views.decorators.cache.cache_page.%s.%s.%s.%s" % ( - key_prefix, - method, - url.hexdigest(), - ctx.hexdigest(), - ) + + url_hex = url.hexdigest() + ctx_hex = ctx.hexdigest() + cache_key = f"{PATH_PREFIX}.{cache_type}.{key_group}.{key_prefix}.{method}.{url_hex}.{ctx_hex}" + return _i18n_cache_key_suffix(request, cache_key) -def _generate_cache_header_key(key_prefix, request): - """Return a cache key for the header cache.""" +def _generate_cache_header_key(key_prefix: str, key_group: str, request) -> str: + """ + Return a cache key for the header cache. + """ + cache_type = "cache_header" + url = md5(request.build_absolute_uri().encode("ascii"), usedforsecurity=False) - cache_key = "views.decorators.cache.cache_header.%s.%s" % ( - key_prefix, - url.hexdigest(), - ) + url_hex = url.hexdigest() + cache_key = f"{PATH_PREFIX}.{cache_type}.{key_group}.{key_prefix}.{url_hex}" + return _i18n_cache_key_suffix(request, cache_key) -def get_cache_key(request, key_prefix=None, method="GET", cache=None): +def get_cache_key(request, key_group, key_prefix=None, method="GET", cache=None): """ Return a cache key based on the request URL and query. It can be used in the request phase because it pulls the list of headers to take into @@ -385,18 +393,27 @@ def get_cache_key(request, key_prefix=None, method="GET", cache=None): """ if key_prefix is None: key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX - cache_key = _generate_cache_header_key(key_prefix, request) + cache_key = _generate_cache_header_key(key_prefix, key_group, request) if cache is None: cache = caches[settings.CACHE_MIDDLEWARE_ALIAS] headerlist = cache.get(cache_key) if headerlist is not None: - return _generate_cache_key(request, method, headerlist, key_prefix) + return _generate_cache_key(request, method, headerlist, key_prefix, key_group) else: return None -def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cache=None): +def learn_cache_key( + request, response, cache_timeout=None, key_prefix=None, key_group=None, cache=None +): """ + :param request: The request object. + :param response: The response object. + :param cache_timeout: The cache timeout in seconds. + :param key_prefix: The cache key prefix. + :param key_group: The cache key group. + :param cache: The cache object. + Learn what headers to take into account for some request URL from the response object. Store those headers in a global URL registry so that later access to that URL will know what headers to take into account @@ -412,9 +429,13 @@ def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cach key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX if cache_timeout is None: cache_timeout = settings.CACHE_MIDDLEWARE_SECONDS - cache_key = _generate_cache_header_key(key_prefix, request) + + # Generate the cache key for the header cache. + cache_key = _generate_cache_header_key(key_prefix, key_group, request) if cache is None: cache = caches[settings.CACHE_MIDDLEWARE_ALIAS] + + # Handle the Vary header. if response.has_header("Vary"): is_accept_language_redundant = settings.USE_I18N # If i18n is used, the generated cache key will be suffixed with the @@ -428,12 +449,14 @@ def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cach headerlist.append("HTTP_" + header) headerlist.sort() cache.set(cache_key, headerlist, cache_timeout) - return _generate_cache_key(request, request.method, headerlist, key_prefix) + return _generate_cache_key( + request, request.method, headerlist, key_prefix, key_group + ) else: # if there is no Vary header, we still need a cache key # for the request.build_absolute_uri() cache.set(cache_key, [], cache_timeout) - return _generate_cache_key(request, request.method, [], key_prefix) + return _generate_cache_key(request, request.method, [], key_prefix, key_group) def _to_tuple(s): @@ -441,3 +464,15 @@ def _to_tuple(s): if len(t) == 2: return t[0].lower(), t[1] return t[0].lower(), True + + +def invalidate_cache_by_key_group(key_group: str, cache=None) -> None: + """ + Invalidates all cache entries for a given key group. + """ + if cache is None: + cache = caches[settings.CACHE_MIDDLEWARE_ALIAS] + + # Get the group cache key and delete it. + group_cache_key = f"{PATH_PREFIX}.cache_group.{key_group}" + cache.delete(group_cache_key) diff --git a/django/views/decorators/cache.py b/django/views/decorators/cache.py index aa1679baff1c..f59b8c35c29f 100644 --- a/django/views/decorators/cache.py +++ b/django/views/decorators/cache.py @@ -7,7 +7,7 @@ from django.utils.decorators import decorator_from_middleware_with_args -def cache_page(timeout, *, cache=None, key_prefix=None): +def cache_page(timeout, *, cache=None, key_prefix=None, key_group=None): """ Decorator for views that tries getting the page from the cache and populates the cache if the page isn't in the cache yet. @@ -25,6 +25,7 @@ def cache_page(timeout, *, cache=None, key_prefix=None): page_timeout=timeout, cache_alias=cache, key_prefix=key_prefix, + key_group=key_group, ) diff --git a/tests/cache/tests.py b/tests/cache/tests.py index fcce9579d48c..f1bad69e6684 100644 --- a/tests/cache/tests.py +++ b/tests/cache/tests.py @@ -56,6 +56,7 @@ from django.utils import timezone, translation from django.utils.cache import ( get_cache_key, + invalidate_cache_by_key_group, learn_cache_key, patch_cache_control, patch_vary_headers, @@ -2578,6 +2579,32 @@ def test_middleware(self): self.assertIsNotNone(result) self.assertEqual(result.content, b"Hello World 1") + def test_middleware_group(self): + group_name = "group1" + middleware_group1 = CacheMiddleware(hello_world_view, key_group=group_name) + + request = self.factory.get("/view/") + + # Put the request through the request middleware + result = middleware_group1.process_request(request) + self.assertIsNone(result) # Is None because there's no cache + + response = hello_world_view(request, "1") + + # Now put the response through the response middleware + response = middleware_group1.process_response(request, response) + + # Repeating the request should result in a cache hit + result = middleware_group1.process_request(request) + self.assertIsNotNone(result) + + # Invalidating the whole group + invalidate_key_group(group_name, middleware_group1.cache) + + # Repeating the request should result in a cache non hit + result = middleware_group1.process_request(request) + self.assertIsNone(result) + def test_view_decorator(self): # decorate the same view with different cache decorators default_view = cache_page(3)(hello_world_view) @@ -2938,3 +2965,32 @@ def test_all(self): # .all() initializes all caches. self.assertEqual(len(test_caches.all(initialized_only=True)), 2) self.assertEqual(test_caches.all(), test_caches.all(initialized_only=True)) + + +class TestCacheInvalidation(SimpleTestCase): + def test_view_cache_invalidation(self): + """ + Test that the view cache is correctly invalidated. + """ + # Set up the view. Decorate it with the cache_page decorator. + decorated_view = cache_page(timeout=3, key_group="test-key-group")( + hello_world_view + ) + + # Set up the request. + request = self.factory.get("/view/") + + # Request the view for the first time. + response = decorated_view(request, "1") + self.assertEqual(response.content, b"Hello World 1") + + # Request again -- hit the cache. + response = decorated_view(request, "2") + self.assertEqual(response.content, b"Hello World 1") + + # Run the cache invalidation. + invalidate_cache_by_key_group(key_group="test-key-group") + + # Request again -- should NOT hit the cache. + response = decorated_view(request, "3") + self.assertEqual(response.content, b"Hello World 3")