diff --git a/aiohttp_asgi/resource.py b/aiohttp_asgi/resource.py index 679cc8a..b99c432 100644 --- a/aiohttp_asgi/resource.py +++ b/aiohttp_asgi/resource.py @@ -1,6 +1,9 @@ import asyncio import logging from contextlib import contextmanager +from contextvars import ContextVar +from pathlib import Path +from re import Pattern from types import MappingProxyType from typing import ( Any, Awaitable, Callable, Coroutine, Dict, Generator, List, Mapping, @@ -13,8 +16,12 @@ from aiohttp.abc import AbstractMatchInfo, AbstractStreamWriter from aiohttp.helpers import DEBUG from aiohttp.web import ( - AbstractResource, Application, HTTPException, Request, StreamResponse, - WebSocketResponse, + AbstractResource, AbstractRoute, Application, HTTPException, Request, + StreamResponse, WebSocketResponse, +) +from aiohttp.web_urldispatcher import AbstractRuleMatching +from aiohttp.web_urldispatcher import ( + _default_expect_handler as default_expect_handler, ) from yarl import URL @@ -29,15 +36,26 @@ ] -try: - from aiohttp.web_urldispatcher import ( - _InfoDict as ResourceInfoDict, # type: ignore - ) -except ImportError: - ResourceInfoDict = Dict[str, Any] # type: ignore +log = logging.getLogger(__name__) -log = logging.getLogger(__name__) +class ResourceInfoDict(TypedDict, total=False): + """ + Redefining `aiohttp.web_urldispatcher._InfoDict`. + It is not total and just using for better typing. + Do not be afraid of this. + """ + + path: str + formatter: str + pattern: Pattern[str] + directory: Path + prefix: str + routes: Mapping[str, AbstractRoute] + app: Application + domain: str + rule: AbstractRuleMatching + http_exception: HTTPException class ScopeDict(TypedDict): @@ -66,19 +84,23 @@ class LifespanDict(TypedDict): class ASGIMatchInfo(AbstractMatchInfo): + CURRENT_APP: ContextVar[Application] = ContextVar("CURRENT_APP") + def __init__(self, handler: Callable[..., Any]): self._handler = handler - self._apps: List[Application] = list() - self._current_app: Optional[Application] = None - self._frozen = False + self._apps: Union[List[Application], Tuple[Application, ...]] = list() + + @property + def frozen(self) -> bool: + return isinstance(self._apps, tuple) @property def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: return self._handler @property - def expect_handler(self) -> Callable[[Request], Awaitable[None]]: - raise NotImplementedError + def expect_handler(self) -> Optional[Callable[[Request], Awaitable[None]]]: + return default_expect_handler @property def http_exception(self) -> Optional[HTTPException]: @@ -93,15 +115,15 @@ def get_info(self) -> Dict[str, Any]: @property def apps(self) -> Tuple[Application, ...]: - if not isinstance(self._apps, tuple): - return tuple(self._apps) - return self._apps + if isinstance(self._apps, tuple): + return self._apps + return tuple(self._apps) def add_app(self, app: Application) -> None: - if self._frozen: + if isinstance(self._apps, tuple): raise RuntimeError("Cannot change apps stack after .freeze() call") - if self._current_app is None: - self._current_app = app + + self.CURRENT_APP.set(app) self._apps.insert(0, app) @contextmanager @@ -114,19 +136,18 @@ def set_current_app( " instead (https://github.com/mosquito/aiohttp-asgi/pull/11)!", DeprecationWarning, ) - prev = self._current_app - self.add_app(app) - self._current_app = app + prev_app = self.CURRENT_APP.get() + self.CURRENT_APP.set(app) try: yield finally: - self._current_app = prev - self._apps.pop(0) + self.CURRENT_APP.set(prev_app) @property def current_app(self) -> Application: - app = self._current_app - assert app is not None + app = self.CURRENT_APP.get(None) + if app is None: + raise RuntimeError("No current app set, use add_app() method first") return app @current_app.setter @@ -138,10 +159,10 @@ def current_app(self, app: Application) -> None: self._apps, app, ), ) - self._current_app = app + self.CURRENT_APP.set(app) def freeze(self) -> None: - self._frozen = True + self._apps = tuple(self._apps) _ResponseType = Optional[Union[StreamResponse, WebSocketResponse]] diff --git a/tests/test_fastapi_integration.py b/tests/test_fastapi_integration.py index e68367d..3d0ea69 100644 --- a/tests/test_fastapi_integration.py +++ b/tests/test_fastapi_integration.py @@ -35,6 +35,19 @@ async def websocket_endpoint(websocket: ASGIWebSocket): except WebSocketDisconnect: return + @asgi_app.post("/upload") + async def upload_endpoint(request: ASGIRequest): + headers = dict(request.scope["headers"]) + expect_header = headers.get(b"expect", b"").decode().lower() + + body = await request.body() + + return { + "received_expect": expect_header, + "body_size": len(body), + "message": "success", + } + @pytest.fixture def asgi_resource(routes, asgi_app): @@ -146,3 +159,18 @@ def test_get_routes_from_resource(asgi_resource): for _ in asgi_resource: # Should be unreachable pytest.fail("ASGIResource should not return routes during iteration") + + +async def test_expect_handler_basic(client: test_utils.TestClient): + test_data = b"This is test data for upload" + + async with client.post( + "/upload", + data=test_data, + headers={"Expect": "100-continue", "Content-Type": "application/octet-stream"}, + ) as resp: + assert resp.status == 200 + data = await resp.json() + + assert data["body_size"] == len(test_data) + assert data["message"] == "success"