Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/12497.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a race condition in :py:class:`~aiohttp.TCPConnector` where closing the connector while a DNS resolution was in-flight could raise :py:exc:`AttributeError` instead of :py:exc:`~aiohttp.ClientConnectionError` -- by :user:`goingforstudying-ctrl`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ Gary Wilson Jr.
Gene Hoffman
Gennady Andreyev
Georges Dubus
goingforstudying-ctrl
Greg Holt
Gregory Haynes
Grigoriy Soldatov
Expand Down
7 changes: 5 additions & 2 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,10 +1005,10 @@ async def close(self, *, abort_ssl: bool = False) -> None:
- If ssl_shutdown_timeout=0: connections are aborted
- If ssl_shutdown_timeout>0: graceful shutdown is performed
"""
if self._resolver_owner:
await self._resolver.close()
# Use abort_ssl param if explicitly set, otherwise use ssl_shutdown_timeout default
await super().close(abort_ssl=abort_ssl or self._ssl_shutdown_timeout == 0)
if self._resolver_owner:
await self._resolver.close()

def _close_immediately(self, *, abort_ssl: bool = False) -> list[Awaitable[object]]:
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
Expand Down Expand Up @@ -1062,6 +1062,9 @@ async def _resolve_host(
for trace in traces:
await trace.send_dns_resolvehost_start(host)

if self._closed:
raise ClientConnectionError("Connector is closed")

res = await self._resolver.resolve(host, port, family=self._family)

if traces:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4698,3 +4698,49 @@ async def test_connect_tunnel_connection_release() -> None:

# Clean up to avoid resource warning
conn.close()


async def test_tcp_connector_close_race_condition() -> None:
"""Test closing TCPConnector while DNS resolution is in-flight."""
loop = asyncio.get_running_loop()
resolve_started = loop.create_future()
close_started = loop.create_future()

class FakeResolver(AbstractResolver):
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> list[ResolveResult]:
resolve_started.set_result(None)
await close_started
return [
{
"hostname": host,
"host": host,
"port": port,
"family": family,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}
]

async def close(self) -> None:
assert False

connector = TCPConnector(use_dns_cache=False, resolver=FakeResolver())

async def resolve_host() -> None:
# The in-flight resolve should complete normally since close()
# happens after the resolver returns
result = await connector._resolve_host("localhost", 80)
assert len(result) == 1

async def close_connector() -> None:
await resolve_started
close_started.set_result(None)
await connector.close()

await asyncio.gather(resolve_host(), close_connector())

# After close, new resolves should raise ClientConnectionError
with pytest.raises(aiohttp.ClientConnectionError, match="Connector is closed"):
await connector._resolve_host("localhost", 80)
Loading