Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 50 additions & 29 deletions aiohttp_asgi/resource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import logging
from contextlib import contextmanager
from contextvars import ContextVar
from pathlib import Path
from re import Pattern
from types import MappingProxyType
from typing import (
Any, Awaitable, Callable, Coroutine, Dict, Generator, List, Mapping,
Expand All @@ -13,8 +16,12 @@
from aiohttp.abc import AbstractMatchInfo, AbstractStreamWriter
from aiohttp.helpers import DEBUG
from aiohttp.web import (
AbstractResource, Application, HTTPException, Request, StreamResponse,
WebSocketResponse,
AbstractResource, AbstractRoute, Application, HTTPException, Request,
StreamResponse, WebSocketResponse,
)
from aiohttp.web_urldispatcher import AbstractRuleMatching
from aiohttp.web_urldispatcher import (
_default_expect_handler as default_expect_handler,
)
from yarl import URL

Expand All @@ -29,15 +36,26 @@
]


try:
from aiohttp.web_urldispatcher import (
_InfoDict as ResourceInfoDict, # type: ignore
)
except ImportError:
ResourceInfoDict = Dict[str, Any] # type: ignore
log = logging.getLogger(__name__)


log = logging.getLogger(__name__)
class ResourceInfoDict(TypedDict, total=False):
"""
Redefining `aiohttp.web_urldispatcher._InfoDict`.
It is not total and just using for better typing.
Do not be afraid of this.
"""

path: str
formatter: str
pattern: Pattern[str]
directory: Path
prefix: str
routes: Mapping[str, AbstractRoute]
app: Application
domain: str
rule: AbstractRuleMatching
http_exception: HTTPException


class ScopeDict(TypedDict):
Expand Down Expand Up @@ -66,19 +84,23 @@ class LifespanDict(TypedDict):


class ASGIMatchInfo(AbstractMatchInfo):
CURRENT_APP: ContextVar[Application] = ContextVar("CURRENT_APP")

def __init__(self, handler: Callable[..., Any]):
self._handler = handler
self._apps: List[Application] = list()
self._current_app: Optional[Application] = None
self._frozen = False
self._apps: Union[List[Application], Tuple[Application, ...]] = list()

@property
def frozen(self) -> bool:
return isinstance(self._apps, tuple)

@property
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
return self._handler

@property
def expect_handler(self) -> Callable[[Request], Awaitable[None]]:
raise NotImplementedError
def expect_handler(self) -> Optional[Callable[[Request], Awaitable[None]]]:
return default_expect_handler

@property
def http_exception(self) -> Optional[HTTPException]:
Expand All @@ -93,15 +115,15 @@ def get_info(self) -> Dict[str, Any]:

@property
def apps(self) -> Tuple[Application, ...]:
if not isinstance(self._apps, tuple):
return tuple(self._apps)
return self._apps
if isinstance(self._apps, tuple):
return self._apps
return tuple(self._apps)

def add_app(self, app: Application) -> None:
if self._frozen:
if isinstance(self._apps, tuple):
raise RuntimeError("Cannot change apps stack after .freeze() call")
if self._current_app is None:
self._current_app = app

self.CURRENT_APP.set(app)
self._apps.insert(0, app)

@contextmanager
Expand All @@ -114,19 +136,18 @@ def set_current_app(
" instead (https://github.com/mosquito/aiohttp-asgi/pull/11)!",
DeprecationWarning,
)
prev = self._current_app
self.add_app(app)
self._current_app = app
prev_app = self.CURRENT_APP.get()
self.CURRENT_APP.set(app)
try:
yield
finally:
self._current_app = prev
self._apps.pop(0)
self.CURRENT_APP.set(prev_app)

@property
def current_app(self) -> Application:
app = self._current_app
assert app is not None
app = self.CURRENT_APP.get(None)
if app is None:
raise RuntimeError("No current app set, use add_app() method first")
return app

@current_app.setter
Expand All @@ -138,10 +159,10 @@ def current_app(self, app: Application) -> None:
self._apps, app,
),
)
self._current_app = app
self.CURRENT_APP.set(app)

def freeze(self) -> None:
self._frozen = True
self._apps = tuple(self._apps)


_ResponseType = Optional[Union[StreamResponse, WebSocketResponse]]
Expand Down
28 changes: 28 additions & 0 deletions tests/test_fastapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ async def websocket_endpoint(websocket: ASGIWebSocket):
except WebSocketDisconnect:
return

@asgi_app.post("/upload")
async def upload_endpoint(request: ASGIRequest):
headers = dict(request.scope["headers"])
expect_header = headers.get(b"expect", b"").decode().lower()

body = await request.body()

return {
"received_expect": expect_header,
"body_size": len(body),
"message": "success",
}


@pytest.fixture
def asgi_resource(routes, asgi_app):
Expand Down Expand Up @@ -146,3 +159,18 @@ def test_get_routes_from_resource(asgi_resource):
for _ in asgi_resource:
# Should be unreachable
pytest.fail("ASGIResource should not return routes during iteration")


async def test_expect_handler_basic(client: test_utils.TestClient):
test_data = b"This is test data for upload"

async with client.post(
"/upload",
data=test_data,
headers={"Expect": "100-continue", "Content-Type": "application/octet-stream"},
) as resp:
assert resp.status == 200
data = await resp.json()

assert data["body_size"] == len(test_data)
assert data["message"] == "success"