Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tornado/httputil.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,14 @@ def parse(cls, headers: str, *, _chars_are_bytes: bool = True) -> HTTPHeaders:
# MutableMapping abstract method implementations.

def __setitem__(self, name: str, value: str) -> None:
# `MutableMapping` lets any object reach the indexer; non-string keys
# used to leak an `AttributeError` from inside `_normalize_header`
# (which is `lru_cache`d and splits on `-`). Reject them up-front with a
# clear message; `__contains__` already short-circuits the same way.
if not isinstance(name, str):
raise TypeError(
"HTTPHeaders keys must be str, not %s" % type(name).__name__
)
norm_name = _normalize_header(name)
self._combined_cache[norm_name] = value
self._as_list[norm_name] = [value]
Expand All @@ -338,12 +346,20 @@ def __contains__(self, name: object) -> bool:
return norm_name in self._as_list

def __getitem__(self, name: str) -> str:
if not isinstance(name, str):
raise TypeError(
"HTTPHeaders keys must be str, not %s" % type(name).__name__
)
header = _normalize_header(name)
if header not in self._combined_cache:
self._combined_cache[header] = ",".join(self._as_list[header])
return self._combined_cache[header]

def __delitem__(self, name: str) -> None:
if not isinstance(name, str):
raise TypeError(
"HTTPHeaders keys must be str, not %s" % type(name).__name__
)
norm_name = _normalize_header(name)
del self._combined_cache[norm_name]
del self._as_list[norm_name]
Expand Down
42 changes: 42 additions & 0 deletions tornado/test/httputil_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,48 @@ def test_setdefault(self):
self.assertEqual(headers["quux"], "xyzzy")
self.assertEqual(sorted(headers.get_all()), [("Foo", "bar"), ("Quux", "xyzzy")])

def test_non_string_key_setitem_raises_type_error(self):
# HTTPHeaders indexes via _normalize_header, which is decorated with
# @lru_cache and only accepts str. Non-string keys used to leak an
# AttributeError out of the cache wrapper; now they raise TypeError
# at the call site so callers get a useful error.
headers = HTTPHeaders()
for bad in (1, 1.5, None, b"Foo", ("Foo",), object()):
with self.assertRaises(TypeError):
headers[bad] = "value"
# __setitem__ must not silently store partial state for the bad
# key, and must not corrupt the lru_cache for valid lookups.
self.assertEqual(len(headers), 0)
self.assertNotIn(bad, headers)

def test_non_string_key_getitem_raises_type_error(self):
headers = HTTPHeaders()
headers["Foo"] = "bar"
for bad in (1, 1.5, None, b"Foo", ("Foo",), object()):
with self.assertRaises(TypeError):
headers[bad]
# Pre-existing string key still reads back.
self.assertEqual(headers["Foo"], "bar")

def test_non_string_key_delitem_raises_type_error(self):
headers = HTTPHeaders()
headers["Foo"] = "bar"
for bad in (1, 1.5, None, b"Foo", ("Foo",), object()):
with self.assertRaises(TypeError):
del headers[bad]
# Pre-existing string entry must still be intact.
self.assertEqual(headers["Foo"], "bar")

def test_non_string_key_contains_returns_false(self):
# __contains__ already guarded against non-strings; this pins the
# behaviour so a future refactor of __setitem__/__getitem__ cannot
# regress it.
headers = HTTPHeaders()
headers["Foo"] = "bar"
for bad in (1, 1.5, None, b"Foo", ("Foo",), object()):
self.assertFalse(bad in headers)
self.assertIn("Foo", headers)

def test_string(self):
headers = HTTPHeaders()
headers.add("Foo", "1")
Expand Down