From 4d50a2aae73ad95a791945775c9e1770acdb6c20 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sat, 22 Oct 2022 19:58:33 +0200 Subject: [PATCH 1/4] removed starlette dependency for lifespan events --- starlite/app.py | 6 +-- starlite/asgi.py | 65 ++++++++++++++++++++++---- starlite/cache/base.py | 2 +- starlite/cache/simple_cache_backend.py | 3 +- 4 files changed, 61 insertions(+), 15 deletions(-) diff --git a/starlite/app.py b/starlite/app.py index e8c5cb3801..18ded76dc1 100644 --- a/starlite/app.py +++ b/starlite/app.py @@ -10,10 +10,10 @@ from typing_extensions import TypedDict from starlite.asgi import ( + ASGIRouter, PathParameterTypePathDesignator, PathParamNode, RouteMapNode, - StarliteASGIRouter, ) from starlite.config import AppConfig, CacheConfig, OpenAPIConfig from starlite.config.logging import get_logger_placeholder @@ -383,7 +383,7 @@ def __init__( self._static_paths.add(static_config.path) self.register(asgi(path=static_config.path, name=static_config.name)(static_config.to_static_files_app())) - self.asgi_router = StarliteASGIRouter(on_shutdown=self.on_shutdown, on_startup=self.on_startup, app=self) + self.asgi_router = ASGIRouter(app=self) self.asgi_handler = self._create_asgi_handler() async def __call__( @@ -406,7 +406,7 @@ async def __call__( """ scope["app"] = self if scope["type"] == "lifespan": - await self.asgi_router.lifespan(scope, receive, send) # type: ignore[arg-type] + await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type] return scope["state"] = {} await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type] diff --git a/starlite/asgi.py b/starlite/asgi.py index 95cae97ae1..29731427ec 100644 --- a/starlite/asgi.py +++ b/starlite/asgi.py @@ -2,6 +2,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from pathlib import Path +from traceback import format_exc from typing import ( TYPE_CHECKING, Any, @@ -22,7 +23,6 @@ parse_duration, parse_time, ) -from starlette.routing import Router as StarletteRouter from starlite.enums import ScopeType from starlite.exceptions import ( @@ -38,6 +38,12 @@ from starlite.types import ( ASGIApp, LifeSpanHandler, + LifeSpanReceive, + LifeSpanSend, + LifeSpanShutdownCompleteEvent, + LifeSpanShutdownFailedEvent, + LifeSpanStartupCompleteEvent, + LifeSpanStartupFailedEvent, Receive, RouteHandlerType, Scope, @@ -59,18 +65,20 @@ class PathParameterTypePathDesignator: ComponentsSet = Set[Union[str, PathParamPlaceholderType, TerminusNodePlaceholderType]] -class StarliteASGIRouter(StarletteRouter): - """This class extends the Starlette Router class and *is* the ASGI app used - in Starlite.""" +class ASGIRouter: + __slots__ = ("app",) def __init__( self, app: "Starlite", - on_shutdown: List["LifeSpanHandler"], - on_startup: List["LifeSpanHandler"], ) -> None: + """This class is the Starlite ASGI router. It handles both the ASGI + lifespan event and routing connection requests. + + Args: + app: The Starlite app instance + """ self.app = app - super().__init__(on_startup=on_startup, on_shutdown=on_shutdown) def _traverse_route_map(self, path: str, scope: "Scope") -> Tuple[RouteMapNode, List[str]]: """Traverses the application route mapping and retrieves the correct @@ -208,7 +216,7 @@ def _resolve_handler_node( node = asgi_handlers[ScopeType.WEBSOCKET] return node["asgi_app"], node["handler"] - async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: # type: ignore[override] + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: """The main entry point to the Router class.""" try: asgi_handlers, is_asgi = self._parse_scope_to_route(scope=scope) @@ -218,6 +226,43 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No scope["route_handler"] = handler await asgi_app(scope, receive, send) + async def lifespan(self, receive: "LifeSpanReceive", send: "LifeSpanSend") -> None: + """Handles the ASGI "lifespan" event on application startup and + shutdown. + + Args: + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None. + """ + message = await receive() + try: + if message["type"] == "lifespan.startup": + await self.startup() + startup_event: "LifeSpanStartupCompleteEvent" = {"type": "lifespan.startup.complete"} + await send(startup_event) + await receive() + except BaseException as e: + if message["type"] == "lifespan.startup": + startup_failure_event: "LifeSpanStartupFailedEvent" = { + "type": "lifespan.startup.failed", + "message": format_exc(), + } + await send(startup_failure_event) + else: + shutdown_failure_event: "LifeSpanShutdownFailedEvent" = { + "type": "lifespan.shutdown.failed", + "message": format_exc(), + } + await send(shutdown_failure_event) + raise e + else: + await self.shutdown() + shutdown_event: "LifeSpanShutdownCompleteEvent" = {"type": "lifespan.shutdown.complete"} + await send(shutdown_event) + async def _call_lifespan_handler(self, handler: "LifeSpanHandler") -> None: """Determines whether the lifecycle handler expects an argument, and if so passes the `app.state` to it. If the handler is an async function, @@ -244,7 +289,7 @@ async def startup(self) -> None: for hook in self.app.before_startup: await hook(self.app) - for handler in self.on_startup: + for handler in self.app.on_startup: await self._call_lifespan_handler(handler) for hook in self.app.after_startup: @@ -262,7 +307,7 @@ async def shutdown(self) -> None: for hook in self.app.before_shutdown: await hook(self.app) - for handler in self.on_shutdown: + for handler in self.app.on_shutdown: await self._call_lifespan_handler(handler) for hook in self.app.after_shutdown: diff --git a/starlite/cache/base.py b/starlite/cache/base.py index 622f32349e..d9faa7de3c 100644 --- a/starlite/cache/base.py +++ b/starlite/cache/base.py @@ -1,6 +1,6 @@ -from asyncio import Lock from typing import TYPE_CHECKING, Any, Optional, overload +from anyio import Lock from typing_extensions import Protocol, runtime_checkable from starlite.utils import is_async_callable diff --git a/starlite/cache/simple_cache_backend.py b/starlite/cache/simple_cache_backend.py index 1db8d2ad25..672f9bff76 100644 --- a/starlite/cache/simple_cache_backend.py +++ b/starlite/cache/simple_cache_backend.py @@ -1,8 +1,9 @@ -from asyncio import Lock from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Dict +from anyio import Lock + from starlite.cache.base import CacheBackendProtocol From 6b2fcc211c24dc71c82e87695c17e45c282c9917 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sun, 23 Oct 2022 09:32:17 +0200 Subject: [PATCH 2/4] updated tests --- starlite/connection/websocket.py | 2 +- starlite/utils/sync.py | 6 ++--- tests/asgi_router/test_asgi_router.py | 39 +++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 tests/asgi_router/test_asgi_router.py diff --git a/starlite/connection/websocket.py b/starlite/connection/websocket.py index 04699645cd..d9b20eeba4 100644 --- a/starlite/connection/websocket.py +++ b/starlite/connection/websocket.py @@ -200,7 +200,7 @@ async def receive_data(self, mode: "Literal['binary', 'text']") -> Union[str, by if event["type"] == "websocket.disconnect": raise WebSocketDisconnect(detail="disconnect event", code=event["code"]) if self.connection_state == "disconnect": - raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover return event.get("text") or "" if mode == "text" else event.get("bytes") or b"" async def receive_text(self) -> str: diff --git a/starlite/utils/sync.py b/starlite/utils/sync.py index 26961df671..0c8e621e50 100644 --- a/starlite/utils/sync.py +++ b/starlite/utils/sync.py @@ -1,5 +1,5 @@ from functools import partial -from inspect import getfullargspec, ismethod +from inspect import getfullargspec, isclass, ismethod from typing import ( Any, AsyncGenerator, @@ -34,10 +34,10 @@ def __init__(self, fn: Callable[P, T]) -> None: fn: Callable to wrap - can be any sync or async callable. """ - self.is_method = ismethod(fn) + self.is_method = ismethod(fn) or isclass(type(fn)) self.num_expected_args = len(getfullargspec(fn).args) - (1 if self.is_method else 0) self.wrapped_callable: Dict[Literal["fn"], Callable] = { - "fn": fn if is_async_callable(fn) else async_partial(fn) + "fn": fn if is_async_callable(fn) else async_partial(fn) # pyright: ignore } async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: diff --git a/tests/asgi_router/test_asgi_router.py b/tests/asgi_router/test_asgi_router.py new file mode 100644 index 0000000000..9e59f57a03 --- /dev/null +++ b/tests/asgi_router/test_asgi_router.py @@ -0,0 +1,39 @@ +import pytest + +from starlite.testing import create_test_client + + +class _LifeSpanCallable: + def __init__(self, should_raise: bool = False) -> None: + self.called = False + self.should_raise = should_raise + + def __call__(self) -> None: + self.called = True + if self.should_raise: + raise RuntimeError("damn") + + +def test_life_span_startup() -> None: + life_span_callable = _LifeSpanCallable() + with create_test_client([], on_startup=[life_span_callable]): + assert life_span_callable.called + + +def test_life_span_startup_error_handling() -> None: + life_span_callable = _LifeSpanCallable(should_raise=True) + with pytest.raises(RuntimeError), create_test_client([], on_startup=[life_span_callable]): + pass + + +def test_life_span_shutdown() -> None: + life_span_callable = _LifeSpanCallable() + with create_test_client([], on_shutdown=[life_span_callable]): + pass + assert life_span_callable.called + + +def test_life_span_shutdown_error_handling() -> None: + life_span_callable = _LifeSpanCallable(should_raise=True) + with pytest.raises(RuntimeError), create_test_client([], on_shutdown=[life_span_callable]): + pass From 39162293c2484188e09a30eff3cd3339b0eea6b9 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sun, 23 Oct 2022 10:37:09 +0200 Subject: [PATCH 3/4] fix async callable --- starlite/utils/sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlite/utils/sync.py b/starlite/utils/sync.py index 0c8e621e50..a2b8cf6636 100644 --- a/starlite/utils/sync.py +++ b/starlite/utils/sync.py @@ -1,5 +1,5 @@ from functools import partial -from inspect import getfullargspec, isclass, ismethod +from inspect import getfullargspec, ismethod from typing import ( Any, AsyncGenerator, @@ -34,7 +34,7 @@ def __init__(self, fn: Callable[P, T]) -> None: fn: Callable to wrap - can be any sync or async callable. """ - self.is_method = ismethod(fn) or isclass(type(fn)) + self.is_method = ismethod(fn) or (callable(fn) and ismethod(fn.__call__)) # type: ignore self.num_expected_args = len(getfullargspec(fn).args) - (1 if self.is_method else 0) self.wrapped_callable: Dict[Literal["fn"], Callable] = { "fn": fn if is_async_callable(fn) else async_partial(fn) # pyright: ignore From 62c0b69aaa6897ef7103ce43fefced618b433bf6 Mon Sep 17 00:00:00 2001 From: Na'aman Hirschfeld Date: Sun, 23 Oct 2022 13:13:17 +0200 Subject: [PATCH 4/4] added cookie parser (#645) * added cookie parser * address review comments * Replace get name (#646) * add get_name helper * added helper to handle enum * update enum safeguard * Removed starlette HTTPException as code dependency (#647) * removed starlette HTTPException as code dependency * address review comments --- starlite/connection/base.py | 5 ++- starlite/exceptions/http_exceptions.py | 6 ++-- starlite/middleware/exceptions.py | 10 ++---- starlite/openapi/path_item.py | 2 +- starlite/openapi/responses.py | 5 ++- starlite/parsers.py | 19 +++++++++- starlite/response/base.py | 3 +- starlite/routes/asgi.py | 3 +- starlite/routes/http.py | 3 +- starlite/routes/websocket.py | 3 +- starlite/utils/__init__.py | 7 ++-- starlite/utils/exception.py | 48 +++++++++++++++----------- starlite/utils/extractors.py | 4 +-- starlite/utils/helpers.py | 31 +++++++++++++++++ tests/test_parsers.py | 29 ++++++++++++++-- 15 files changed, 125 insertions(+), 53 deletions(-) create mode 100644 starlite/utils/helpers.py diff --git a/starlite/connection/base.py b/starlite/connection/base.py index 4f953e6882..e7523a2dd5 100644 --- a/starlite/connection/base.py +++ b/starlite/connection/base.py @@ -11,11 +11,10 @@ ) from starlette.datastructures import URL, Address, Headers, URLPath -from starlette.requests import cookie_parser from starlite.datastructures.state import State from starlite.exceptions import ImproperlyConfiguredException -from starlite.parsers import parse_query_params +from starlite.parsers import parse_cookie_string, parse_query_params from starlite.types.empty import Empty if TYPE_CHECKING: @@ -183,7 +182,7 @@ def cookies(self) -> Dict[str, str]: cookies: Dict[str, str] = {} cookie_header = self.headers.get("cookie") if cookie_header: - cookies = cookie_parser(cookie_header) + cookies = parse_cookie_string(cookie_header) self._cookies = self.scope["_cookies"] = cookies # type: ignore[typeddict-item] return cast("Dict[str, str]", self._cookies) diff --git a/starlite/exceptions/http_exceptions.py b/starlite/exceptions/http_exceptions.py index 2fd8fa85a2..ca2fc9b7ba 100644 --- a/starlite/exceptions/http_exceptions.py +++ b/starlite/exceptions/http_exceptions.py @@ -1,8 +1,6 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Union -from starlette.exceptions import HTTPException as StarletteHTTPException - from starlite.exceptions.base_exceptions import StarLiteException from starlite.status_codes import ( HTTP_400_BAD_REQUEST, @@ -16,7 +14,7 @@ ) -class HTTPException(StarletteHTTPException, StarLiteException): +class HTTPException(StarLiteException): status_code: int = HTTP_500_INTERNAL_SERVER_ERROR """Exception status code.""" detail: str @@ -46,7 +44,7 @@ def __init__( extra: An extra mapping to attach to the exception. """ - super().__init__(status_code or self.status_code) + self.status_code = status_code or self.status_code if not detail: detail = args[0] if args else HTTPStatus(self.status_code).phrase diff --git a/starlite/middleware/exceptions.py b/starlite/middleware/exceptions.py index fc78e9751e..0613a3408e 100644 --- a/starlite/middleware/exceptions.py +++ b/starlite/middleware/exceptions.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, Any -from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.errors import ServerErrorMiddleware from starlite.connection import Request @@ -61,18 +60,15 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No if isinstance(e, WebSocketException): code = e.code reason = e.detail - elif isinstance(e, StarletteHTTPException): - code = e.status_code + 4000 - reason = e.detail else: - code = HTTP_500_INTERNAL_SERVER_ERROR + 4000 - reason = repr(e) + code = 4000 + getattr(e, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) + reason = getattr(e, "detail", repr(e)) event: "WebSocketCloseEvent" = {"type": "websocket.close", "code": code, "reason": reason} await send(event) def default_http_exception_handler(self, request: Request, exc: Exception) -> "Response[Any]": """Default handler for exceptions subclassed from HTTPException.""" - status_code = exc.status_code if isinstance(exc, StarletteHTTPException) else HTTP_500_INTERNAL_SERVER_ERROR + status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self.debug: # in debug mode, we just use the serve_middleware to create an HTML formatted response for us server_middleware = ServerErrorMiddleware(app=self) # type: ignore[arg-type] diff --git a/starlite/openapi/path_item.py b/starlite/openapi/path_item.py index 3f7093e63c..0bc46076bb 100644 --- a/starlite/openapi/path_item.py +++ b/starlite/openapi/path_item.py @@ -2,11 +2,11 @@ from pydantic_openapi_schema.v3_1_0.operation import Operation from pydantic_openapi_schema.v3_1_0.path_item import PathItem -from starlette.routing import get_name from starlite.openapi.parameters import create_parameter_for_handler from starlite.openapi.request_body import create_request_body from starlite.openapi.responses import create_responses +from starlite.utils import get_name if TYPE_CHECKING: from pydantic import BaseModel diff --git a/starlite/openapi/responses.py b/starlite/openapi/responses.py index 1fde906e08..548e001488 100644 --- a/starlite/openapi/responses.py +++ b/starlite/openapi/responses.py @@ -8,7 +8,6 @@ MediaType as OpenAPISchemaMediaType, ) from pydantic_openapi_schema.v3_1_0.schema import Schema -from starlette.routing import get_name from typing_extensions import get_args, get_origin from starlite.datastructures.response_containers import File, Redirect, Stream, Template @@ -22,7 +21,7 @@ from starlite.openapi.schema import create_schema from starlite.openapi.utils import pascal_case_to_text from starlite.response import Response as StarliteResponse -from starlite.utils.model import create_parsed_model_field +from starlite.utils import create_parsed_model_field, get_enum_string_value, get_name if TYPE_CHECKING: @@ -68,7 +67,7 @@ def create_success_response( return_annotation = signature.return_annotation if signature.return_annotation is Template: return_annotation = str # since templates return str - route_handler.media_type = MediaType.HTML + route_handler.media_type = get_enum_string_value(MediaType.HTML) elif get_origin(signature.return_annotation) is StarliteResponse: return_annotation = get_args(signature.return_annotation)[0] or Any as_parsed_model_field = create_parsed_model_field(return_annotation) diff --git a/starlite/parsers.py b/starlite/parsers.py index 6fd019c1d5..1cfd955f49 100644 --- a/starlite/parsers.py +++ b/starlite/parsers.py @@ -1,7 +1,8 @@ from contextlib import suppress from functools import reduce +from http.cookies import _unquote as unquote_cookie from typing import TYPE_CHECKING, Any, Dict, List, Tuple -from urllib.parse import parse_qsl +from urllib.parse import parse_qsl, unquote from orjson import JSONDecodeError, loads from pydantic.fields import SHAPE_LIST, SHAPE_SINGLETON @@ -78,3 +79,19 @@ def parse_form_data(media_type: "RequestEncodingType", form_data: "FormMultiDict if field.shape is SHAPE_SINGLETON and field.type_ in (UploadFile, MultipartUploadFile) and values_dict: return list(values_dict.values())[0] return values_dict + + +def parse_cookie_string(cookie_string: str) -> Dict[str, str]: + """ + Parses a cookie string into a dictionary of values. + Args: + cookie_string: A cookie string. + + Returns: + A string keyed dictionary of values + """ + output: Dict[str, str] = {} + cookies = [cookie.split("=", 1) if "=" in cookie else ("", cookie) for cookie in cookie_string.split(";")] + for k, v in filter(lambda x: x[0] or x[1], ((k.strip(), v.strip()) for k, v in cookies)): + output[k] = unquote(unquote_cookie(v)) + return output diff --git a/starlite/response/base.py b/starlite/response/base.py index 52468c4d4e..c17f6c3fb2 100644 --- a/starlite/response/base.py +++ b/starlite/response/base.py @@ -23,6 +23,7 @@ HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED, ) +from starlite.utils.helpers import get_enum_string_value from starlite.utils.serialization import default_serializer if TYPE_CHECKING: @@ -82,7 +83,7 @@ def __init__( is_head_response: Whether the response should send only the headers ("head" request) or also the content. """ self.status_code = status_code - self.media_type = media_type + self.media_type = get_enum_string_value(media_type) self.background = background self.headers = headers or {} self.cookies = cookies or [] diff --git a/starlite/routes/asgi.py b/starlite/routes/asgi.py index 47684a9d21..0bd7d869a7 100644 --- a/starlite/routes/asgi.py +++ b/starlite/routes/asgi.py @@ -1,11 +1,10 @@ from typing import TYPE_CHECKING, Any, cast -from starlette.routing import get_name - from starlite.connection import ASGIConnection from starlite.controller import Controller from starlite.enums import ScopeType from starlite.routes.base import BaseRoute +from starlite.utils import get_name if TYPE_CHECKING: from starlite.handlers.asgi import ASGIRouteHandler diff --git a/starlite/routes/http.py b/starlite/routes/http.py index d8d6554b59..8719bdadf3 100644 --- a/starlite/routes/http.py +++ b/starlite/routes/http.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from anyio.to_thread import run_sync -from starlette.routing import get_name from starlite.connection import Request from starlite.controller import Controller @@ -13,7 +12,7 @@ from starlite.response import RedirectResponse from starlite.routes.base import BaseRoute from starlite.signature import get_signature_model -from starlite.utils import is_async_callable +from starlite.utils import get_name, is_async_callable if TYPE_CHECKING: from starlite.handlers.http import HTTPRouteHandler diff --git a/starlite/routes/websocket.py b/starlite/routes/websocket.py index 3eb477369a..113dcbfe0a 100644 --- a/starlite/routes/websocket.py +++ b/starlite/routes/websocket.py @@ -1,12 +1,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, cast -from starlette.routing import get_name - from starlite.controller import Controller from starlite.enums import ScopeType from starlite.exceptions import ImproperlyConfiguredException from starlite.routes.base import BaseRoute from starlite.signature import get_signature_model +from starlite.utils import get_name if TYPE_CHECKING: from starlite.connection import WebSocket diff --git a/starlite/utils/__init__.py b/starlite/utils/__init__.py index d375cf4222..4a7b0370ef 100644 --- a/starlite/utils/__init__.py +++ b/starlite/utils/__init__.py @@ -6,6 +6,7 @@ get_exception_handler, ) from .extractors import ConnectionDataExtractor, ResponseDataExtractor, obfuscate +from .helpers import get_enum_string_value, get_name from .model import ( convert_dataclass_to_model, convert_typeddict_to_model, @@ -38,7 +39,11 @@ "create_parsed_model_field", "default_serializer", "find_index", + "generate_csrf_hash", + "generate_csrf_token", + "get_enum_string_value", "get_exception_handler", + "get_name", "get_serializer_from_scope", "is_async_callable", "is_class_and_subclass", @@ -52,6 +57,4 @@ "obfuscate", "should_skip_dependency_validation", "unique", - "generate_csrf_hash", - "generate_csrf_token", ) diff --git a/starlite/utils/exception.py b/starlite/utils/exception.py index db447305b2..34183e2503 100644 --- a/starlite/utils/exception.py +++ b/starlite/utils/exception.py @@ -2,10 +2,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from pydantic import BaseModel -from starlette.exceptions import HTTPException as StarletteHTTPException from starlite.enums import MediaType -from starlite.exceptions.http_exceptions import HTTPException from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR if TYPE_CHECKING: @@ -36,15 +34,13 @@ def get_exception_handler(exception_handlers: "ExceptionHandlersMap", exc: Excep """ if not exception_handlers: return None - if isinstance(exc, (StarletteHTTPException, HTTPException)) and exc.status_code in exception_handlers: - return exception_handlers[exc.status_code] + status_code: Optional[int] = getattr(exc, "status_code", None) + if status_code in exception_handlers: + return exception_handlers[status_code] for cls in getmro(type(exc)): if cls in exception_handlers: return exception_handlers[cast("Type[Exception]", cls)] - if ( - not isinstance(exc, (StarletteHTTPException, HTTPException)) - and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers - ): + if not hasattr(exc, "status_code") and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers: return exception_handlers[HTTP_500_INTERNAL_SERVER_ERROR] return None @@ -59,6 +55,23 @@ class ExceptionResponseContent(BaseModel): extra: Optional[Union[Dict[str, Any], List[Any]]] = None """An extra mapping to attach to the exception.""" + def to_response(self) -> "Response": + """Creates a response from the model attributes. + + Returns: + A response instance. + """ + from starlite.response import ( # pylint: disable=import-outside-toplevel + Response, + ) + + return Response( + content=self.dict(exclude_none=True, exclude={"headers"}), + headers=self.headers, + media_type=MediaType.JSON, + status_code=self.status_code, + ) + def create_exception_response(exc: Exception) -> "Response": """Constructs a response from an exception. @@ -72,17 +85,10 @@ def create_exception_response(exc: Exception) -> "Response": Returns: Response: HTTP response constructed from exception details. """ - from starlite.response import Response # pylint: disable=import-outside-toplevel - - if isinstance(exc, (HTTPException, StarletteHTTPException)): - content = ExceptionResponseContent(detail=exc.detail, status_code=exc.status_code) - if isinstance(exc, HTTPException): - content.extra = exc.extra - else: - content = ExceptionResponseContent(detail=repr(exc), status_code=HTTP_500_INTERNAL_SERVER_ERROR) - return Response( - media_type=MediaType.JSON, - content=content.dict(exclude_none=True), - status_code=content.status_code, - headers=exc.headers if isinstance(exc, (HTTPException, StarletteHTTPException)) else None, + content = ExceptionResponseContent( + status_code=getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR), + detail=getattr(exc, "detail", repr(exc)), + headers=getattr(exc, "headers", None), + extra=getattr(exc, "extra", None), ) + return content.to_response() diff --git a/starlite/utils/extractors.py b/starlite/utils/extractors.py index ce16e21390..eb5fc32187 100644 --- a/starlite/utils/extractors.py +++ b/starlite/utils/extractors.py @@ -11,12 +11,12 @@ cast, ) -from starlette.requests import cookie_parser from typing_extensions import Literal, TypedDict from starlite.connection import Request from starlite.datastructures.upload_file import UploadFile from starlite.enums import HttpMethod, RequestEncodingType +from starlite.parsers import parse_cookie_string if TYPE_CHECKING: from starlite.connection import ASGIConnection @@ -410,6 +410,6 @@ def extract_cookies(self, messages: Tuple["HTTPResponseStartEvent", "HTTPRespons ) ) if cookie_string: - parsed_cookies = cookie_parser(cookie_string) + parsed_cookies = parse_cookie_string(cookie_string) return obfuscate(parsed_cookies, self.obfuscate_cookies) if self.obfuscate_cookies else parsed_cookies return {} diff --git a/starlite/utils/helpers.py b/starlite/utils/helpers.py new file mode 100644 index 0000000000..55feb5b464 --- /dev/null +++ b/starlite/utils/helpers.py @@ -0,0 +1,31 @@ +from enum import Enum +from typing import Any, Union, cast + + +def get_name(value: Any) -> str: + """Helper to get the '__name__' dunder of a value. + + Args: + value: An arbitrary value. + + Returns: + A name string. + """ + + if hasattr(value, "__name__"): + return cast("str", value.__name__) + return type(value).__name__ + + +def get_enum_string_value(value: Union[Enum, str]) -> str: + """A helper function to return the string value of a string enum. + + See: https://github.com/starlite-api/starlite/pull/633#issuecomment-1286519267 + + Args: + value: An enum or string. + + Returns: + A string. + """ + return cast("str", value.value) if isinstance(value, Enum) else value diff --git a/tests/test_parsers.py b/tests/test_parsers.py index eadce7d5a5..d98ca9a235 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -1,9 +1,12 @@ +from typing import Dict + +import pytest from pydantic import BaseConfig from pydantic.fields import ModelField -from starlite import RequestEncodingType +from starlite import Cookie, RequestEncodingType from starlite.datastructures import FormMultiDict -from starlite.parsers import parse_form_data, parse_query_params +from starlite.parsers import parse_cookie_string, parse_form_data, parse_query_params from starlite.testing import RequestFactory @@ -51,3 +54,25 @@ def test_parse_form_data() -> None: "healthy": True, "polluting": False, } + + +@pytest.mark.parametrize( + "cookie_string, expected", + ( + ("ABC = 123; efg = 456", {"ABC": "123", "efg": "456"}), + (("foo= ; bar="), {"foo": "", "bar": ""}), + ('foo="bar=123456789&name=moisheZuchmir"', {"foo": "bar=123456789&name=moisheZuchmir"}), + ("email=%20%22%2c%3b%2f", {"email": ' ",;/'}), + ("foo=%1;bar=bar", {"foo": "%1", "bar": "bar"}), + ("foo=bar;fizz ; buzz", {"": "buzz", "foo": "bar"}), + (" fizz; foo= bar", {"": "fizz", "foo": "bar"}), + ("foo=false;bar=bar;foo=true", {"bar": "bar", "foo": "true"}), + ("foo=;bar=bar;foo=boo", {"bar": "bar", "foo": "boo"}), + ( + Cookie(key="abc", value="123", path="/head", domain="localhost").to_header(header=""), + {"Domain": "localhost", "Path": "/head", "SameSite": "lax", "abc": "123"}, + ), + ), +) +def test_parse_cookie_string(cookie_string: str, expected: Dict[str, str]) -> None: + assert parse_cookie_string(cookie_string) == expected