Skip to content

Commit

Permalink
Update ASGI router (#644)
Browse files Browse the repository at this point in the history
* removed starlette dependency for lifespan events

* updated tests

* fix async callable

* added cookie parser (#645)

* added cookie parser

* address review comments

* Replace get name (#646)

* add get_name helper

* added helper to handle enum

* update enum safeguard

* Removed starlette HTTPException as code dependency (#647)

* removed starlette HTTPException as code dependency

* address review comments
  • Loading branch information
Goldziher committed Oct 23, 2022
1 parent 6cb5d1f commit 28f0540
Show file tree
Hide file tree
Showing 22 changed files with 228 additions and 71 deletions.
6 changes: 3 additions & 3 deletions starlite/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from typing_extensions import TypedDict

from starlite.asgi import (
ASGIRouter,
PathParameterTypePathDesignator,
PathParamNode,
RouteMapNode,
StarliteASGIRouter,
)
from starlite.config import AppConfig, CacheConfig, OpenAPIConfig
from starlite.config.logging import get_logger_placeholder
Expand Down Expand Up @@ -389,7 +389,7 @@ def __init__(
self._static_paths.add(static_config.path)
self.register(asgi(path=static_config.path, name=static_config.name)(static_config.to_static_files_app()))

self.asgi_router = StarliteASGIRouter(on_shutdown=self.on_shutdown, on_startup=self.on_startup, app=self)
self.asgi_router = ASGIRouter(app=self)
self.asgi_handler = self._create_asgi_handler()

async def __call__(
Expand All @@ -412,7 +412,7 @@ async def __call__(
"""
scope["app"] = self
if scope["type"] == "lifespan":
await self.asgi_router.lifespan(scope, receive, send) # type: ignore[arg-type]
await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type]
return
scope["state"] = {}
await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type]
Expand Down
65 changes: 55 additions & 10 deletions starlite/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from pathlib import Path
from traceback import format_exc
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -22,7 +23,6 @@
parse_duration,
parse_time,
)
from starlette.routing import Router as StarletteRouter

from starlite.enums import ScopeType
from starlite.exceptions import (
Expand All @@ -38,6 +38,12 @@
from starlite.types import (
ASGIApp,
LifeSpanHandler,
LifeSpanReceive,
LifeSpanSend,
LifeSpanShutdownCompleteEvent,
LifeSpanShutdownFailedEvent,
LifeSpanStartupCompleteEvent,
LifeSpanStartupFailedEvent,
Receive,
RouteHandlerType,
Scope,
Expand All @@ -59,18 +65,20 @@ class PathParameterTypePathDesignator:
ComponentsSet = Set[Union[str, PathParamPlaceholderType, TerminusNodePlaceholderType]]


class StarliteASGIRouter(StarletteRouter):
"""This class extends the Starlette Router class and *is* the ASGI app used
in Starlite."""
class ASGIRouter:
__slots__ = ("app",)

def __init__(
self,
app: "Starlite",
on_shutdown: List["LifeSpanHandler"],
on_startup: List["LifeSpanHandler"],
) -> None:
"""This class is the Starlite ASGI router. It handles both the ASGI
lifespan event and routing connection requests.
Args:
app: The Starlite app instance
"""
self.app = app
super().__init__(on_startup=on_startup, on_shutdown=on_shutdown)

def _traverse_route_map(self, path: str, scope: "Scope") -> Tuple[RouteMapNode, List[str]]:
"""Traverses the application route mapping and retrieves the correct
Expand Down Expand Up @@ -208,7 +216,7 @@ def _resolve_handler_node(
node = asgi_handlers[ScopeType.WEBSOCKET]
return node["asgi_app"], node["handler"]

async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: # type: ignore[override]
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
"""The main entry point to the Router class."""
try:
asgi_handlers, is_asgi = self._parse_scope_to_route(scope=scope)
Expand All @@ -218,6 +226,43 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No
scope["route_handler"] = handler
await asgi_app(scope, receive, send)

async def lifespan(self, receive: "LifeSpanReceive", send: "LifeSpanSend") -> None:
"""Handles the ASGI "lifespan" event on application startup and
shutdown.
Args:
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None.
"""
message = await receive()
try:
if message["type"] == "lifespan.startup":
await self.startup()
startup_event: "LifeSpanStartupCompleteEvent" = {"type": "lifespan.startup.complete"}
await send(startup_event)
await receive()
except BaseException as e:
if message["type"] == "lifespan.startup":
startup_failure_event: "LifeSpanStartupFailedEvent" = {
"type": "lifespan.startup.failed",
"message": format_exc(),
}
await send(startup_failure_event)
else:
shutdown_failure_event: "LifeSpanShutdownFailedEvent" = {
"type": "lifespan.shutdown.failed",
"message": format_exc(),
}
await send(shutdown_failure_event)
raise e
else:
await self.shutdown()
shutdown_event: "LifeSpanShutdownCompleteEvent" = {"type": "lifespan.shutdown.complete"}
await send(shutdown_event)

async def _call_lifespan_handler(self, handler: "LifeSpanHandler") -> None:
"""Determines whether the lifecycle handler expects an argument, and if
so passes the `app.state` to it. If the handler is an async function,
Expand All @@ -244,7 +289,7 @@ async def startup(self) -> None:
for hook in self.app.before_startup:
await hook(self.app)

for handler in self.on_startup:
for handler in self.app.on_startup:
await self._call_lifespan_handler(handler)

for hook in self.app.after_startup:
Expand All @@ -262,7 +307,7 @@ async def shutdown(self) -> None:
for hook in self.app.before_shutdown:
await hook(self.app)

for handler in self.on_shutdown:
for handler in self.app.on_shutdown:
await self._call_lifespan_handler(handler)

for hook in self.app.after_shutdown:
Expand Down
2 changes: 1 addition & 1 deletion starlite/cache/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from asyncio import Lock
from typing import TYPE_CHECKING, Any, Optional, overload

from anyio import Lock
from typing_extensions import Protocol, runtime_checkable

from starlite.utils import is_async_callable
Expand Down
3 changes: 2 additions & 1 deletion starlite/cache/simple_cache_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from asyncio import Lock
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict

from anyio import Lock

from starlite.cache.base import CacheBackendProtocol


Expand Down
5 changes: 2 additions & 3 deletions starlite/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
)

from starlette.datastructures import URL, Address, Headers, URLPath
from starlette.requests import cookie_parser

from starlite.datastructures.state import State
from starlite.exceptions import ImproperlyConfiguredException
from starlite.parsers import parse_query_params
from starlite.parsers import parse_cookie_string, parse_query_params
from starlite.types.empty import Empty

if TYPE_CHECKING:
Expand Down Expand Up @@ -183,7 +182,7 @@ def cookies(self) -> Dict[str, str]:
cookies: Dict[str, str] = {}
cookie_header = self.headers.get("cookie")
if cookie_header:
cookies = cookie_parser(cookie_header)
cookies = parse_cookie_string(cookie_header)
self._cookies = self.scope["_cookies"] = cookies # type: ignore[typeddict-item]
return cast("Dict[str, str]", self._cookies)

Expand Down
2 changes: 1 addition & 1 deletion starlite/connection/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def receive_data(self, mode: "Literal['binary', 'text']") -> Union[str, by
if event["type"] == "websocket.disconnect":
raise WebSocketDisconnect(detail="disconnect event", code=event["code"])
if self.connection_state == "disconnect":
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE)
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover
return event.get("text") or "" if mode == "text" else event.get("bytes") or b""

async def receive_text(self) -> str:
Expand Down
6 changes: 2 additions & 4 deletions starlite/exceptions/http_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union

from starlette.exceptions import HTTPException as StarletteHTTPException

from starlite.exceptions.base_exceptions import StarLiteException
from starlite.status_codes import (
HTTP_400_BAD_REQUEST,
Expand All @@ -16,7 +14,7 @@
)


class HTTPException(StarletteHTTPException, StarLiteException):
class HTTPException(StarLiteException):
status_code: int = HTTP_500_INTERNAL_SERVER_ERROR
"""Exception status code."""
detail: str
Expand Down Expand Up @@ -46,7 +44,7 @@ def __init__(
extra: An extra mapping to attach to the exception.
"""

super().__init__(status_code or self.status_code)
self.status_code = status_code or self.status_code

if not detail:
detail = args[0] if args else HTTPStatus(self.status_code).phrase
Expand Down
10 changes: 3 additions & 7 deletions starlite/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Any

from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.errors import ServerErrorMiddleware

from starlite.connection import Request
Expand Down Expand Up @@ -61,18 +60,15 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No
if isinstance(e, WebSocketException):
code = e.code
reason = e.detail
elif isinstance(e, StarletteHTTPException):
code = e.status_code + 4000
reason = e.detail
else:
code = HTTP_500_INTERNAL_SERVER_ERROR + 4000
reason = repr(e)
code = 4000 + getattr(e, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
reason = getattr(e, "detail", repr(e))
event: "WebSocketCloseEvent" = {"type": "websocket.close", "code": code, "reason": reason}
await send(event)

def default_http_exception_handler(self, request: Request, exc: Exception) -> "Response[Any]":
"""Default handler for exceptions subclassed from HTTPException."""
status_code = exc.status_code if isinstance(exc, StarletteHTTPException) else HTTP_500_INTERNAL_SERVER_ERROR
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self.debug:
# in debug mode, we just use the serve_middleware to create an HTML formatted response for us
server_middleware = ServerErrorMiddleware(app=self) # type: ignore[arg-type]
Expand Down
2 changes: 1 addition & 1 deletion starlite/openapi/path_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from pydantic_openapi_schema.v3_1_0.operation import Operation
from pydantic_openapi_schema.v3_1_0.path_item import PathItem
from starlette.routing import get_name

from starlite.openapi.parameters import create_parameter_for_handler
from starlite.openapi.request_body import create_request_body
from starlite.openapi.responses import create_responses
from starlite.utils import get_name

if TYPE_CHECKING:
from pydantic import BaseModel
Expand Down
5 changes: 2 additions & 3 deletions starlite/openapi/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
MediaType as OpenAPISchemaMediaType,
)
from pydantic_openapi_schema.v3_1_0.schema import Schema
from starlette.routing import get_name
from typing_extensions import get_args, get_origin

from starlite.datastructures.response_containers import File, Redirect, Stream, Template
Expand All @@ -22,7 +21,7 @@
from starlite.openapi.schema import create_schema
from starlite.openapi.utils import pascal_case_to_text
from starlite.response import Response as StarliteResponse
from starlite.utils.model import create_parsed_model_field
from starlite.utils import create_parsed_model_field, get_enum_string_value, get_name

if TYPE_CHECKING:

Expand Down Expand Up @@ -68,7 +67,7 @@ def create_success_response(
return_annotation = signature.return_annotation
if signature.return_annotation is Template:
return_annotation = str # since templates return str
route_handler.media_type = MediaType.HTML
route_handler.media_type = get_enum_string_value(MediaType.HTML)
elif get_origin(signature.return_annotation) is StarliteResponse:
return_annotation = get_args(signature.return_annotation)[0] or Any
as_parsed_model_field = create_parsed_model_field(return_annotation)
Expand Down
19 changes: 18 additions & 1 deletion starlite/parsers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from contextlib import suppress
from functools import reduce
from http.cookies import _unquote as unquote_cookie
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from urllib.parse import parse_qsl
from urllib.parse import parse_qsl, unquote

from orjson import JSONDecodeError, loads
from pydantic.fields import SHAPE_LIST, SHAPE_SINGLETON
Expand Down Expand Up @@ -78,3 +79,19 @@ def parse_form_data(media_type: "RequestEncodingType", form_data: "FormMultiDict
if field.shape is SHAPE_SINGLETON and field.type_ in (UploadFile, MultipartUploadFile) and values_dict:
return list(values_dict.values())[0]
return values_dict


def parse_cookie_string(cookie_string: str) -> Dict[str, str]:
"""
Parses a cookie string into a dictionary of values.
Args:
cookie_string: A cookie string.
Returns:
A string keyed dictionary of values
"""
output: Dict[str, str] = {}
cookies = [cookie.split("=", 1) if "=" in cookie else ("", cookie) for cookie in cookie_string.split(";")]
for k, v in filter(lambda x: x[0] or x[1], ((k.strip(), v.strip()) for k, v in cookies)):
output[k] = unquote(unquote_cookie(v))
return output
3 changes: 2 additions & 1 deletion starlite/response/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
HTTP_204_NO_CONTENT,
HTTP_304_NOT_MODIFIED,
)
from starlite.utils.helpers import get_enum_string_value
from starlite.utils.serialization import default_serializer

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(
is_head_response: Whether the response should send only the headers ("head" request) or also the content.
"""
self.status_code = status_code
self.media_type = media_type
self.media_type = get_enum_string_value(media_type)
self.background = background
self.headers = headers or {}
self.cookies = cookies or []
Expand Down
3 changes: 1 addition & 2 deletions starlite/routes/asgi.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import TYPE_CHECKING, Any, cast

from starlette.routing import get_name

from starlite.connection import ASGIConnection
from starlite.controller import Controller
from starlite.enums import ScopeType
from starlite.routes.base import BaseRoute
from starlite.utils import get_name

if TYPE_CHECKING:
from starlite.handlers.asgi import ASGIRouteHandler
Expand Down
3 changes: 1 addition & 2 deletions starlite/routes/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

from anyio.to_thread import run_sync
from starlette.routing import get_name

from starlite.connection import Request
from starlite.controller import Controller
Expand All @@ -13,7 +12,7 @@
from starlite.response import RedirectResponse
from starlite.routes.base import BaseRoute
from starlite.signature import get_signature_model
from starlite.utils import is_async_callable
from starlite.utils import get_name, is_async_callable

if TYPE_CHECKING:
from starlite.handlers.http import HTTPRouteHandler
Expand Down
3 changes: 1 addition & 2 deletions starlite/routes/websocket.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, cast

from starlette.routing import get_name

from starlite.controller import Controller
from starlite.enums import ScopeType
from starlite.exceptions import ImproperlyConfiguredException
from starlite.routes.base import BaseRoute
from starlite.signature import get_signature_model
from starlite.utils import get_name

if TYPE_CHECKING:
from starlite.connection import WebSocket
Expand Down
Loading

0 comments on commit 28f0540

Please sign in to comment.