diff --git a/examples/ace_editor.py b/examples/ace_editor.py index d785bcc..22d32f2 100644 --- a/examples/ace_editor.py +++ b/examples/ace_editor.py @@ -1,11 +1,19 @@ -import uvicorn +from contextlib import asynccontextmanager from pydantic import BaseModel from uiwiz import UiwizApp, ui +from uiwiz.server import run -app = UiwizApp() +@asynccontextmanager +async def lifespan(app): + print("") + yield + print("") + +app = UiwizApp(lifespan=lifespan) + class DataInput(BaseModel): ace_data: str @@ -35,4 +43,4 @@ async def home_page(): if __name__ == "__main__": - uvicorn.run("ace_editor:app", reload=True) + run("ace_editor:app") diff --git a/examples/dict_example.py b/examples/dict_example.py index 181543e..04e42fc 100644 --- a/examples/dict_example.py +++ b/examples/dict_example.py @@ -1,4 +1,4 @@ -import uvicorn +from uiwiz import server from uiwiz import UiwizApp, ui @@ -42,4 +42,4 @@ async def test(): if __name__ == "__main__": - uvicorn.run("dict_example:app", reload=True) + server.run("dict_example:app") diff --git a/examples/drawer_example.py b/examples/drawer_example.py index 1ca3463..1403349 100644 --- a/examples/drawer_example.py +++ b/examples/drawer_example.py @@ -1,6 +1,6 @@ import logging -import uvicorn +from uiwiz import server import uiwiz.ui as ui from uiwiz.app import UiwizApp @@ -57,4 +57,4 @@ async def test(): if __name__ == "__main__": - uvicorn.run("drawer_example:app", reload=True) + server.run("drawer_example:app") diff --git a/examples/echart_example.py b/examples/echart_example.py index 0de2d12..42c8421 100644 --- a/examples/echart_example.py +++ b/examples/echart_example.py @@ -1,6 +1,6 @@ from random import randint -import uvicorn +from uiwiz import server from uiwiz import UiwizApp, ui @@ -90,4 +90,4 @@ async def test(): if __name__ == "__main__": - uvicorn.run("echart_example:app", reload=True) + server.run("echart_example:app") diff --git a/examples/edit_table.py b/examples/edit_table.py index 8dece39..8be6905 100644 --- a/examples/edit_table.py +++ b/examples/edit_table.py @@ -1,6 +1,6 @@ from typing import Annotated -import uvicorn +from uiwiz import server from fastapi import Request from pydantic import BaseModel, Field @@ -79,4 +79,4 @@ async def test(request: Request): if __name__ == "__main__": - uvicorn.run("edit_table:app", reload=True) + server.run("edit_table:app") diff --git a/examples/input_example.py b/examples/input_example.py index edab438..715e37b 100644 --- a/examples/input_example.py +++ b/examples/input_example.py @@ -3,8 +3,8 @@ from typing import Annotated, Optional import pandas as pd -import uvicorn -from fastapi import Depends +from uiwiz import server +from fastapi import Depends, Request from pydantic import BaseModel from uiwiz import Element, PageDefinition, PageRouter, UiwizApp, ui @@ -59,12 +59,10 @@ def footer(self, content): route = PageRouter(page_definition_class=MyDefinition) -PageRouter() - # set static page title @route.page("/", title="Input Example") -async def test(page: Annotated[MyDefinition, Depends()]): +async def test(page: Annotated[MyDefinition, Depends()], req: Request): # set dynamic page title page.title = "Dynamic title" # set dynamic page lang @@ -152,4 +150,4 @@ async def test(page: Annotated[MyDefinition, Depends()]): if __name__ == "__main__": - uvicorn.run("input_example:app", reload=True) + server.run("input_example:app") diff --git a/examples/multipage/main.py b/examples/multipage/main.py index 856fc6a..fb10acc 100644 --- a/examples/multipage/main.py +++ b/examples/multipage/main.py @@ -1,8 +1,7 @@ -import uvicorn +from uiwiz import server -from examples.multipage.second_page import router +from multipage.second_page import router from uiwiz import UiwizApp, ui -from uiwiz.shared import page_map app = UiwizApp() app.include_router(router) @@ -13,10 +12,11 @@ async def test(): with ui.element().classes("col lg:px-80"): with ui.element().classes("w-full"): ui.element(content="Hello, world!") - for route in page_map.values(): - with ui.col(): - ui.link(route, route) + with ui.col(): + ui.link("Home", "/") + with ui.col(): + ui.link("Second page", "/second_page") if __name__ == "__main__": - uvicorn.run("examples.multipage.main:app", reload=True, port=8000) + server.run("multipage.main:app") diff --git a/examples/run_simple.py b/examples/run_simple.py index a0d80e1..4eb5195 100644 --- a/examples/run_simple.py +++ b/examples/run_simple.py @@ -1,11 +1,11 @@ from io import BytesIO import pandas as pd -import uvicorn from fastapi import Request, UploadFile import uiwiz.ui as ui from uiwiz.app import UiwizApp +from uiwiz import server app = UiwizApp(theme="aqua") @@ -62,4 +62,4 @@ async def test(request: Request): if __name__ == "__main__": - uvicorn.run("run_simple:app", reload=True) + server.run("run_simple:app") diff --git a/examples/run_tabs.py b/examples/run_tabs.py index 7249252..92c3750 100644 --- a/examples/run_tabs.py +++ b/examples/run_tabs.py @@ -1,7 +1,5 @@ -import uvicorn - import uiwiz.ui as ui -from uiwiz.app import UiwizApp +from uiwiz import UiwizApp, server app = UiwizApp() @@ -35,4 +33,4 @@ async def test(): if __name__ == "__main__": - uvicorn.run("run_tabs:app", reload=True) + server.run("run_tabs:app") diff --git a/examples/sample.py b/examples/sample.py index 379da09..2bf3e2a 100644 --- a/examples/sample.py +++ b/examples/sample.py @@ -1,7 +1,4 @@ -import uvicorn - -from uiwiz import ui -from uiwiz.app import UiwizApp +from uiwiz import ui, UiwizApp, server app = UiwizApp() @@ -12,4 +9,4 @@ async def home_page(): if __name__ == "__main__": - uvicorn.run(app) + server.run(app) diff --git a/examples/validate_form_example.py b/examples/validate_form_example.py index a13696f..ad2dfff 100644 --- a/examples/validate_form_example.py +++ b/examples/validate_form_example.py @@ -2,11 +2,10 @@ from datetime import date from typing import Annotated, Literal -import uvicorn from pydantic import BaseModel, Field import uiwiz.ui as ui -from uiwiz.app import UiwizApp +from uiwiz import UiwizApp, server from uiwiz.models.model_handler import UiAnno app = UiwizApp(auto_close_toast_error=False) @@ -90,4 +89,4 @@ async def test(): if __name__ == "__main__": app_name = os.path.basename(__file__).replace(".py", "") - uvicorn.run(app=f"{app_name}:app", host="0.0.0.0", port=8080, workers=1, reload=True) + server.run(app=f"{app_name}:app") diff --git a/src/uiwiz/page_definition.py b/src/uiwiz/page_definition.py index fbf1f55..4ee27b4 100644 --- a/src/uiwiz/page_definition.py +++ b/src/uiwiz/page_definition.py @@ -10,6 +10,7 @@ from uiwiz.version import __version__ + class PageDefinition: html_ele: Element header_ele: Element @@ -51,6 +52,7 @@ def footer(self, content: Element) -> None: """ self._lang: str = "en" + self._title_ele: Optional[Element] = None @property def lang(self) -> str: @@ -63,11 +65,11 @@ def lang(self, value: str) -> None: @property def title(self) -> str: - return self.title_ele.content + return self._title_ele.content @title.setter def title(self, value: str) -> None: - self.title_ele.content = value + self._title_ele.content = value async def render( self, @@ -102,7 +104,7 @@ def render(self): Element("meta", charset="utf-8") Element("meta", description=frame.meta_description_content) - self.title_ele = Element("title", content=page_title) + self._title_ele = Element("title", content=page_title) Element("link", href=f"/_static/{__version__}/libs/output.css", rel="stylesheet", type="text/css") Element("link", href=f"/_static/{__version__}/libs/daisyui.css", rel="stylesheet", type="text/css") diff --git a/src/uiwiz/page_route.py b/src/uiwiz/page_route.py index 0cab9f6..d1e028b 100644 --- a/src/uiwiz/page_route.py +++ b/src/uiwiz/page_route.py @@ -175,8 +175,13 @@ async def decorated(*dec_args, **dec_kwargs: DecKwargs) -> Response: Frame.get_stack().del_stack() # Create frame before function is called - request = dec_kwargs["request"] - response = dec_kwargs["response"] + request = None + response = None + for value in dec_kwargs.values(): + if isinstance(value, Request): + request = value + if isinstance(value, Response): + response = value if self.page_definition_class is None: self.page_definition_class = request.app.page_definition_class @@ -185,7 +190,7 @@ async def decorated(*dec_args, **dec_kwargs: DecKwargs) -> Response: page = page_class() - dec_kwargs = {k: v for k, v in dec_kwargs.items() if k in parameters_of_decorated_func} + dec_kwargs = {k: v if not isinstance(v, PageDefinition) else page for k, v in dec_kwargs.items() if k in parameters_of_decorated_func} user_method = partial(func, *dec_args, **dec_kwargs) result = await page.render(user_method=user_method, request=request, title=cap_title) if isinstance(result, Response): diff --git a/src/uiwiz/server/__init__.py b/src/uiwiz/server/__init__.py new file mode 100644 index 0000000..a560aaf --- /dev/null +++ b/src/uiwiz/server/__init__.py @@ -0,0 +1,7 @@ +from uiwiz.server._server import Config, Server + +def run(app: str, host: str = "localhost", port:int = 8080): + config = Config(host=host, port=port, app=app, root_path="") + + server = Server(config) + server.run() \ No newline at end of file diff --git a/src/uiwiz/server/_server.py b/src/uiwiz/server/_server.py new file mode 100644 index 0000000..c47231d --- /dev/null +++ b/src/uiwiz/server/_server.py @@ -0,0 +1,452 @@ +import asyncio +from asyncio import Queue, Event, TimerHandle +import re +from typing import Any, Optional +import urllib +import httptools +import http +import importlib +from dataclasses import dataclass +from collections import deque +from contextlib import suppress + +from uvicorn._types import ( + ASGI3Application, +) +from uvicorn.protocols.http.flow_control import FlowControl +from uvicorn.protocols.http.httptools_impl import RequestResponseCycle +from time import perf_counter + +from uvicorn.protocols.http.flow_control import HIGH_WATER_LIMIT +from uiwiz.app import UiwizApp +import logging + +formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" +) +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger(__name__) +sc = logging.StreamHandler() +sc.setFormatter(formatter) +logger.addHandler(sc) + +HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]') +HEADER_VALUE_RE = re.compile(b"[\x00-\x08\x0a-\x1f\x7f]") + +@dataclass +class Config: + host: str + port: int + root_path: str + app: Optional[str] = None + app_instance: Optional[UiwizApp] = None + + +def import_app_instance(config: Config) -> None: + start = perf_counter() + if isinstance(config.app, str): + module_name, _, app = config.app.partition(":") + module = importlib.import_module(module_name) + module = importlib.reload(module) + config.app_instance = getattr(module, app) + else: + config.app_instance = config.app + end = perf_counter() + logger.info(f"Module reloaded in: {end - start}") + + +class LifespanHandler: + def __init__(self, config: Config): + self.config = config + self.receive_queue: Queue = Queue() + self.state: dict[str, Any] = {} + self.startup_done_event = Event() + self.shutdown_done_event = Event() + + async def startup(self) -> None: + logger.info("Calling lifespan") + loop = asyncio.get_running_loop() + lifespan_task = loop.create_task(self.execute()) + await self.receive_queue.put({"type": "lifespan.startup"}) + await self.startup_done_event.wait() + + logger.info("Startup complete") + + async def shutdown(self) -> None: + logger.info("Waiting for application shutdown.") + shutdown_event = {"type": "lifespan.shutdown"} + await self.receive_queue.put(shutdown_event) + await self.shutdown_done_event.wait() + + logger.info("Application shutdown complete.") + + async def execute(self) -> None: + app = self.config.app_instance + scope = { + "type": "lifespan", + "asgi": {"version": "3", "spec_version": "2.0"}, + "state": self.state, + } + await app(scope, self.receive, self.send) + + async def send(self, message: dict) -> None: + task = { + "lifespan.startup.complete": lambda: self.startup_done_event.set(), + "lifespan.startup.failed": lambda: self.startup_done_event.set(), + "lifespan.shutdown.complete": lambda: self.shutdown_done_event.set(), + "lifespan.shutdown.failed": lambda: self.shutdown_done_event.set(), + } + task.get(message["type"], lambda: 1)() + + async def receive(self): + with suppress(asyncio.CancelledError): + return await self.receive_queue.get() + +class ServerState: + def __init__(self, config: Config): + self.total_requests = 0 + self.connections = set() + self.tasks: set[asyncio.Task[None]] = set() + self.default_headers: list[tuple[bytes, bytes]] = [] + self.init_load: bool = True + self.lifespan = LifespanHandler(config) + +class HttpToolsImpl(asyncio.Protocol): + def __init__(self, config: Config, server_state: ServerState): + self.config = config + self.app = config.app + self.state: dict[str, Any] = {} + self.loop = asyncio.get_event_loop() + self.parser = httptools.HttpRequestParser(self) + self.timeout_keep_alive_task: TimerHandle | None = None + self.transport: asyncio.Transport = None + self.app_state = dict() + self.root_path = config.root_path + self.cycle: RequestResponseCycle = None + + self.timeout_keep_alive = 10 + + self.flow: FlowControl = None # type: ignore[assignment] + self.pipeline: deque[tuple[RequestResponseCycle, ASGI3Application]] = deque() + self.logger = logger + self.access_logger = logger + self.access_log = logger.hasHandlers() + self.server_state = server_state + self.tasks = self.server_state.tasks + + def data_received(self, data: bytes) -> None: + self._unset_keepalive_if_required() + + try: + self.parser.feed_data(data) + except httptools.HttpParserError: + msg = "Invalid HTTP request received." + logger.warning(msg) + self.send_400_response(msg) + return + except httptools.HttpParserUpgrade: + if self._should_upgrade(): + self.handle_websocket_upgrade() + else: + self._unsupported_upgrade_warning() + + def eof_received(self) -> None: ... + + def _unset_keepalive_if_required(self) -> None: + if self.timeout_keep_alive_task is not None: + self.timeout_keep_alive_task.cancel() + self.timeout_keep_alive_task = None + + def connection_made(self, transport: asyncio.Transport) -> None: + self.transport = transport + self.flow = FlowControl(transport) + self.server = self.get_local_addr() + self.client = self.get_remote_addr() + self.scheme = "https" if bool(transport.get_extra_info("sslcontext")) else "http" + + def connection_lost(self, exc: Optional[Exception]) -> None: + if self.cycle and not self.cycle.response_complete: + self.cycle.disconnected = True + if self.cycle is not None: + self.cycle.message_event.set() + if self.flow is not None: + self.flow.resume_writing() + if exc is None: + self.transport.close() + self._unset_keepalive_if_required() + + self.parser = None + + def on_url(self, url: bytes) -> None: + self.url += url + + def on_header(self, name: bytes, value: bytes) -> None: + name = name.lower() + if name == b"expect" and value.lower() == b"100-continue": + self.expect_100_continue = True + self.headers.append((name, value)) + + def on_headers_complete(self) -> None: + http_version = self.parser.get_http_version() + method = self.parser.get_method() + self.scope["method"] = method.decode("ascii") + if http_version != "1.1": + self.scope["http_version"] = http_version + if self.parser.should_upgrade() and self._should_upgrade(): + return + parsed_url = httptools.parse_url(self.url) + raw_path = parsed_url.path + path = raw_path.decode("ascii") + if "%" in path: + path = urllib.parse.unquote(path) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + raw_path + self.scope["path"] = full_path + self.scope["raw_path"] = full_raw_path + self.scope["query_string"] = parsed_url.query or b"" + + import_app_instance(self.config) + + existing_cycle = self.cycle + self.cycle = RRCycle( + scope=self.scope, + transport=self.transport, + flow=self.flow, + logger=self.logger, + access_logger=self.access_logger, + access_log=self.access_log, + default_headers=self.server_state.default_headers, + message_event=asyncio.Event(), + expect_100_continue=self.expect_100_continue, + keep_alive=http_version != "1.0", + on_response=self.on_response_complete, + ) + if existing_cycle is None or existing_cycle.response_complete: + self.tasks.add(self.loop.create_task(self.cycle.run_asgi(self.config.app_instance))) + else: + # Pipelined HTTP requests need to be queued up. + self.flow.pause_reading() + self.pipeline.appendleft((self.cycle, self.config.app_instance)) + + def _shutdown(self, *args): + task = self.loop.create_task(self.lifespan.shutdown()) + task.add_done_callback(self.tasks.discard) + self.tasks.add(task) + + done = {ta for ta in set(self.tasks) if ta.done()} + self.tasks.difference_update(done) + + def on_message_begin(self) -> None: + self.url = b"" + self.expect_100_continue = False + self.headers = [] + self.scope = { # type: ignore[typeddict-item] + "type": "http", + "asgi": {"version": "asgi3", "spec_version": "2.3"}, + "http_version": "1.1", + "server": self.server, + "client": self.client, + "scheme": self.scheme, # type: ignore[typeddict-item] + "root_path": "", + "headers": self.headers, + "state": self.app_state, + } + + + def shutdown(self) -> None: + """ + Called by the server to commence a graceful shutdown. + """ + if self.cycle is None or self.cycle.response_complete: + self.transport.close() + else: + self.cycle.keep_alive = False + + def on_body(self, body: bytes) -> None: + if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete: + return + self.cycle.body += body + if len(self.cycle.body) > HIGH_WATER_LIMIT: + self.flow.pause_reading() + self.cycle.message_event.set() + + + def on_message_complete(self) -> None: + if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete: + return + self.cycle.more_body = False + self.cycle.message_event.set() + + def on_response_complete(self) -> None: + if self.transport.is_closing(): + return + + self._unset_keepalive_if_required() + + # Unpause data reads if needed. + self.flow.resume_reading() + + # Unblock any pipelined events. If there are none, arm the + # Keep-Alive timeout instead. + if self.pipeline: + cycle, app = self.pipeline.pop() + task = self.loop.create_task(cycle.run_asgi(app)) + task.add_done_callback(self.tasks.discard) + self.tasks.add(task) + else: + self.timeout_keep_alive_task = self.loop.call_later( + self.timeout_keep_alive, self.timeout_keep_alive_handler + ) + + def _get_upgrade(self) -> Optional[bytes]: + connection = [] + upgrade = None + for name, value in self.headers: + if name == b"connection": + connection = [token.lower().strip() for token in value.split(b",")] + if name == b"upgrade": + upgrade = value.lower() + if b"upgrade" in connection: + return upgrade + return None # pragma: full coverage + + def _should_upgrade_to_ws(self) -> bool: + if self.ws_protocol_class is None: + return False + return True + + def _unsupported_upgrade_warning(self) -> None: + logger.warning("Unsupported upgrade request.") + if not self._should_upgrade_to_ws(): + msg = "No supported WebSocket library detected. Please use \"pip install 'uvicorn[standard]'\", or install 'websockets' or 'wsproto' manually." # noqa: E501 + logger.warning(msg) + + def _should_upgrade(self) -> bool: + upgrade = self._get_upgrade() + return upgrade == b"websocket" and self._should_upgrade_to_ws() + + def get_local_addr(self) -> Optional[tuple[str, int]]: + socket_info = self.transport.get_extra_info("socket") + if socket_info is not None: + info = socket_info.getsockname() + + return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None + info = self.transport.get_extra_info("sockname") + if info is not None and isinstance(info, (list, tuple)) and len(info) == 2: + return (str(info[0]), int(info[1])) + return None + + def get_remote_addr(self) -> Optional[tuple[str, int]]: + socket_info = self.transport.get_extra_info("socket") + if socket_info is not None: + try: + info = socket_info.getpeername() + return (str(info[0]), int(info[1])) if isinstance(info, tuple) else None + except OSError: + return None + + info = self.transport.get_extra_info("peername") + if info is not None and isinstance(info, (list, tuple)) and len(info) == 2: + return (str(info[0]), int(info[1])) + return None + + def send_400_response(self, msg: str) -> None: + message = [http.HTTPStatus(400).phrase.encode()] + # for name, value in self.server_state.default_headers: + # message.extend([name, b": ", value, b"\r\n"]) # pragma: full coverage + message.extend( + [ + b"content-type: text/plain; charset=utf-8\r\n", + b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", + b"connection: close\r\n", + b"\r\n", + msg.encode("ascii"), + ] + ) + self.transport.write(b"".join(message)) + self.transport.close() + + def pause_writing(self) -> None: + """ + Called by the transport when the write buffer exceeds the high water mark. + """ + self.flow.pause_writing() # pragma: full coverage + + def resume_writing(self) -> None: + """ + Called by the transport when the write buffer drops below the low water mark. + """ + self.flow.resume_writing() # pragma: full coverage + + def timeout_keep_alive_handler(self) -> None: + """ + Called on a keep-alive connection if no new data is received after a short + delay. + """ + if not self.transport.is_closing(): + self.transport.close() + +class RRCycle(RequestResponseCycle): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # ASGI exception wrapper + async def run_asgi(self, app: ASGI3Application) -> None: + try: + result = await app( # type: ignore[func-returns-value] + self.scope, self.receive, self.send + ) + except BaseException as exc: + msg = "Exception in ASGI application\n" + self.logger.error(msg, exc_info=exc) + if not self.response_started: + await self.send_500_response() + else: + self.transport.close() + else: + if result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + self.transport.close() + elif not self.response_started and not self.disconnected: + msg = "ASGI callable returned without starting response." + self.logger.error(msg) + await self.send_500_response() + elif not self.response_complete and not self.disconnected: + msg = "ASGI callable returned without completing response." + self.logger.error(msg) + self.transport.close() + finally: + self.on_response = lambda: None + + +class Server: + def __init__(self, config: Config): + self.config = config + self.server_state = None + import_app_instance(config) + + def run(self) -> None: + try: + return asyncio.run(self._serve(), debug=True) + except KeyboardInterrupt: + return + + async def _serve(self) -> None: + loop = asyncio.get_running_loop() + self.server_state = ServerState(self.config) + server = await loop.create_server( + lambda: HttpToolsImpl(self.config, self.server_state), host=self.config.host, port=self.config.port + ) + async with server: + await loop.create_task(self.server_state.lifespan.startup()) + try: + await server.serve_forever() + except asyncio.exceptions.CancelledError: + await self.server_state.lifespan.shutdown() + return + await self.server_state.lifespan.shutdown() + diff --git a/src/uiwiz/server/main.py b/src/uiwiz/server/main.py new file mode 100644 index 0000000..3e955bf --- /dev/null +++ b/src/uiwiz/server/main.py @@ -0,0 +1,23 @@ +from uiwiz import ui, UiwizApp +import logging +from contextlib import asynccontextmanager +logger = logging.getLogger(__name__) + +@asynccontextmanager +async def lifespan(app): + logger.info("Start app") + yield + logger.info("Closing app") + + +app = UiwizApp(lifespan=lifespan) + + +@app.page("/") +def index(): + ui.element(content="This is pretty cool") + ui.element(content="Custom server working!") + + logger.info("Calling index") + + return {} \ No newline at end of file