Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ASGI router #644

Merged
merged 4 commits into from
Oct 23, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions starlite/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from typing_extensions import TypedDict

from starlite.asgi import (
ASGIRouter,
PathParameterTypePathDesignator,
PathParamNode,
RouteMapNode,
StarliteASGIRouter,
)
from starlite.config import AppConfig, CacheConfig, OpenAPIConfig
from starlite.config.logging import get_logger_placeholder
Expand Down Expand Up @@ -383,7 +383,7 @@ def __init__(
self._static_paths.add(static_config.path)
self.register(asgi(path=static_config.path, name=static_config.name)(static_config.to_static_files_app()))

self.asgi_router = StarliteASGIRouter(on_shutdown=self.on_shutdown, on_startup=self.on_startup, app=self)
self.asgi_router = ASGIRouter(app=self)
self.asgi_handler = self._create_asgi_handler()

async def __call__(
Expand All @@ -406,7 +406,7 @@ async def __call__(
"""
scope["app"] = self
if scope["type"] == "lifespan":
await self.asgi_router.lifespan(scope, receive, send) # type: ignore[arg-type]
await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type]
return
scope["state"] = {}
await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type]
Expand Down
65 changes: 55 additions & 10 deletions starlite/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from pathlib import Path
from traceback import format_exc
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -22,7 +23,6 @@
parse_duration,
parse_time,
)
from starlette.routing import Router as StarletteRouter

from starlite.enums import ScopeType
from starlite.exceptions import (
Expand All @@ -38,6 +38,12 @@
from starlite.types import (
ASGIApp,
LifeSpanHandler,
LifeSpanReceive,
LifeSpanSend,
LifeSpanShutdownCompleteEvent,
LifeSpanShutdownFailedEvent,
LifeSpanStartupCompleteEvent,
LifeSpanStartupFailedEvent,
Receive,
RouteHandlerType,
Scope,
Expand All @@ -59,18 +65,20 @@ class PathParameterTypePathDesignator:
ComponentsSet = Set[Union[str, PathParamPlaceholderType, TerminusNodePlaceholderType]]


class StarliteASGIRouter(StarletteRouter):
"""This class extends the Starlette Router class and *is* the ASGI app used
in Starlite."""
class ASGIRouter:
__slots__ = ("app",)

def __init__(
self,
app: "Starlite",
on_shutdown: List["LifeSpanHandler"],
on_startup: List["LifeSpanHandler"],
) -> None:
"""This class is the Starlite ASGI router. It handles both the ASGI
lifespan event and routing connection requests.
Args:
app: The Starlite app instance
"""
self.app = app
super().__init__(on_startup=on_startup, on_shutdown=on_shutdown)

def _traverse_route_map(self, path: str, scope: "Scope") -> Tuple[RouteMapNode, List[str]]:
"""Traverses the application route mapping and retrieves the correct
Expand Down Expand Up @@ -208,7 +216,7 @@ def _resolve_handler_node(
node = asgi_handlers[ScopeType.WEBSOCKET]
return node["asgi_app"], node["handler"]

async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: # type: ignore[override]
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
"""The main entry point to the Router class."""
try:
asgi_handlers, is_asgi = self._parse_scope_to_route(scope=scope)
Expand All @@ -218,6 +226,43 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No
scope["route_handler"] = handler
await asgi_app(scope, receive, send)

async def lifespan(self, receive: "LifeSpanReceive", send: "LifeSpanSend") -> None:
"""Handles the ASGI "lifespan" event on application startup and
shutdown.
Args:
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None.
"""
message = await receive()
try:
if message["type"] == "lifespan.startup":
await self.startup()
startup_event: "LifeSpanStartupCompleteEvent" = {"type": "lifespan.startup.complete"}
await send(startup_event)
await receive()
except BaseException as e:
if message["type"] == "lifespan.startup":
startup_failure_event: "LifeSpanStartupFailedEvent" = {
"type": "lifespan.startup.failed",
"message": format_exc(),
}
await send(startup_failure_event)
else:
shutdown_failure_event: "LifeSpanShutdownFailedEvent" = {
"type": "lifespan.shutdown.failed",
"message": format_exc(),
}
await send(shutdown_failure_event)
raise e
else:
await self.shutdown()
shutdown_event: "LifeSpanShutdownCompleteEvent" = {"type": "lifespan.shutdown.complete"}
await send(shutdown_event)

async def _call_lifespan_handler(self, handler: "LifeSpanHandler") -> None:
"""Determines whether the lifecycle handler expects an argument, and if
so passes the `app.state` to it. If the handler is an async function,
Expand All @@ -244,7 +289,7 @@ async def startup(self) -> None:
for hook in self.app.before_startup:
await hook(self.app)

for handler in self.on_startup:
for handler in self.app.on_startup:
await self._call_lifespan_handler(handler)

for hook in self.app.after_startup:
Expand All @@ -262,7 +307,7 @@ async def shutdown(self) -> None:
for hook in self.app.before_shutdown:
await hook(self.app)

for handler in self.on_shutdown:
for handler in self.app.on_shutdown:
await self._call_lifespan_handler(handler)

for hook in self.app.after_shutdown:
Expand Down
2 changes: 1 addition & 1 deletion starlite/cache/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from asyncio import Lock
from typing import TYPE_CHECKING, Any, Optional, overload

from anyio import Lock
from typing_extensions import Protocol, runtime_checkable

from starlite.utils import is_async_callable
Expand Down
3 changes: 2 additions & 1 deletion starlite/cache/simple_cache_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from asyncio import Lock
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict

from anyio import Lock

from starlite.cache.base import CacheBackendProtocol


Expand Down
2 changes: 1 addition & 1 deletion starlite/connection/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def receive_data(self, mode: "Literal['binary', 'text']") -> Union[str, by
if event["type"] == "websocket.disconnect":
raise WebSocketDisconnect(detail="disconnect event", code=event["code"])
if self.connection_state == "disconnect":
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE)
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover
return event.get("text") or "" if mode == "text" else event.get("bytes") or b""

async def receive_text(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions starlite/utils/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def __init__(self, fn: Callable[P, T]) -> None:
fn: Callable to wrap - can be any sync or async callable.
"""

self.is_method = ismethod(fn)
self.is_method = ismethod(fn) or (callable(fn) and ismethod(fn.__call__)) # type: ignore
self.num_expected_args = len(getfullargspec(fn).args) - (1 if self.is_method else 0)
self.wrapped_callable: Dict[Literal["fn"], Callable] = {
"fn": fn if is_async_callable(fn) else async_partial(fn)
"fn": fn if is_async_callable(fn) else async_partial(fn) # pyright: ignore
}

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
Expand Down
39 changes: 39 additions & 0 deletions tests/asgi_router/test_asgi_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from starlite.testing import create_test_client


class _LifeSpanCallable:
def __init__(self, should_raise: bool = False) -> None:
self.called = False
self.should_raise = should_raise

def __call__(self) -> None:
self.called = True
if self.should_raise:
raise RuntimeError("damn")


def test_life_span_startup() -> None:
life_span_callable = _LifeSpanCallable()
with create_test_client([], on_startup=[life_span_callable]):
assert life_span_callable.called


def test_life_span_startup_error_handling() -> None:
life_span_callable = _LifeSpanCallable(should_raise=True)
with pytest.raises(RuntimeError), create_test_client([], on_startup=[life_span_callable]):
pass


def test_life_span_shutdown() -> None:
life_span_callable = _LifeSpanCallable()
with create_test_client([], on_shutdown=[life_span_callable]):
pass
assert life_span_callable.called


def test_life_span_shutdown_error_handling() -> None:
life_span_callable = _LifeSpanCallable(should_raise=True)
with pytest.raises(RuntimeError), create_test_client([], on_shutdown=[life_span_callable]):
pass