Skip to content
34 changes: 31 additions & 3 deletions django/middleware/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

"""

from hashlib import md5
import uuid
from django.conf import settings
from django.core.cache import DEFAULT_CACHE_ALIAS, caches
from django.utils.cache import (
Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(self, get_response):
self.page_timeout = None
self.key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX
self.cache_alias = settings.CACHE_MIDDLEWARE_ALIAS
self.key_group = None

@property
def cache(self):
Expand Down Expand Up @@ -115,8 +118,11 @@ def process_response(self, request, response):
return response
patch_response_headers(response, timeout)
if timeout and response.status_code == 200:

key_group = handle_key_group(self.key_group, self.cache)

cache_key = learn_cache_key(
request, response, timeout, self.key_prefix, cache=self.cache
request, response, timeout, self.key_prefix, key_group, cache=self.cache,
)
if hasattr(response, "render") and callable(response.render):
response.add_post_render_callback(
Expand All @@ -140,6 +146,7 @@ def __init__(self, get_response):
super().__init__(get_response)
self.key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX
self.cache_alias = settings.CACHE_MIDDLEWARE_ALIAS
self.key_group = None

@property
def cache(self):
Expand All @@ -148,14 +155,16 @@ def cache(self):
def process_request(self, request):
"""
Check whether the page is already cached and return the cached
version if available.
version if available, otherwise returns None.
"""
if request.method not in ("GET", "HEAD"):
request._cache_update_cache = False
return None # Don't bother checking the cache.

# try and get the cached GET response
cache_key = get_cache_key(request, self.key_prefix, "GET", cache=self.cache)
key_group = handle_key_group(self.key_group, self.cache)

cache_key = get_cache_key(request, key_group, self.key_prefix, "GET", cache=self.cache)
if cache_key is None:
request._cache_update_cache = True
return None # No cache information available, need to rebuild.
Expand All @@ -176,6 +185,18 @@ def process_request(self, request):
return response


def handle_key_group(key_group, cache):
group_cache_header = "decorators.group_caching"
group_key = f"{group_cache_header}.{key_group}"
result = cache.get(group_key)
if not result:
hash = md5(f"{group_key}.{uuid.uuid4()}".encode("ascii"), usedforsecurity=False).hexdigest()
cache.set(group_key, hash)
result = hash

return result


class CacheMiddleware(UpdateCacheMiddleware, FetchFromCacheMiddleware):
"""
Cache middleware that provides basic behavior for many simple sites.
Expand Down Expand Up @@ -205,6 +226,13 @@ def __init__(self, get_response, cache_timeout=None, page_timeout=None, **kwargs
self.cache_alias = cache_alias
except KeyError:
pass
try:
key_group = kwargs["key_group"]
if key_group is None:
key_group = ""
self.key_group = key_group
except KeyError:
pass

if cache_timeout is not None:
self.cache_timeout = cache_timeout
Expand Down
77 changes: 56 additions & 21 deletions django/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

cc_delim_re = _lazy_re_compile(r"\s*,\s*")

PATH_PREFIX = "views.decorators.cache"


def patch_cache_control(response, **kwargs):
"""
Expand Down Expand Up @@ -346,34 +348,40 @@ def _i18n_cache_key_suffix(request, cache_key):
return cache_key


def _generate_cache_key(request, method, headerlist, key_prefix):
"""Return a cache key from the headers given in the header list."""
def _generate_cache_key(request, method, headerlist, key_prefix, key_group: str):
"""
Return a cache key from the headers given in the header list.
"""
cache_type = "cache_page"

ctx = md5(usedforsecurity=False)
for header in headerlist:
value = request.META.get(header)
if value is not None:
ctx.update(value.encode())
url = md5(request.build_absolute_uri().encode("ascii"), usedforsecurity=False)
cache_key = "views.decorators.cache.cache_page.%s.%s.%s.%s" % (
key_prefix,
method,
url.hexdigest(),
ctx.hexdigest(),
)

url_hex = url.hexdigest()
ctx_hex = ctx.hexdigest()
cache_key = f"{PATH_PREFIX}.{cache_type}.{key_group}.{key_prefix}.{method}.{url_hex}.{ctx_hex}"

return _i18n_cache_key_suffix(request, cache_key)


def _generate_cache_header_key(key_prefix, request):
"""Return a cache key for the header cache."""
def _generate_cache_header_key(key_prefix: str, key_group: str, request) -> str:
"""
Return a cache key for the header cache.
"""
cache_type = "cache_header"

url = md5(request.build_absolute_uri().encode("ascii"), usedforsecurity=False)
cache_key = "views.decorators.cache.cache_header.%s.%s" % (
key_prefix,
url.hexdigest(),
)
url_hex = url.hexdigest()
cache_key = f"{PATH_PREFIX}.{cache_type}.{key_group}.{key_prefix}.{url_hex}"

return _i18n_cache_key_suffix(request, cache_key)


def get_cache_key(request, key_prefix=None, method="GET", cache=None):
def get_cache_key(request, key_group, key_prefix=None, method="GET", cache=None):
"""
Return a cache key based on the request URL and query. It can be used
in the request phase because it pulls the list of headers to take into
Expand All @@ -385,18 +393,27 @@ def get_cache_key(request, key_prefix=None, method="GET", cache=None):
"""
if key_prefix is None:
key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX
cache_key = _generate_cache_header_key(key_prefix, request)
cache_key = _generate_cache_header_key(key_prefix, key_group, request)
if cache is None:
cache = caches[settings.CACHE_MIDDLEWARE_ALIAS]
headerlist = cache.get(cache_key)
if headerlist is not None:
return _generate_cache_key(request, method, headerlist, key_prefix)
return _generate_cache_key(request, method, headerlist, key_prefix, key_group)
else:
return None


def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cache=None):
def learn_cache_key(
request, response, cache_timeout=None, key_prefix=None, key_group=None, cache=None
):
"""
:param request: The request object.
:param response: The response object.
:param cache_timeout: The cache timeout in seconds.
:param key_prefix: The cache key prefix.
:param key_group: The cache key group.
:param cache: The cache object.

Learn what headers to take into account for some request URL from the
response object. Store those headers in a global URL registry so that
later access to that URL will know what headers to take into account
Expand All @@ -412,9 +429,13 @@ def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cach
key_prefix = settings.CACHE_MIDDLEWARE_KEY_PREFIX
if cache_timeout is None:
cache_timeout = settings.CACHE_MIDDLEWARE_SECONDS
cache_key = _generate_cache_header_key(key_prefix, request)

# Generate the cache key for the header cache.
cache_key = _generate_cache_header_key(key_prefix, key_group, request)
if cache is None:
cache = caches[settings.CACHE_MIDDLEWARE_ALIAS]

# Handle the Vary header.
if response.has_header("Vary"):
is_accept_language_redundant = settings.USE_I18N
# If i18n is used, the generated cache key will be suffixed with the
Expand All @@ -428,16 +449,30 @@ def learn_cache_key(request, response, cache_timeout=None, key_prefix=None, cach
headerlist.append("HTTP_" + header)
headerlist.sort()
cache.set(cache_key, headerlist, cache_timeout)
return _generate_cache_key(request, request.method, headerlist, key_prefix)
return _generate_cache_key(
request, request.method, headerlist, key_prefix, key_group
)
else:
# if there is no Vary header, we still need a cache key
# for the request.build_absolute_uri()
cache.set(cache_key, [], cache_timeout)
return _generate_cache_key(request, request.method, [], key_prefix)
return _generate_cache_key(request, request.method, [], key_prefix, key_group)


def _to_tuple(s):
t = s.split("=", 1)
if len(t) == 2:
return t[0].lower(), t[1]
return t[0].lower(), True


def invalidate_cache_by_key_group(key_group: str, cache=None) -> None:
"""
Invalidates all cache entries for a given key group.
"""
if cache is None:
cache = caches[settings.CACHE_MIDDLEWARE_ALIAS]

# Get the group cache key and delete it.
group_cache_key = f"{PATH_PREFIX}.cache_group.{key_group}"
cache.delete(group_cache_key)
3 changes: 2 additions & 1 deletion django/views/decorators/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.utils.decorators import decorator_from_middleware_with_args


def cache_page(timeout, *, cache=None, key_prefix=None):
def cache_page(timeout, *, cache=None, key_prefix=None, key_group=None):
"""
Decorator for views that tries getting the page from the cache and
populates the cache if the page isn't in the cache yet.
Expand All @@ -25,6 +25,7 @@ def cache_page(timeout, *, cache=None, key_prefix=None):
page_timeout=timeout,
cache_alias=cache,
key_prefix=key_prefix,
key_group=key_group,
)


Expand Down
56 changes: 56 additions & 0 deletions tests/cache/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from django.utils import timezone, translation
from django.utils.cache import (
get_cache_key,
invalidate_cache_by_key_group,
learn_cache_key,
patch_cache_control,
patch_vary_headers,
Expand Down Expand Up @@ -2578,6 +2579,32 @@ def test_middleware(self):
self.assertIsNotNone(result)
self.assertEqual(result.content, b"Hello World 1")

def test_middleware_group(self):
group_name = "group1"
middleware_group1 = CacheMiddleware(hello_world_view, key_group=group_name)

request = self.factory.get("/view/")

# Put the request through the request middleware
result = middleware_group1.process_request(request)
self.assertIsNone(result) # Is None because there's no cache

response = hello_world_view(request, "1")

# Now put the response through the response middleware
response = middleware_group1.process_response(request, response)

# Repeating the request should result in a cache hit
result = middleware_group1.process_request(request)
self.assertIsNotNone(result)

# Invalidating the whole group
invalidate_key_group(group_name, middleware_group1.cache)

# Repeating the request should result in a cache non hit
result = middleware_group1.process_request(request)
self.assertIsNone(result)

def test_view_decorator(self):
# decorate the same view with different cache decorators
default_view = cache_page(3)(hello_world_view)
Expand Down Expand Up @@ -2938,3 +2965,32 @@ def test_all(self):
# .all() initializes all caches.
self.assertEqual(len(test_caches.all(initialized_only=True)), 2)
self.assertEqual(test_caches.all(), test_caches.all(initialized_only=True))


class TestCacheInvalidation(SimpleTestCase):
def test_view_cache_invalidation(self):
"""
Test that the view cache is correctly invalidated.
"""
# Set up the view. Decorate it with the cache_page decorator.
decorated_view = cache_page(timeout=3, key_group="test-key-group")(
hello_world_view
)

# Set up the request.
request = self.factory.get("/view/")

# Request the view for the first time.
response = decorated_view(request, "1")
self.assertEqual(response.content, b"Hello World 1")

# Request again -- hit the cache.
response = decorated_view(request, "2")
self.assertEqual(response.content, b"Hello World 1")

# Run the cache invalidation.
invalidate_cache_by_key_group(key_group="test-key-group")

# Request again -- should NOT hit the cache.
response = decorated_view(request, "3")
self.assertEqual(response.content, b"Hello World 3")