diff --git a/.flake8 b/.flake8 index 47b222ae75..7e6c6e1e06 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,7 @@ [flake8] max-line-length = 120 max-complexity = 12 -ignore = E501, B008, W503, C408, B009, B023, C417, PT006, PT007, PT004, PT012, SIM401, E225, E203, SCS108 +ignore = E501,E225,W503,B008,E203 type-checking-pydantic-enabled = true type-checking-fastapi-enabled = true classmethod-decorators = @@ -9,5 +9,6 @@ classmethod-decorators = validator root_validator per-file-ignores = - examples/dependency_injection/dependency_non_optional_not_provided.py:E800 - starlite/types/builtin_types.py:E800,F401 + examples/*:SCS108 + starlite/types/builtin_types.py:F401 + tests/*:SCS108 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 757bbfd360..c396a43605 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: hooks: - id: blacken-docs - repo: https://github.com/pre-commit/mirrors-prettier - rev: "v3.0.0-alpha.2" + rev: "v3.0.0-alpha.3" hooks: - id: prettier exclude: docs @@ -70,12 +70,10 @@ repos: "flake8-print", "flake8-simplify", "flake8-type-checking", - "flake8-pytest-style", "flake8-implicit-str-concat", "flake8-noqa", "flake8-return", "flake8-secure-coding-standard", - "flake8-eradicate", "flake8-encodings", "flake8-use-fstring", "flake8-use-pathlib", @@ -85,22 +83,39 @@ repos: hooks: - id: flake8-markdown - repo: https://github.com/dosisod/refurb - rev: v1.3.0 + rev: v1.4.0 hooks: - id: refurb args: [--ignore, "120", --ignore, "128"] additional_dependencies: [ + aiomcache, + beautifulsoup4, + brotli, + cryptography, + freezegun, + httpx, + hypothesis, + mako, orjson, piccolo, picologging, pydantic, pydantic_factories, pydantic_openapi_schema, + pytest, + pyyaml, + redis, sqlalchemy, + sqlalchemy2-stubs, starlette, starlite_multipart, structlog, + tortoise-orm, + types-PyYAML, + types-beautifulsoup4, + types-freezegun, + types-redis, ] - repo: https://github.com/ariebovenberg/slotscheck rev: v0.14.1 @@ -131,7 +146,7 @@ repos: aiomcache, ] - repo: https://github.com/pycqa/pylint - rev: "v2.15.4" + rev: "v2.15.5" hooks: - id: pylint exclude: "test_*" @@ -155,37 +170,50 @@ repos: rev: "v0.982" hooks: - id: mypy - exclude: "test_starlette_tests" + args: [--show-traceback, --pdb] additional_dependencies: [ + aiomcache, + beautifulsoup4, + brotli, + cryptography, + freezegun, + httpx, + hypothesis, + mako, orjson, piccolo, picologging, pydantic, pydantic_factories, pydantic_openapi_schema, + pytest, + pyyaml, + redis, sqlalchemy, sqlalchemy2-stubs, starlette, starlite_multipart, structlog, + tortoise-orm, types-PyYAML, + types-beautifulsoup4, types-freezegun, types-redis, - redis, ] - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.275 + rev: v1.1.276 hooks: - id: pyright - exclude: "tests" additional_dependencies: [ + aiomcache, + beautifulsoup4, brotli, cryptography, freezegun, + httpx, hypothesis, - jinja2, mako, orjson, piccolo, @@ -196,11 +224,14 @@ repos: pytest, pyyaml, redis, - httpx, sqlalchemy, sqlalchemy2-stubs, starlette, starlite_multipart, structlog, - aiomcache, + tortoise-orm, + types-PyYAML, + types-beautifulsoup4, + types-freezegun, + types-redis, ] diff --git a/docs/reference/4-response.md b/docs/reference/4-response.md deleted file mode 100644 index 4c569db275..0000000000 --- a/docs/reference/4-response.md +++ /dev/null @@ -1,13 +0,0 @@ -# Responses - -::: starlite.response.Response - options: - members: - - __init__ - - serializer - - render - -::: starlite.response.TemplateResponse - options: - members: - - __init__ diff --git a/docs/reference/datastructures/5-response-containers.md b/docs/reference/datastructures/5-response-containers.md index ca917e5e5a..ef1962be19 100644 --- a/docs/reference/datastructures/5-response-containers.md +++ b/docs/reference/datastructures/5-response-containers.md @@ -4,34 +4,58 @@ options: members: - background - - headers - cookies + - encoding + - headers + - media_type - to_response ::: starlite.datastructures.File options: members: - - path + - background + - chunk_size + - content_disposition_type + - cookies + - encoding + - etag - filename + - headers + - media_type + - path - stat_result - to_response + - to_response ::: starlite.datastructures.Redirect options: members: + - background + - cookies + - headers - path - to_response ::: starlite.datastructures.Stream options: members: + - background + - cookies + - encoding + - headers - iterator + - media_type - to_response ::: starlite.datastructures.Template options: members: - - name + - background - context + - cookies + - encoding + - headers + - media_type + - name - to_response - update_context diff --git a/docs/reference/response/0-base.md b/docs/reference/response/0-base.md new file mode 100644 index 0000000000..07d9b4f6d9 --- /dev/null +++ b/docs/reference/response/0-base.md @@ -0,0 +1,18 @@ +# Base HTTP Response + +::: starlite.response.Response + options: + members: + - __init__ + - __call__ + - after_response + - content_length + - delete_cookie + - encoded_headers + - render + - send_body + - serializer + - set_cookie + - set_etag + - set_header + - start_response diff --git a/docs/reference/response/1-streaming.md b/docs/reference/response/1-streaming.md new file mode 100644 index 0000000000..a153a188e1 --- /dev/null +++ b/docs/reference/response/1-streaming.md @@ -0,0 +1,16 @@ +# Streaming Response + +::: starlite.response.StreamingResponse + options: + members: + - __init__ + - __call__ + - after_response + - content_length + - delete_cookie + - encoded_headers + - send_body + - set_cookie + - set_etag + - set_header + - start_response diff --git a/docs/reference/response/2-file.md b/docs/reference/response/2-file.md new file mode 100644 index 0000000000..08f8bbca45 --- /dev/null +++ b/docs/reference/response/2-file.md @@ -0,0 +1,16 @@ +# File Response + +::: starlite.response.FileResponse + options: + members: + - __init__ + - __call__ + - after_response + - content_length + - delete_cookie + - encoded_headers + - send_body + - set_cookie + - set_etag + - set_header + - start_response diff --git a/docs/reference/response/3-template.md b/docs/reference/response/3-template.md new file mode 100644 index 0000000000..24b433ae18 --- /dev/null +++ b/docs/reference/response/3-template.md @@ -0,0 +1,16 @@ +# Template Response + +::: starlite.response.TemplateResponse + options: + members: + - __init__ + - __call__ + - after_response + - content_length + - delete_cookie + - encoded_headers + - send_body + - set_cookie + - set_etag + - set_header + - start_response diff --git a/docs/reference/response/4-redirect.md b/docs/reference/response/4-redirect.md new file mode 100644 index 0000000000..bff093726e --- /dev/null +++ b/docs/reference/response/4-redirect.md @@ -0,0 +1,16 @@ +# Redirect Response + +::: starlite.response.RedirectResponse + options: + members: + - __init__ + - __call__ + - after_response + - content_length + - delete_cookie + - encoded_headers + - send_body + - set_cookie + - set_etag + - set_header + - start_response diff --git a/docs/reference/utils/0-predicate-utils.md b/docs/reference/utils/0-predicate-utils.md index 781e56f86c..d5c31f5576 100644 --- a/docs/reference/utils/0-predicate-utils.md +++ b/docs/reference/utils/0-predicate-utils.md @@ -4,8 +4,6 @@ ::: starlite.utils.predicates.T -::: starlite.utils.predicates.is_async_callable - ::: starlite.utils.predicates.is_class_and_subclass ::: starlite.utils.predicates.is_dataclass_class_or_instance_typeguard diff --git a/docs/reference/utils/1-sync-utils.md b/docs/reference/utils/1-sync-utils.md index 6f0d7d1673..f014f8b43c 100644 --- a/docs/reference/utils/1-sync-utils.md +++ b/docs/reference/utils/1-sync-utils.md @@ -4,6 +4,8 @@ ::: starlite.utils.sync.T +::: starlite.utils.is_async_callable + ::: starlite.utils.AsyncCallable options: members: @@ -13,3 +15,8 @@ ::: starlite.utils.as_async_callable_list ::: starlite.utils.async_partial + +::: starlite.utils.AsyncIteratorWrapper + options: + members: + - __init__ diff --git a/docs/usage/12-openapi/2-route-handler-configuration.md b/docs/usage/12-openapi/2-route-handler-configuration.md index 94f027abff..309de5526c 100644 --- a/docs/usage/12-openapi/2-route-handler-configuration.md +++ b/docs/usage/12-openapi/2-route-handler-configuration.md @@ -34,35 +34,35 @@ You can also modify the generated schema for the route handler using the followi The expected content should be based on a Pydantic model describing its structure. It can also include a description and the expected media type. For example: - ```python - from datetime import datetime - from typing import Optional +```python +from datetime import datetime +from typing import Optional - from pydantic import BaseModel +from pydantic import BaseModel - from starlite import ResponseSpec, get +from starlite import ResponseSpec, get - class Item(BaseModel): - ... +class Item(BaseModel): + ... - class ItemNotFound(BaseModel): - was_removed: bool - removed_at: Optional[datetime] +class ItemNotFound(BaseModel): + was_removed: bool + removed_at: Optional[datetime] - @get( - path="/items/{pk:int}", - responses={ - 404: ResponseSpec( - model=ItemNotFound, description="Item was removed or not found" - ) - }, - ) - def retrieve_item(pk: int) -> Item: - ... - ``` +@get( + path="/items/{pk:int}", + responses={ + 404: ResponseSpec( + model=ItemNotFound, description="Item was removed or not found" + ) + }, +) +def retrieve_item(pk: int) -> Item: + ... +``` You can also specify `security` and `tags` on higher level of the application, e.g. on a controller, router or the app instance itself. For example: diff --git a/docs/usage/14-testing.md b/docs/usage/14-testing.md index d1dbf8e0e3..f68142fdbb 100644 --- a/docs/usage/14-testing.md +++ b/docs/usage/14-testing.md @@ -29,7 +29,7 @@ app = Starlite(route_handlers=[health_check]) We would then test it using the test client like so: ```python title="tests/test_health_check.py" -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient from my_app.main import app @@ -60,7 +60,7 @@ def test_client() -> TestClient: We would then be able to rewrite our test like so: ```python title="tests/test_health_check.py" -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient @@ -174,7 +174,7 @@ here you can also pass individual values. For example, you can do this: ```python title="my_app/tests/test_health_check.py" -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client from my_app.main import health_check @@ -190,7 +190,7 @@ def test_health_check(): But also this: ```python title="my_app/tests/test_health_check.py" -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client from my_app.main import health_check @@ -237,7 +237,7 @@ We could test the `/item` route like so: ```python title="tests/conftest.py" import pytest -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite import Provide, create_test_client from my_app.main import Service, Item, get_item @@ -269,7 +269,7 @@ from typing import Protocol, runtime_checkable import pytest from pydantic import BaseModel from pydantic_factories import ModelFactory -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite import Provide, get from starlite.testing import create_test_client diff --git a/docs/usage/16-templating/2-template-functions.md b/docs/usage/16-templating/2-template-functions.md index db6722adeb..4c09265f9a 100644 --- a/docs/usage/16-templating/2-template-functions.md +++ b/docs/usage/16-templating/2-template-functions.md @@ -24,7 +24,7 @@ URLs for static files can be created using the `url_for_static_asset` function. The Starlite [TemplateEngineProtocol][starlite.template.base.TemplateEngineProtocol] specifies the method `register_template_callable` that allows defining a custom callable on a template engine. This method is implemented -for the two built in engine and it can be used to register callables that will be inject on the template. The callable +for the two built in engines, and it can be used to register callables that will be injected into the template. The callable should expect one argument - the context dictionary. It can be any callable - a function, method or class that defines the call method. For example: @@ -40,7 +40,7 @@ def my_template_function(ctx: dict) -> str: template_config.engine.register_template_callable( - "check_context_key", my_template_function + "check_context_key", template_callable=my_template_function ) app = Starlite( diff --git a/docs/usage/17-exceptions.md b/docs/usage/17-exceptions.md index f57122a36b..3e77939169 100644 --- a/docs/usage/17-exceptions.md +++ b/docs/usage/17-exceptions.md @@ -55,7 +55,7 @@ or `exception classes`, to callables. For example, if you would like to replace handler that returns plain-text responses you could do this: ```python -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR +from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from starlite import HTTPException, MediaType, Request, Response, Starlite @@ -84,7 +84,7 @@ The above will define a top level exception handler that will apply the `plain_t exceptions that inherit from `HTTPException`. You could of course be more granular: ```python -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR +from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from starlite import ValidationException, Request, Response, Starlite @@ -175,10 +175,10 @@ layers for this purpose. ```python import logging -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR -from starlette.responses import Response +from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR +from starlite.response import Response from starlite.utils import create_exception_response -from starlite.types import Request +from starlite.connection import Request from starlite import Starlite logger = logging.getLogger(__name__) diff --git a/docs/usage/2-route-handlers/3-asgi-route-handlers.md b/docs/usage/2-route-handlers/3-asgi-route-handlers.md index 2d76a2c1cb..988f5db497 100644 --- a/docs/usage/2-route-handlers/3-asgi-route-handlers.md +++ b/docs/usage/2-route-handlers/3-asgi-route-handlers.md @@ -5,7 +5,7 @@ If you need to write your own ASGI application, you can do so using the `asgi` d ```python from starlite.types import Scope, Receive, Send from starlite.enums import MediaType -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite import Response, asgi @@ -32,7 +32,7 @@ the code below is equivalent to the one above: ```python from starlite.types import Scope, Receive, Send from starlite.enums import MediaType -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite import ASGIRouteHandler, Response diff --git a/docs/usage/5-responses/10-custom-responses.md b/docs/usage/5-responses/10-custom-responses.md index c6c84528c4..55d78195ab 100644 --- a/docs/usage/5-responses/10-custom-responses.md +++ b/docs/usage/5-responses/10-custom-responses.md @@ -1,9 +1,8 @@ # Custom Responses -You can use a subclass of `starlite.responses.Response` and specify it as the response class using the `response_class` -kwarg. - -For example, lets say we want to handle subclasses of `Document` from the `elasticsearch_dsl` package as shown below: +You can use a subclass of [Response][starlite.response.Response] and specify it as the response class using +the `response_class` kwarg. For example, lets say we want to handle subclasses of `Document` from +the `elasticsearch_dsl` package as shown below: ```python from elasticsearch_dsl import Document, Integer, Keyword @@ -22,8 +21,8 @@ We could of course convert it to a dictionary of values in the route handler, an return `Document` subclasses in many route handlers, it makes sense to create a custom response to handle the serialization. -We will therefore create a subclass of `starlite.response.Response` that implements a serializer method that is capable -of handling `Document` subclasses: +We will therefore create a subclass of [Response][starlite.response.Response] that implements a serializer method that +is capable of handling `Document` subclasses: ```python from typing import Any, Dict diff --git a/docs/usage/5-responses/11-background-tasks.md b/docs/usage/5-responses/11-background-tasks.md new file mode 100644 index 0000000000..27842978a7 --- /dev/null +++ b/docs/usage/5-responses/11-background-tasks.md @@ -0,0 +1,49 @@ +# Background Tasks + +All Starlite responses and response containers (e.g. `File`, `Template` etc.) allow passing in a `background_task` +kwarg. This kwarg accepts either an instance of [BackgroundTask][starlite.datastructures.background_tasks.BackgroundTask] +or +an instance of [BackgroundTasks][starlite.datastructures.background_tasks.BackgroundTasks], which wraps an iterable +of [BackgroundTask][starlite.datastructures.background_tasks.BackgroundTask] instances. + +A background task is a sync or async callable (function, method or class that implements the `__call__` dunder method) +that will be called after the response finishes sending the data. + +Thus, in the following example the passed in background task will be executed after the response sends: + +```python +import logging + +from starlite import BackgroundTask, get + +logger = logging.getLogger(__name__) + + +async def logging_task(identifier: str, message: str) -> None: + logger.info(f"{identifier}: {message}") + + +@get("/", background=BackgroundTask(logging_task, "greeter", message="was called")) +def greeter() -> dict[str, str]: + return {"hello": "world"} +``` + +When the `greeter` handler is called, the logging task will be called with any `*args` and `**kwargs` passed into the +`BackgroundTask`. + +!!! note +In the above example `"greeter"` is an arg and `message="was called"` is a kwarg. The function signature of +`logging_task` allows for this, so this should pose no problem. Starlite uses [ParamSpec][typing.ParamSpec] to ensure +that a [BackgroundTask][starlite.datastructures.background_tasks.BackgroundTask] is properly typed, so will get +type checking for any passed in args and kwargs. + +## Executing Multiple BackgroundTasks + +You can also use the [BackgroundTasks][starlite.datastructures.background_tasks.BackgroundTasks] class instead, and pass +to it an iterable (list, tuple etc.) of [BackgroundTask][starlite.datastructures.background_tasks.BackgroundTask] +instances. This class accepts one optional kwargs aside from the tasks - `run_in_task_group`, which is a boolean flag +that defaults to `False`. If you set this value to `True` than the tasks will run concurrently, using +an [anyio.task_group](https://anyio.readthedocs.io/en/stable/tasks.html). + +!!! note + Setting `run_in_task_group` to `True` will not preserve execution order. diff --git a/docs/usage/5-responses/2-status-codes.md b/docs/usage/5-responses/2-status-codes.md index faa44f9398..44911c5c75 100644 --- a/docs/usage/5-responses/2-status-codes.md +++ b/docs/usage/5-responses/2-status-codes.md @@ -5,7 +5,7 @@ You can control the response `status_code` by setting the corresponding kwarg to ```python from pydantic import BaseModel from starlite import get -from starlette.status import HTTP_202_ACCEPTED +from starlite.status_codes import HTTP_202_ACCEPTED class Resource(BaseModel): @@ -37,6 +37,6 @@ If `status_code` is not set by the user, the following defaults are used: !!! tip While you can write integers as the value for `status_code`, e.g. `200`, it's best practice to use constants (also in - tests). Starlette includes easy to use statuses that are exported from `starlette.status`, e.g. `HTTP_200_OK` + tests). Starlite includes easy to use statuses that are exported from `starlite.status_codes`, e.g. `HTTP_200_OK` and `HTTP_201_CREATED`. Another option is the `http.HTTPStatus`enum from the standard library, which also offers extra functionality. For this see [the standard library documentation](https://docs.python.org/3/library/http.html#http.HTTPStatus). diff --git a/docs/usage/5-responses/3-returning-responses.md b/docs/usage/5-responses/3-returning-responses.md index 11e2dc6520..ef9b9ac944 100644 --- a/docs/usage/5-responses/3-returning-responses.md +++ b/docs/usage/5-responses/3-returning-responses.md @@ -3,7 +3,7 @@ While the default response handling fits most use cases, in some cases you need to be able to return a response instance directly. -Starlite allows you to return any class inheriting from the `starlette.responses.Response` class. Thus, the below +Starlite allows you to return any class inheriting from the [Response][starlite.response.Response] class. Thus, the below example will work perfectly fine: ```python @@ -11,7 +11,7 @@ from pydantic import BaseModel from starlite import Response, get from starlite.datastructures import Cookie from starlite.enums import MediaType -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK class Resource(BaseModel): @@ -33,34 +33,103 @@ def retrieve_resource() -> Response[Resource]: ) ``` -The caveat of using a Starlette response though is that Starlite will not be able to infer the OpenAPI documentation. +!!! important + In the case of the builtin [TemplateResponse][starlite.response.TemplateResponse], + [FileResponse][starlite.response.FileResponse], [StreamingResponse][starlite.response.StreamingResponse] and + [RedirectResponse][starlite.response.RedirectResponse] you should use the response "response containers", otherwise + OpenAPI documentation will not be generated correctly. For more details see the respective documentation sections + for the [Template](9-template-responses.md), [File](7-file-responses.md), [Stream](8-streaming-responses.md) + and [Redirect](6-redirect-responses.md). -## Annotated Responses +## Annotating Responses -To solve this issue, use you can use the `starlite.response.Response` class which supports type annotations: +As you can see above, the [Response][starlite.response.Response] class accepts a generic argument. This allows Starlite +to infer the response body when generating the OpenAPI docs. + +!!! note + If the generic argument is not provided, and thus defaults to `Any`, the OpenAPI docs will be imprecise. So make sure + to type this argument even when returning an empty or `null` body, i.e. use `None`. + +## Returning ASGI Applications + +Starlite also supports returning ASGI applications directly, as you would responses. For example: ```python -from pydantic import BaseModel -from starlite import Response, get -from starlite.datastructures import Cookie +from starlite import get +from starlite.types import ASGIApp, Receive, Scope, Send -class Resource(BaseModel): - id: int - name: str +@get("/") +def handler() -> ASGIApp: + async def my_asgi_app(scope: Scope, receive: Receive, send: Send) -> None: + ... + return my_asgi_app +``` -@get("/resources") -def retrieve_resource() -> Response[Resource]: - return Response( - Resource( - id=1, - name="my resource", - headers={"MY-HEADER": "xyz"}, - cookies=[Cookie("my-cookie", value="abc")], - ) - ) +### What is an ASGI Application? + +An ASGI application in this context is any async callable (function, class method or simply a class that implements +that special `__call__` dunder method) that accepts the three ASGI arguments: `scope`, `receive` and `send`. + +For example, all the following examples are ASGI applications: + +#### Function ASGI Application + +```python +from starlite.types import Receive, Scope, Send + + +async def my_asgi_app_function(scope: Scope, receive: Receive, send: Send) -> None: + # do something here + ... +``` + +#### Method ASGI Application + +```python +from starlite.types import Receive, Scope, Send + + +class MyClass: + async def my_asgi_app_method( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + # do something here + ... +``` + +#### Class ASGI Application + +```python +from starlite.types import Receive, Scope, Send + + +class ASGIApp: + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + # do something here + ... +``` + +### Returning Other Library Responses + +Because you can return any ASGI Application from a route handler, you can also use any ASGI application from other +libraries. For example, you can return the response classes from Starlette or FastAPI directly from route handlers: + +```python +from starlette.responses import JSONResponse + +from starlite import get +from starlite.types import ASGIApp + + +@get("/") +def handler() -> ASGIApp: + return JSONResponse(content={"hello": "world"}) # type: ignore ``` -As you can see above, the `starlite.response.Response` class accepts a generic argument - in this case the pydantic -model `Resource`. This allows Starlite to infer from the `Response` type the correct typing for OpenAPI generation. +!!! important + Starlite offers strong typing for the ASGI arguments. Other libraries often offer less strict typing, which might + cause type checkers to complain when using ASGI apps from them inside Starlite. + For the time being, the only solution is to add `# type: ignore` comments in the pertinent places. + Nonetheless, the above example will work perfectly fine. diff --git a/docs/usage/5-responses/4-response-headers.md b/docs/usage/5-responses/4-response-headers.md index a4cfcc39c1..3814363769 100644 --- a/docs/usage/5-responses/4-response-headers.md +++ b/docs/usage/5-responses/4-response-headers.md @@ -77,7 +77,7 @@ as you see fit, e.g.: ```python from pydantic import BaseModel -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite import Response, get from starlite.datastructures import ResponseHeader from starlite.enums import MediaType @@ -124,7 +124,7 @@ the headers on the corresponding layer: from random import randint from pydantic import BaseModel -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite import Response, Router, get from starlite.datastructures import ResponseHeader from starlite.enums import MediaType diff --git a/docs/usage/5-responses/5-response-cookies.md b/docs/usage/5-responses/5-response-cookies.md index 5a9d150d7d..c9f82a8668 100644 --- a/docs/usage/5-responses/5-response-cookies.md +++ b/docs/usage/5-responses/5-response-cookies.md @@ -111,7 +111,7 @@ as you see fit, e.g.: from random import randint from pydantic import BaseModel -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite import Response, get from starlite.datastructures import Cookie from starlite.enums import MediaType @@ -210,7 +210,7 @@ different value range: from random import randint from pydantic import BaseModel -from starlette.status import HTTP_200_OK +from starlite.status_codes import HTTP_200_OK from starlite import Response, Router, get from starlite.datastructures import Cookie from starlite.enums import MediaType diff --git a/docs/usage/5-responses/6-redirect-responses.md b/docs/usage/5-responses/6-redirect-responses.md index 1ff4f2270b..d9126c22c9 100644 --- a/docs/usage/5-responses/6-redirect-responses.md +++ b/docs/usage/5-responses/6-redirect-responses.md @@ -6,9 +6,8 @@ status code in the 30x range. In Starlite, a redirect response looks like this: ```python -from starlette.status import HTTP_307_TEMPORARY_REDIRECT -from starlite import get -from starlite.datastructures import Redirect +from starlite.status_codes import HTTP_307_TEMPORARY_REDIRECT +from starlite import Redirect, get @get(path="/some-path", status_code=HTTP_307_TEMPORARY_REDIRECT) diff --git a/docs/usage/5-responses/7-file-responses.md b/docs/usage/5-responses/7-file-responses.md index de81b6b7ff..c10f4a35f3 100644 --- a/docs/usage/5-responses/7-file-responses.md +++ b/docs/usage/5-responses/7-file-responses.md @@ -4,8 +4,7 @@ File responses send a file: ```python from pathlib import Path -from starlite import get -from starlite.datastructures import File +from starlite import File, get @get(path="/file-download") @@ -33,8 +32,7 @@ For example: ```python from pathlib import Path -from starlite import get -from starlite.datastructures import File +from starlite import File, get @get(path="/file-download", media_type="application/pdf") diff --git a/docs/usage/5-responses/8-streaming-responses.md b/docs/usage/5-responses/8-streaming-responses.md index d94955efc9..6692e3a326 100644 --- a/docs/usage/5-responses/8-streaming-responses.md +++ b/docs/usage/5-responses/8-streaming-responses.md @@ -5,8 +5,7 @@ To return a streaming response use the `Stream` class. The Stream class receives ```python from typing import AsyncGenerator from asyncio import sleep -from starlite import get -from starlite.datastructures import Stream +from starlite import Stream, get from datetime import datetime from orjson import dumps diff --git a/docs/usage/7-middleware/2-creating-middleware/2-using-middleware-protocol.md b/docs/usage/7-middleware/2-creating-middleware/2-using-middleware-protocol.md index 916761ef0b..6667a95f79 100644 --- a/docs/usage/7-middleware/2-creating-middleware/2-using-middleware-protocol.md +++ b/docs/usage/7-middleware/2-creating-middleware/2-using-middleware-protocol.md @@ -63,8 +63,9 @@ another example - redirecting the request to a different url from a middleware: ```python from starlite.types import ASGIApp, Receive, Scope, Send -from starlette.responses import RedirectResponse -from starlette.status import HTTP_307_TEMPORARY_REDIRECT +from starlite.status_codes import HTTP_307_TEMPORARY_REDIRECT + +from starlite.response import RedirectResponse from starlite import Request from starlite.middleware.base import MiddlewareProtocol diff --git a/docs/usage/7-middleware/3-builtin-middlewares/2-allowed-hosts-middleware.md b/docs/usage/7-middleware/3-builtin-middlewares/2-allowed-hosts-middleware.md index 2e3802d992..9020d74e6d 100644 --- a/docs/usage/7-middleware/3-builtin-middlewares/2-allowed-hosts-middleware.md +++ b/docs/usage/7-middleware/3-builtin-middlewares/2-allowed-hosts-middleware.md @@ -8,7 +8,7 @@ trusted hosts to the Starlite constructor: from starlite import Starlite app = Starlite( - request_handlers=[...], allowed_hosts=["*.example.com", "www.wikipedia.org"] + route_handlers=[...], allowed_hosts=["*.example.com", "www.wikipedia.org"] ) ``` diff --git a/docs/usage/7-middleware/3-builtin-middlewares/4-compression-middleware.md b/docs/usage/7-middleware/3-builtin-middlewares/4-compression-middleware.md index 354059be5f..31b680fe14 100644 --- a/docs/usage/7-middleware/3-builtin-middlewares/4-compression-middleware.md +++ b/docs/usage/7-middleware/3-builtin-middlewares/4-compression-middleware.md @@ -17,7 +17,7 @@ You can configure the following additional gzip-specific values: from starlite import Starlite, CompressionConfig app = Starlite( - request_handlers=[...], + route_handlers=[...], compression_config=CompressionConfig(backend="gzip", gzip_compress_level=9), ) ``` @@ -42,7 +42,7 @@ from starlite import Starlite from starlite.config import CompressionConfig app = Starlite( - request_handlers=[...], + route_handlers=[...], compression_config=CompressionConfig(backend="brotli", brotli_gzip_fallback=True), ) ``` diff --git a/docs/usage/8-authentication/1-abstract-authentication-middleware.md b/docs/usage/8-authentication/1-abstract-authentication-middleware.md index 3a27a0e8a0..ff0ef75321 100644 --- a/docs/usage/8-authentication/1-abstract-authentication-middleware.md +++ b/docs/usage/8-authentication/1-abstract-authentication-middleware.md @@ -229,7 +229,7 @@ async def site_index() -> Response: raise NotFoundException("Site index was not found") -app = Starlite(request_handlers=[site_index], middleware=[auth_mw]) +app = Starlite(route_handlers=[site_index], middleware=[auth_mw]) ``` And of course use the same kind of mechanism for dependencies: diff --git a/examples/application_hooks/after_exception_hook.py b/examples/application_hooks/after_exception_hook.py index 43c474cb98..8bab561b2b 100644 --- a/examples/application_hooks/after_exception_hook.py +++ b/examples/application_hooks/after_exception_hook.py @@ -1,9 +1,8 @@ import logging from typing import TYPE_CHECKING -from starlette.status import HTTP_400_BAD_REQUEST - from starlite import HTTPException, Starlite, get +from starlite.status_codes import HTTP_400_BAD_REQUEST logger = logging.getLogger() diff --git a/examples/tests/application_hooks/test_application_before_send.py b/examples/tests/application_hooks/test_application_before_send.py index 62b2143432..cdefe3b8e9 100644 --- a/examples/tests/application_hooks/test_application_before_send.py +++ b/examples/tests/application_hooks/test_application_before_send.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK - from examples.application_hooks import before_send_hook +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient diff --git a/examples/tests/dependency_injection/test_dependency_default_value_no_dependency_fn.py b/examples/tests/dependency_injection/test_dependency_default_value_no_dependency_fn.py index 29e8faaf78..439befc7b9 100644 --- a/examples/tests/dependency_injection/test_dependency_default_value_no_dependency_fn.py +++ b/examples/tests/dependency_injection/test_dependency_default_value_no_dependency_fn.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK - from examples.dependency_injection import dependency_default_value_no_dependency_fn +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient diff --git a/examples/tests/dependency_injection/test_dependency_default_value_with_dependency_fn.py b/examples/tests/dependency_injection/test_dependency_default_value_with_dependency_fn.py index f5e9f4afa4..e5572b65e3 100644 --- a/examples/tests/dependency_injection/test_dependency_default_value_with_dependency_fn.py +++ b/examples/tests/dependency_injection/test_dependency_default_value_with_dependency_fn.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK - from examples.dependency_injection import dependency_default_value_with_dependency_fn +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient diff --git a/examples/tests/dependency_injection/test_dependency_skip_validation.py b/examples/tests/dependency_injection/test_dependency_skip_validation.py index b03306dde1..fb3bc85b9c 100644 --- a/examples/tests/dependency_injection/test_dependency_skip_validation.py +++ b/examples/tests/dependency_injection/test_dependency_skip_validation.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK - from examples.dependency_injection import dependency_skip_validation +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient diff --git a/examples/tests/dependency_injection/test_dependency_validation_error.py b/examples/tests/dependency_injection/test_dependency_validation_error.py index 0c5e11993d..58b017bfd2 100644 --- a/examples/tests/dependency_injection/test_dependency_validation_error.py +++ b/examples/tests/dependency_injection/test_dependency_validation_error.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR - from examples.dependency_injection import dependency_validation_error +from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from starlite.testing import TestClient diff --git a/examples/tests/middleware/test_session_middleware.py b/examples/tests/middleware/test_session_middleware.py index f79104957a..6ddfd4704a 100644 --- a/examples/tests/middleware/test_session_middleware.py +++ b/examples/tests/middleware/test_session_middleware.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT - from examples.middleware import session_middleware +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite.testing import TestClient diff --git a/examples/tests/test_hello_world.py b/examples/tests/test_hello_world.py index 8483177616..a1afddc211 100644 --- a/examples/tests/test_hello_world.py +++ b/examples/tests/test_hello_world.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK - from examples import hello_world +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient diff --git a/examples/tests/test_startup_and_shutdown.py b/examples/tests/test_startup_and_shutdown.py index c9f5c169b4..295fdfbba1 100644 --- a/examples/tests/test_startup_and_shutdown.py +++ b/examples/tests/test_startup_and_shutdown.py @@ -6,7 +6,7 @@ from starlite.testing import TestClient if TYPE_CHECKING: - from pytest import MonkeyPatch # noqa:PT013 + from pytest import MonkeyPatch class FakeAsyncEngine: diff --git a/examples/tests/test_using_application_state.py b/examples/tests/test_using_application_state.py index dd780aaf1d..112f2b24a2 100644 --- a/examples/tests/test_using_application_state.py +++ b/examples/tests/test_using_application_state.py @@ -1,9 +1,8 @@ import logging from typing import Any -from starlette.status import HTTP_200_OK - from examples import using_application_state +from starlite.status_codes import HTTP_200_OK from starlite.testing import TestClient diff --git a/mkdocs.yml b/mkdocs.yml index a149a60012..fc7c6f7e46 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -77,7 +77,7 @@ nav: - usage/4-request-data/1-the-body-function.md - usage/4-request-data/2-url-encoded-form-data.md - usage/4-request-data/3-multipart-form-data.md - - Returning Responses: + - Responses: - usage/5-responses/0-responses-intro.md - usage/5-responses/1-media-type.md - usage/5-responses/2-status-codes.md @@ -89,6 +89,7 @@ nav: - usage/5-responses/8-streaming-responses.md - usage/5-responses/9-template-responses.md - usage/5-responses/10-custom-responses.md + - usage/5-responses/11-background-tasks.md - Dependency Injection: - usage/6-dependency-injection/0-dependency-injection-intro.md - usage/6-dependency-injection/1-dependency-kwargs.md @@ -151,7 +152,6 @@ nav: - reference/1-app.md - reference/2-router.md - reference/3-controller.md - - reference/4-response.md - reference/5-dto.md - reference/6-enums.md - Cache: @@ -173,6 +173,12 @@ nav: - reference/connection/0-asgi-connection.md - reference/connection/1-request.md - reference/connection/2-websocket.md + - Responses: + - reference/response/0-base.md + - reference/response/1-streaming.md + - reference/response/2-file.md + - reference/response/3-template.md + - reference/response/4-redirect.md - Datastructures: - reference/datastructures/0-state.md - reference/datastructures/1-cookie.md diff --git a/mypy.ini b/mypy.ini index 9f2d186810..2676a9b92d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -30,6 +30,8 @@ warn_untyped_fields = True [mypy-picologging.*] ignore_missing_imports = True - [mypy-brotli.*] ignore_missing_imports = True + +[mypy-mako.*] +ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 289f177962..5282e695f2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1119,7 +1119,7 @@ testing = ["httpx", "cryptography"] [metadata] lock-version = "1.1" python-versions = ">=3.7,<4.0" -content-hash = "a2b19b96e3988eae2ccd8b0d7636bd6de3646d4deeb2cadf1c5de01f918af8e0" +content-hash = "583eaaa2535ff838b7cf3936e8d96b92eda7ce445a0fb1a77e6b659168eb820b" [metadata.files] aiomcache = [ diff --git a/pyproject.toml b/pyproject.toml index f75c07df17..8db4563749 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ packages = [ [tool.poetry.dependencies] python = ">=3.7,<4.0" +aiomcache = {version = "*", optional = true} +anyio = ">=3" brotli = { version = "*", optional = true } cryptography = { version = "*", optional = true } httpx = { version = ">=0.22", optional = true } @@ -43,12 +45,11 @@ pydantic = "*" pydantic-factories = "*" pydantic-openapi-schema = "*" pyyaml = "*" +redis = {version = "*", optional = true, extras = ["hiredis"]} starlette = ">=0.21" starlite-multipart = ">=1.2.0" structlog = { version = "*", optional = true } typing-extensions = "*" -redis = {version = "*", optional = true, extras = ["hiredis"]} -aiomcache = {version = "*", optional = true} [tool.poetry.group.dev.dependencies] aiomcache = "*" @@ -102,6 +103,7 @@ multi_line_output = 3 disable = [ "cyclic-import", "duplicate-code", + "fixme", "line-too-long", "missing-class-docstring", "missing-module-docstring", @@ -131,7 +133,7 @@ max-line-length = "120" ignored-argument-names = "args|kwargs|_|__" [tool.pylint.BASIC] -good-names = "_,__,i,e,k,v,fn,get,post,put,patch,delete,route,asgi,websocket,Dependency,Body,Parameter,HandlerType,ScopeType,Auth,User" +good-names = "_,__,i,e,k,v,fn,get,post,put,patch,delete,route,asgi,websocket,Dependency,Body,Parameter,HandlerType,ScopeType,Auth,User,it" no-docstring-rgx="(__.*__|main|test.*|.*test|.*Test|^_.*)$" [tool.pylint.LOGGING] diff --git a/starlite/app.py b/starlite/app.py index 24b7d05c5d..675c5eb2d0 100644 --- a/starlite/app.py +++ b/starlite/app.py @@ -10,10 +10,10 @@ from typing_extensions import TypedDict from starlite.asgi import ( + ASGIRouter, PathParameterTypePathDesignator, PathParamNode, RouteMapNode, - StarliteASGIRouter, ) from starlite.config import AppConfig, CacheConfig, OpenAPIConfig from starlite.config.logging import get_logger_placeholder @@ -394,7 +394,7 @@ def __init__( self._static_paths.add(static_config.path) self.register(asgi(path=static_config.path, name=static_config.name)(static_config.to_static_files_app())) - self.asgi_router = StarliteASGIRouter(on_shutdown=self.on_shutdown, on_startup=self.on_startup, app=self) + self.asgi_router = ASGIRouter(app=self) self.asgi_handler = self._create_asgi_handler() async def __call__( @@ -417,7 +417,7 @@ async def __call__( """ scope["app"] = self if scope["type"] == "lifespan": - await self.asgi_router.lifespan(scope, receive, send) # type: ignore[arg-type] + await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type] return scope["state"] = {} await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type] diff --git a/starlite/asgi.py b/starlite/asgi.py index 95cae97ae1..29731427ec 100644 --- a/starlite/asgi.py +++ b/starlite/asgi.py @@ -2,6 +2,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal from pathlib import Path +from traceback import format_exc from typing import ( TYPE_CHECKING, Any, @@ -22,7 +23,6 @@ parse_duration, parse_time, ) -from starlette.routing import Router as StarletteRouter from starlite.enums import ScopeType from starlite.exceptions import ( @@ -38,6 +38,12 @@ from starlite.types import ( ASGIApp, LifeSpanHandler, + LifeSpanReceive, + LifeSpanSend, + LifeSpanShutdownCompleteEvent, + LifeSpanShutdownFailedEvent, + LifeSpanStartupCompleteEvent, + LifeSpanStartupFailedEvent, Receive, RouteHandlerType, Scope, @@ -59,18 +65,20 @@ class PathParameterTypePathDesignator: ComponentsSet = Set[Union[str, PathParamPlaceholderType, TerminusNodePlaceholderType]] -class StarliteASGIRouter(StarletteRouter): - """This class extends the Starlette Router class and *is* the ASGI app used - in Starlite.""" +class ASGIRouter: + __slots__ = ("app",) def __init__( self, app: "Starlite", - on_shutdown: List["LifeSpanHandler"], - on_startup: List["LifeSpanHandler"], ) -> None: + """This class is the Starlite ASGI router. It handles both the ASGI + lifespan event and routing connection requests. + + Args: + app: The Starlite app instance + """ self.app = app - super().__init__(on_startup=on_startup, on_shutdown=on_shutdown) def _traverse_route_map(self, path: str, scope: "Scope") -> Tuple[RouteMapNode, List[str]]: """Traverses the application route mapping and retrieves the correct @@ -208,7 +216,7 @@ def _resolve_handler_node( node = asgi_handlers[ScopeType.WEBSOCKET] return node["asgi_app"], node["handler"] - async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: # type: ignore[override] + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: """The main entry point to the Router class.""" try: asgi_handlers, is_asgi = self._parse_scope_to_route(scope=scope) @@ -218,6 +226,43 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No scope["route_handler"] = handler await asgi_app(scope, receive, send) + async def lifespan(self, receive: "LifeSpanReceive", send: "LifeSpanSend") -> None: + """Handles the ASGI "lifespan" event on application startup and + shutdown. + + Args: + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None. + """ + message = await receive() + try: + if message["type"] == "lifespan.startup": + await self.startup() + startup_event: "LifeSpanStartupCompleteEvent" = {"type": "lifespan.startup.complete"} + await send(startup_event) + await receive() + except BaseException as e: + if message["type"] == "lifespan.startup": + startup_failure_event: "LifeSpanStartupFailedEvent" = { + "type": "lifespan.startup.failed", + "message": format_exc(), + } + await send(startup_failure_event) + else: + shutdown_failure_event: "LifeSpanShutdownFailedEvent" = { + "type": "lifespan.shutdown.failed", + "message": format_exc(), + } + await send(shutdown_failure_event) + raise e + else: + await self.shutdown() + shutdown_event: "LifeSpanShutdownCompleteEvent" = {"type": "lifespan.shutdown.complete"} + await send(shutdown_event) + async def _call_lifespan_handler(self, handler: "LifeSpanHandler") -> None: """Determines whether the lifecycle handler expects an argument, and if so passes the `app.state` to it. If the handler is an async function, @@ -244,7 +289,7 @@ async def startup(self) -> None: for hook in self.app.before_startup: await hook(self.app) - for handler in self.on_startup: + for handler in self.app.on_startup: await self._call_lifespan_handler(handler) for hook in self.app.after_startup: @@ -262,7 +307,7 @@ async def shutdown(self) -> None: for hook in self.app.before_shutdown: await hook(self.app) - for handler in self.on_shutdown: + for handler in self.app.on_shutdown: await self._call_lifespan_handler(handler) for hook in self.app.after_shutdown: diff --git a/starlite/cache/base.py b/starlite/cache/base.py index 622f32349e..d9faa7de3c 100644 --- a/starlite/cache/base.py +++ b/starlite/cache/base.py @@ -1,6 +1,6 @@ -from asyncio import Lock from typing import TYPE_CHECKING, Any, Optional, overload +from anyio import Lock from typing_extensions import Protocol, runtime_checkable from starlite.utils import is_async_callable diff --git a/starlite/cache/memcached_cache_backend.py b/starlite/cache/memcached_cache_backend.py index 28b01bb4fe..45e5917a82 100644 --- a/starlite/cache/memcached_cache_backend.py +++ b/starlite/cache/memcached_cache_backend.py @@ -34,18 +34,19 @@ class MemcachedCacheBackendConfig(BaseModel): class MemcachedCacheBackend(CacheBackendProtocol): - def __init__(self, config: MemcachedCacheBackendConfig): + _client: Client + + def __init__(self, config: MemcachedCacheBackendConfig) -> None: """This class offers a cache backend based on memcached. Args: config: required configuration to connect to memcached. """ self._config = config - self._client: Client = None # pyright: ignore @property def _memcached_client(self) -> Client: - if not self._client: + if not hasattr(self, "_client"): self._client = Client(**self._config.dict(exclude_unset=True)) return self._client @@ -60,7 +61,7 @@ async def get(self, key: str) -> Any: # pylint: disable=invalid-overridden-meth Cached value if existing else `None`. """ - value = await self._memcached_client.get(key=key.encode()) # pyright: ignore + value = await self._memcached_client.get(key=key.encode("utf-8")) # type: ignore return self._config.deserialize(value) async def set(self, key: str, value: Any, expiration: int) -> None: # pylint: disable=invalid-overridden-method diff --git a/starlite/cache/simple_cache_backend.py b/starlite/cache/simple_cache_backend.py index 1db8d2ad25..672f9bff76 100644 --- a/starlite/cache/simple_cache_backend.py +++ b/starlite/cache/simple_cache_backend.py @@ -1,8 +1,9 @@ -from asyncio import Lock from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Dict +from anyio import Lock + from starlite.cache.base import CacheBackendProtocol diff --git a/starlite/connection/__init__.py b/starlite/connection/__init__.py index 5a0a27b214..7834046fee 100644 --- a/starlite/connection/__init__.py +++ b/starlite/connection/__init__.py @@ -1,3 +1,36 @@ +"""Some code in this module was adapted from +https://github.com/encode/starlette/blob/master/starlette/requests.py and +https://github.com/encode/starlette/blob/master/starlette/websockets.py. + +Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + from starlite.connection.base import ASGIConnection, empty_receive, empty_send from starlite.connection.request import Request from starlite.connection.websocket import WebSocket diff --git a/starlite/connection/base.py b/starlite/connection/base.py index 8ffd144772..e7523a2dd5 100644 --- a/starlite/connection/base.py +++ b/starlite/connection/base.py @@ -11,11 +11,10 @@ ) from starlette.datastructures import URL, Address, Headers, URLPath -from starlette.requests import cookie_parser from starlite.datastructures.state import State from starlite.exceptions import ImproperlyConfiguredException -from starlite.parsers import parse_query_params +from starlite.parsers import parse_cookie_string, parse_query_params from starlite.types.empty import Empty if TYPE_CHECKING: @@ -150,6 +149,7 @@ def headers(self) -> Headers: A Headers instance with the request's scope["headers"] value. """ if self._headers is Empty: + self.scope.setdefault("headers", []) self._headers = self.scope["_headers"] = Headers(scope=self.scope) # type: ignore[typeddict-item] return cast("Headers", self._headers) @@ -182,7 +182,7 @@ def cookies(self) -> Dict[str, str]: cookies: Dict[str, str] = {} cookie_header = self.headers.get("cookie") if cookie_header: - cookies = cookie_parser(cookie_header) + cookies = parse_cookie_string(cookie_header) self._cookies = self.scope["_cookies"] = cookies # type: ignore[typeddict-item] return cast("Dict[str, str]", self._cookies) diff --git a/starlite/connection/request.py b/starlite/connection/request.py index 43dafab91e..79106a745f 100644 --- a/starlite/connection/request.py +++ b/starlite/connection/request.py @@ -2,7 +2,6 @@ from urllib.parse import parse_qsl from orjson import loads -from starlette.requests import SERVER_PUSH_HEADERS_TO_COPY from starlite_multipart import MultipartFormDataParser from starlite_multipart import UploadFile as MultipartUploadFile from starlite_multipart import parse_options_header @@ -27,6 +26,15 @@ from starlite.types.asgi_types import HTTPScope, Method, Receive, Scope, Send +SERVER_PUSH_HEADERS = { + "accept", + "accept-encoding", + "accept-language", + "cache-control", + "user-agent", +} + + class Request(Generic[User, Auth], ASGIConnection["HTTPRouteHandler", User, Auth]): __slots__ = ("_json", "_form", "_body", "_content_type", "is_connected") @@ -188,7 +196,7 @@ async def send_push_promise(self, path: str) -> None: extensions: Dict[str, Dict[Any, Any]] = self.scope.get("extensions") or {} if "http.response.push" in extensions: raw_headers = [] - for name in SERVER_PUSH_HEADERS_TO_COPY: + for name in SERVER_PUSH_HEADERS: for value in self.headers.getlist(name): raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) await self.send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/starlite/connection/websocket.py b/starlite/connection/websocket.py index 0e2e5dcb1f..d9b20eeba4 100644 --- a/starlite/connection/websocket.py +++ b/starlite/connection/websocket.py @@ -13,7 +13,6 @@ from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps, loads from starlette.datastructures import Headers -from starlette.status import WS_1000_NORMAL_CLOSURE from starlite.connection.base import ( ASGIConnection, @@ -22,7 +21,8 @@ empty_receive, empty_send, ) -from starlite.exceptions import WebSocketException +from starlite.exceptions import WebSocketDisconnect, WebSocketException +from starlite.status_codes import WS_1000_NORMAL_CLOSURE from starlite.utils.serialization import default_serializer if TYPE_CHECKING: @@ -112,7 +112,7 @@ def send_wrapper(self, send: "Send") -> "Send": async def wrapped_send(message: "Message") -> None: if self.connection_state == "disconnect": - raise WebSocketException(detail=DISCONNECT_MESSAGE) # pragma: no cover + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover await send(message) return wrapped_send @@ -198,9 +198,9 @@ async def receive_data(self, mode: "Literal['binary', 'text']") -> Union[str, by await self.accept() event = cast("Union['WebSocketReceiveEvent', 'WebSocketDisconnectEvent']", await self.receive()) if event["type"] == "websocket.disconnect": - raise WebSocketException(detail="disconnect event", code=event["code"]) + raise WebSocketDisconnect(detail="disconnect event", code=event["code"]) if self.connection_state == "disconnect": - raise WebSocketException(detail=DISCONNECT_MESSAGE) + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover return event.get("text") or "" if mode == "text" else event.get("bytes") or b"" async def receive_text(self) -> str: diff --git a/starlite/datastructures/background_tasks.py b/starlite/datastructures/background_tasks.py index a66c813cf1..2ed3dca4c4 100644 --- a/starlite/datastructures/background_tasks.py +++ b/starlite/datastructures/background_tasks.py @@ -1,32 +1,67 @@ -from typing import Any, Callable, List, TypeVar +from typing import Any, Callable, Iterable -from starlette.background import BackgroundTask as StarletteBackgroundTask -from starlette.background import BackgroundTasks as StarletteBackgroundTasks +from anyio import create_task_group from typing_extensions import ParamSpec +from starlite.utils.sync import AsyncCallable + P = ParamSpec("P") -T = TypeVar("T") -class BackgroundTask(StarletteBackgroundTask): - def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: +class BackgroundTask: + __slots__ = ("fn", "args", "kwargs") + + def __init__(self, fn: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: """A container for a 'background' task function. Background tasks are called once a Response finishes. - Args: - func: A sync or async function to call as the background task. + fn: A sync or async function to call as the background task. *args: Args to pass to the func. **kwargs: Kwargs to pass to the func """ - super().__init__(func, *args, **kwargs) + self.fn = AsyncCallable(fn) + self.args = args + self.kwargs = kwargs + async def __call__(self) -> None: + """Calls the wrapped function with the passed in arguments. + + Returns: + None. + """ + await self.fn(*self.args, **self.kwargs) -class BackgroundTasks(StarletteBackgroundTasks): - def __init__(self, tasks: List[BackgroundTask]) -> None: - """A container for multiple 'background' task functions. Background - tasks are called once a Response finishes. +class BackgroundTasks: + __slots__ = ( + "tasks", + "run_in_task_group", + ) + + def __init__(self, tasks: Iterable[BackgroundTask], run_in_task_group: bool = False) -> None: + """A container for multiple 'background' task functions. + + Background + tasks are called once a Response finishes. Args: - tasks: A list of [BackgroundTask][starlite.datastructures.BackgroundTask] instances. + tasks: An iterable of [BackgroundTask][starlite.datastructures.BackgroundTask] instances. + run_in_task_group: If you set this value to `True` than the tasks will run concurrently, using + an [anyio.task_group](https://anyio.readthedocs.io/en/stable/tasks.html). Note: this will + not preserve execution order. + """ + self.tasks = tasks + self.run_in_task_group = run_in_task_group + + async def __call__(self) -> None: + """Calls the wrapped background tasks. + + Returns: + None """ - super().__init__(tasks=tasks) + if self.run_in_task_group: + async with create_task_group() as task_group: + for task in self.tasks: + task_group.start_soon(task) + else: + for task in self.tasks: + await task() diff --git a/starlite/datastructures/cookie.py b/starlite/datastructures/cookie.py index 056ebf1542..6a4d34ece5 100644 --- a/starlite/datastructures/cookie.py +++ b/starlite/datastructures/cookie.py @@ -51,3 +51,18 @@ def to_header(self, **kwargs: Any) -> str: if value is not None: simple_cookie[self.key][key] = value return simple_cookie.output(**kwargs).strip() + + def __eq__(self, other: Any) -> bool: + """Custom implementation to allow using equality operators on class + instances. + + Args: + other: An arbitrary value + + Returns: + Determines whether two cookie instances are equal according to the cookie spec, i.e. + they have a similar path, domain and key. + """ + if isinstance(other, type(self)): + return other.key == self.key and other.path == self.path and other.domain == self.domain + return False diff --git a/starlite/datastructures/provide.py b/starlite/datastructures/provide.py index dfcb72160a..fc536fafbf 100644 --- a/starlite/datastructures/provide.py +++ b/starlite/datastructures/provide.py @@ -1,8 +1,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast from starlite.types import Empty -from starlite.utils.predicates import is_async_callable -from starlite.utils.sync import AsyncCallable +from starlite.utils.sync import AsyncCallable, is_async_callable if TYPE_CHECKING: from typing import Type diff --git a/starlite/datastructures/response_containers.py b/starlite/datastructures/response_containers.py index 042a6af88b..c66f85e9d4 100644 --- a/starlite/datastructures/response_containers.py +++ b/starlite/datastructures/response_containers.py @@ -5,18 +5,15 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Dict, - Generator, Generic, Iterable, Iterator, List, Optional, - Type, TypeVar, Union, cast, @@ -24,24 +21,29 @@ from pydantic import BaseConfig, FilePath, validator from pydantic.generics import GenericModel -from starlette.responses import FileResponse, RedirectResponse -from starlette.responses import Response as StarletteResponse -from starlette.responses import StreamingResponse +from typing_extensions import Literal from starlite.datastructures.background_tasks import BackgroundTask, BackgroundTasks from starlite.datastructures.cookie import Cookie from starlite.enums import MediaType from starlite.exceptions import ImproperlyConfiguredException -from starlite.response import TemplateResponse +from starlite.response import ( + FileResponse, + RedirectResponse, + StreamingResponse, + TemplateResponse, +) +from starlite.response.file import ONE_MEGA_BYTE +from starlite.types.composite import StreamType if TYPE_CHECKING: from starlite.app import Starlite from starlite.connection import Request -R = TypeVar("R", bound=StarletteResponse) +R = TypeVar("R") -class ResponseContainer(GenericModel, ABC, Generic[R]): +class ResponseContainer(ABC, GenericModel, Generic[R]): class Config(BaseConfig): arbitrary_types_allowed = True @@ -57,6 +59,8 @@ class Config(BaseConfig): """A list of Cookie instances to be set under the response 'Set-Cookie' header. Defaults to None.""" media_type: Optional[Union[MediaType, str]] = None """If defined, overrides the media type configured in the route decorator""" + encoding: str = "utf-8" + """The encoding to be used for the response headers.""" @abstractmethod def to_response( @@ -66,7 +70,7 @@ def to_response( status_code: int, app: "Starlite", request: "Request", - ) -> R: # pragma: no cover + ) -> "R": # pragma: no cover """Abstract method that should be implemented by subclasses. Returns a Starlette compatible Response instance. @@ -88,22 +92,34 @@ class File(ResponseContainer[FileResponse]): path: FilePath """Path to the file to send""" - filename: str + filename: Optional[str] = None """The filename""" stat_result: Optional[os.stat_result] = None """File statistics""" + chunk_size: int = ONE_MEGA_BYTE + """The size of chunks to use when streaming the file""" + content_disposition_type: "Literal['attachment', 'inline']" = "attachment" + """The type of the 'Content-Disposition'. Either 'inline' or 'attachment'.""" @validator("stat_result", always=True) def validate_status_code( # pylint: disable=no-self-argument cls, value: Optional[os.stat_result], values: Dict[str, Any] ) -> os.stat_result: - """Set the stat_result value for the given filepath.""" + """Set the stat_result value for the given filepath. + + Args: + value: An optional result [stat][os.stat] result. + values: The values dict. + + Returns: + A stat_result + """ return value or Path(cast("str", values.get("path"))).stat() def to_response( self, headers: Dict[str, Any], - media_type: Union["MediaType", str], + media_type: Optional[Union["MediaType", str]], status_code: int, app: "Starlite", request: "Request", @@ -122,6 +138,9 @@ def to_response( """ return FileResponse( background=self.background, + chunk_size=self.chunk_size, + content_disposition_type=self.content_disposition_type, + encoding=self.encoding, filename=self.filename, headers=headers, media_type=media_type, @@ -137,11 +156,13 @@ class Redirect(ResponseContainer[RedirectResponse]): path: str """Redirection path""" - def to_response( + def to_response( # type: ignore[override] self, headers: Dict[str, Any], + # TODO: update the redirect response to support HTML as well. + # This argument is currently ignored. media_type: Union["MediaType", str], - status_code: int, + status_code: "Literal[301, 302, 303, 307, 308]", app: "Starlite", request: "Request", ) -> RedirectResponse: @@ -157,23 +178,36 @@ def to_response( Returns: A RedirectResponse instance """ - return RedirectResponse(headers=headers, status_code=status_code, url=self.path, background=self.background) + return RedirectResponse( + background=self.background, + encoding=self.encoding, + headers=headers, + status_code=status_code, + url=self.path, + ) class Stream(ResponseContainer[StreamingResponse]): """Container type for returning Stream responses.""" - iterator: Union[ - Iterator[Union[str, bytes]], - Generator[Union[str, bytes], Any, Any], - AsyncIterator[Union[str, bytes]], - AsyncGenerator[Union[str, bytes], Any], - Type[Iterator[Union[str, bytes]]], - Type[AsyncIterator[Union[str, bytes]]], - Callable[[], AsyncGenerator[Union[str, bytes], Any]], - Callable[[], Generator[Union[str, bytes], Any, Any]], - ] - """Iterator, Generator or async Iterator or Generator returning stream chunks""" + iterator: Union[StreamType[Union[str, bytes]], Callable[[], StreamType[Union[str, bytes]]]] + """Iterator, Iterable,Generator or async Iterator, Iterable or Generator returning chunks to stream.""" + + @validator("iterator", always=True) + def validate_iterator( # pylint: disable=no-self-argument + cls, + value: Union[StreamType[Union[str, bytes]], Callable[[], StreamType[Union[str, bytes]]]], + ) -> StreamType[Union[str, bytes]]: + """Set the iterator value by ensuring that the return value is + iterable. + + Args: + value: An iterable or callable returning an iterable. + + Returns: + A sync or async iterable. + """ + return value if isinstance(value, (Iterable, Iterator, AsyncIterable, AsyncIterator)) else value() def to_response( self, @@ -199,6 +233,7 @@ def to_response( return StreamingResponse( background=self.background, content=self.iterator if isinstance(self.iterator, (Iterable, AsyncIterable)) else self.iterator(), + encoding=self.encoding, headers=headers, media_type=media_type, status_code=status_code, @@ -237,13 +272,13 @@ def to_response( Returns: A TemplateResponse instance """ - if not app.template_engine: raise ImproperlyConfiguredException("Template engine is not configured") return TemplateResponse( background=self.background, context=self.create_template_context(request=request), + encoding=self.encoding, headers=headers, status_code=status_code, template_engine=app.template_engine, diff --git a/starlite/exceptions/__init__.py b/starlite/exceptions/__init__.py index f594a6d9eb..820975c473 100644 --- a/starlite/exceptions/__init__.py +++ b/starlite/exceptions/__init__.py @@ -13,22 +13,23 @@ TooManyRequestsException, ValidationException, ) -from .websocket_exceptions import WebSocketException +from .websocket_exceptions import WebSocketDisconnect, WebSocketException __all__ = ( - "MissingDependencyException", - "StarLiteException", "HTTPException", "ImproperlyConfiguredException", - "ValidationException", - "NotAuthorizedException", - "PermissionDeniedException", + "InternalServerException", + "MethodNotAllowedException", + "MissingDependencyException", "NoRouteMatchFoundException", + "NotAuthorizedException", "NotFoundException", - "MethodNotAllowedException", - "TooManyRequestsException", - "InternalServerException", + "PermissionDeniedException", "ServiceUnavailableException", + "StarLiteException", "TemplateNotFoundException", + "TooManyRequestsException", + "ValidationException", + "WebSocketDisconnect", "WebSocketException", ) diff --git a/starlite/exceptions/http_exceptions.py b/starlite/exceptions/http_exceptions.py index e0653b3476..b6bc6be45d 100644 --- a/starlite/exceptions/http_exceptions.py +++ b/starlite/exceptions/http_exceptions.py @@ -1,8 +1,8 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Union -from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.status import ( +from starlite.exceptions.base_exceptions import StarLiteException +from starlite.status_codes import ( HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN, @@ -13,10 +13,8 @@ HTTP_503_SERVICE_UNAVAILABLE, ) -from starlite.exceptions.base_exceptions import StarLiteException - -class HTTPException(StarletteHTTPException, StarLiteException): +class HTTPException(StarLiteException): status_code: int = HTTP_500_INTERNAL_SERVER_ERROR """Exception status code.""" detail: str @@ -45,8 +43,8 @@ def __init__( headers: Headers to set on the response. extra: An extra mapping to attach to the exception. """ - - super().__init__(status_code or self.status_code) + super().__init__() + self.status_code = status_code or self.status_code if not detail: detail = args[0] if args else HTTPStatus(self.status_code).phrase diff --git a/starlite/exceptions/websocket_exceptions.py b/starlite/exceptions/websocket_exceptions.py index 6ad86fe27b..2f0490e1dc 100644 --- a/starlite/exceptions/websocket_exceptions.py +++ b/starlite/exceptions/websocket_exceptions.py @@ -1,6 +1,7 @@ from typing import Any from starlite.exceptions.base_exceptions import StarLiteException +from starlite.status_codes import WS_1000_NORMAL_CLOSURE class WebSocketException(StarLiteException): @@ -12,8 +13,20 @@ def __init__(self, *args: Any, detail: str, code: int = 4500) -> None: Args: *args: Any exception args. - detail: - code: Exception code. Should be a number in the 4000+ range. + detail: Exception details. + code: Exception code. Should be a number in the >= 1000. """ super().__init__(*args, detail=detail) self.code = code + + +class WebSocketDisconnect(WebSocketException): + def __init__(self, *args: Any, detail: str, code: int = WS_1000_NORMAL_CLOSURE) -> None: + """Exception class for websocket disconnect events. + + Args: + *args: Any exception args. + detail: Exception details. + code: Exception code. Should be a number in the >= 1000. + """ + super().__init__(*args, detail=detail, code=code) diff --git a/starlite/handlers/http.py b/starlite/handlers/http.py index b5e3580107..01788e85c5 100644 --- a/starlite/handlers/http.py +++ b/starlite/handlers/http.py @@ -15,13 +15,6 @@ from pydantic import validate_arguments from pydantic_openapi_schema.v3_1_0 import SecurityRequirement -from starlette.responses import Response as StarletteResponse -from starlette.status import ( - HTTP_200_OK, - HTTP_201_CREATED, - HTTP_204_NO_CONTENT, - HTTP_304_NOT_MODIFIED, -) from starlite.constants import REDIRECT_STATUS_CODES from starlite.datastructures import CacheControlHeader, ETag, Provide, ResponseHeader @@ -41,9 +34,16 @@ from starlite.openapi.datastructures import ResponseSpec from starlite.plugins import get_plugin_for_value from starlite.response import Response +from starlite.status_codes import ( + HTTP_200_OK, + HTTP_201_CREATED, + HTTP_204_NO_CONTENT, + HTTP_304_NOT_MODIFIED, +) from starlite.types import ( AfterRequestHookHandler, AfterResponseHookHandler, + ASGIApp, BeforeRequestHookHandler, CacheKeyBuilder, Empty, @@ -56,7 +56,8 @@ ResponseHeadersMap, ResponseType, ) -from starlite.utils.predicates import is_async_callable, is_class_and_subclass +from starlite.utils import is_async_callable, unique +from starlite.utils.predicates import is_class_and_subclass from starlite.utils.sync import AsyncCallable if TYPE_CHECKING: @@ -74,11 +75,11 @@ def _normalize_cookies(local_cookies: "ResponseCookies", layered_cookies: "Respo returns a normalized dict ready to be set on the response.""" filtered_cookies = [*local_cookies] for cookie in layered_cookies: - if not any(cookie.key == c.key for c in filtered_cookies): + if not any(c == cookie for c in filtered_cookies): filtered_cookies.append(cookie) return [ cookie.dict(exclude_none=True, exclude={"documentation_only", "description"}) - for cookie in filtered_cookies + for cookie in unique(filtered_cookies) if not cookie.documentation_only ] @@ -133,7 +134,7 @@ def _create_response_container_handler( ) -> "AsyncAnyCallable": """Creates a handler function for ResponseContainers.""" - async def handler(data: ResponseContainer, app: "Starlite", request: "Request", **kwargs: Any) -> StarletteResponse: + async def handler(data: ResponseContainer, app: "Starlite", request: "Request", **kwargs: Any) -> "ASGIApp": normalized_headers = {**_normalize_headers(headers), **data.headers} normalized_cookies = _normalize_cookies(data.cookies, cookies) response = data.to_response( @@ -155,7 +156,7 @@ def _create_response_handler( ) -> "AsyncAnyCallable": """Creates a handler function for Starlite Responses.""" - async def handler(data: Response, **kwargs: Any) -> StarletteResponse: + async def handler(data: Response, **kwargs: Any) -> "ASGIApp": normalized_cookies = _normalize_cookies(data.cookies, cookies) for cookie in normalized_cookies: data.set_cookie(**cookie) @@ -164,15 +165,16 @@ async def handler(data: Response, **kwargs: Any) -> StarletteResponse: return handler -def _create_starlette_response_handler( +def _create_generic_asgi_response_handler( after_request: Optional["AfterRequestHookHandler"], cookies: "ResponseCookies" ) -> "AsyncAnyCallable": """Creates a handler function for Starlette Responses.""" - async def handler(data: StarletteResponse, **kwargs: Any) -> StarletteResponse: + async def handler(data: "ASGIApp", **kwargs: Any) -> "ASGIApp": normalized_cookies = _normalize_cookies(cookies, []) - for cookie in normalized_cookies: - data.set_cookie(**cookie) + if hasattr(data, "set_cookie"): + for cookie in normalized_cookies: + data.set_cookie(**cookie) # type: ignore return await after_request(data) if after_request else data # type: ignore return handler @@ -189,7 +191,7 @@ def _create_data_handler( ) -> "AsyncAnyCallable": """Creates a handler function for arbitrary data.""" - async def handler(data: Any, plugins: List["PluginProtocol"], **kwargs: Any) -> StarletteResponse: + async def handler(data: Any, plugins: List["PluginProtocol"], **kwargs: Any) -> "ASGIApp": data = await _normalize_response_data(data=data, plugins=plugins) normalized_cookies = _normalize_cookies(cookies, []) normalized_headers = _normalize_headers(headers) @@ -399,7 +401,7 @@ def __init__( # memoized attributes, defaulted to Empty self._resolved_after_response: Union[Optional[AfterResponseHookHandler], EmptyType] = Empty self._resolved_before_request: Union[Optional[BeforeRequestHookHandler], EmptyType] = Empty - self._resolved_response_handler: Union["Callable[[Any], Awaitable[StarletteResponse]]", EmptyType] = Empty + self._resolved_response_handler: Union["Callable[[Any], Awaitable[ASGIApp]]", EmptyType] = Empty def __call__(self, fn: "AnyCallable") -> "HTTPRouteHandler": """Replaces a function with itself.""" @@ -503,7 +505,7 @@ def resolve_after_response(self) -> Optional["AfterResponseHookHandler"]: def resolve_response_handler( self, - ) -> Callable[[Any], Awaitable[StarletteResponse]]: + ) -> Callable[[Any], Awaitable["ASGIApp"]]: """Resolves the response_handler function for the route handler. This method is memoized so the computation occurs only once. @@ -534,8 +536,8 @@ def resolve_response_handler( ) elif is_class_and_subclass(self.signature.return_annotation, Response): handler = _create_response_handler(cookies=cookies, after_request=after_request) - elif is_class_and_subclass(self.signature.return_annotation, StarletteResponse): - handler = _create_starlette_response_handler(cookies=cookies, after_request=after_request) + elif is_async_callable(self.signature.return_annotation): + handler = _create_generic_asgi_response_handler(cookies=cookies, after_request=after_request) else: handler = _create_data_handler( after_request=after_request, @@ -548,11 +550,11 @@ def resolve_response_handler( ) self._resolved_response_handler = handler - return cast("Callable[[Any], Awaitable[StarletteResponse]]", self._resolved_response_handler) + return cast("Callable[[Any], Awaitable[ASGIApp]]", self._resolved_response_handler) async def to_response( self, app: "Starlite", data: Any, plugins: List["PluginProtocol"], request: "Request" - ) -> StarletteResponse: + ) -> "ASGIApp": """ Args: diff --git a/starlite/middleware/compression/brotli.py b/starlite/middleware/compression/brotli.py index 1d0934eb3b..7a6e1c52c8 100644 --- a/starlite/middleware/compression/brotli.py +++ b/starlite/middleware/compression/brotli.py @@ -61,25 +61,26 @@ def __init__( self.minimum_size = minimum_size self.lgwin = brotli_lgwin self.lgblock = brotli_lgblock - self.gzip_fallback = brotli_gzip_fallback + self.brotli_responder = BrotliResponder( + app=self.app, + minimum_size=self.minimum_size, + quality=self.quality, + mode=self.mode, + lgwin=self.lgwin, + lgblock=self.lgblock, + ) + self.gzip_responder: Optional[GZipResponder] = None + if brotli_gzip_fallback: + self.gzip_responder = GZipResponder(self.app, self.minimum_size) # type: ignore[arg-type] async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: if scope["type"] == ScopeType.HTTP: headers = Headers(scope=scope) if CompressionEncoding.BROTLI in headers.get("Accept-Encoding", ""): - brotli_responder = BrotliResponder( - app=self.app, - minimum_size=self.minimum_size, - quality=self.quality, - mode=self.mode, - lgwin=self.lgwin, - lgblock=self.lgblock, - ) - await brotli_responder(scope, receive, send) + await self.brotli_responder(scope, receive, send) return - if self.gzip_fallback and CompressionEncoding.GZIP in headers.get("Accept-Encoding", ""): - gzip_responder = GZipResponder(self.app, self.minimum_size) # type: ignore[arg-type] - await gzip_responder(scope, receive, send) # type: ignore[arg-type] + if self.gzip_responder and CompressionEncoding.GZIP in headers.get("Accept-Encoding", ""): + await self.gzip_responder(scope, receive, send) # type: ignore[arg-type] return await self.app(scope, receive, send) @@ -151,7 +152,6 @@ async def send_wrapper(message: "Message") -> None: Args: message (Message): An ASGI Message. """ - if message["type"] == "http.response.start": # Don't send the initial message until we've determined how to # modify the outgoing headers correctly. diff --git a/starlite/middleware/csrf.py b/starlite/middleware/csrf.py index 532b6665c0..2365dfde4d 100644 --- a/starlite/middleware/csrf.py +++ b/starlite/middleware/csrf.py @@ -87,6 +87,7 @@ def create_send_wrapper(self, send: "Send", token: str, csrf_cookie: Optional[st """Wraps 'send' to handle CSRF validation. Args: + token: The CSRF token. send: The ASGI send function. csrf_cookie: CSRF cookie. diff --git a/starlite/middleware/exceptions.py b/starlite/middleware/exceptions.py index 18251bcc42..0613a3408e 100644 --- a/starlite/middleware/exceptions.py +++ b/starlite/middleware/exceptions.py @@ -1,19 +1,16 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.errors import ServerErrorMiddleware -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from starlite.connection import Request from starlite.enums import ScopeType from starlite.exceptions import WebSocketException +from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from starlite.utils import create_exception_response from starlite.utils.exception import get_exception_handler if TYPE_CHECKING: - - from starlette.responses import Response as StarletteResponse - + from starlite.response import Response from starlite.types import ASGIApp, ExceptionHandlersMap, Receive, Scope, Send from starlite.types.asgi_types import WebSocketCloseEvent @@ -57,26 +54,23 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No get_exception_handler(self.exception_handlers, e) or self.default_http_exception_handler ) response = exception_handler(Request(scope=scope, receive=receive, send=send), e) - await response(scope=scope, receive=receive, send=send) # type: ignore[arg-type] + await response(scope=scope, receive=receive, send=send) return if isinstance(e, WebSocketException): code = e.code reason = e.detail - elif isinstance(e, StarletteHTTPException): - code = e.status_code + 4000 - reason = e.detail else: - code = HTTP_500_INTERNAL_SERVER_ERROR + 4000 - reason = repr(e) + code = 4000 + getattr(e, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) + reason = getattr(e, "detail", repr(e)) event: "WebSocketCloseEvent" = {"type": "websocket.close", "code": code, "reason": reason} await send(event) - def default_http_exception_handler(self, request: Request, exc: Exception) -> "StarletteResponse": + def default_http_exception_handler(self, request: Request, exc: Exception) -> "Response[Any]": """Default handler for exceptions subclassed from HTTPException.""" - status_code = exc.status_code if isinstance(exc, StarletteHTTPException) else HTTP_500_INTERNAL_SERVER_ERROR + status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self.debug: # in debug mode, we just use the serve_middleware to create an HTML formatted response for us server_middleware = ServerErrorMiddleware(app=self) # type: ignore[arg-type] - return server_middleware.debug_response(request=request, exc=exc) # type: ignore[arg-type] + return server_middleware.debug_response(request=request, exc=exc) # type: ignore return create_exception_response(exc) diff --git a/starlite/openapi/controller.py b/starlite/openapi/controller.py index d77c4d7cb3..b5d401ef9f 100644 --- a/starlite/openapi/controller.py +++ b/starlite/openapi/controller.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Callable, Dict from orjson import OPT_INDENT_2, dumps -from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND from starlite.connection import Request from starlite.controller import Controller @@ -9,6 +8,7 @@ from starlite.exceptions import ImproperlyConfiguredException from starlite.handlers import get from starlite.response import Response +from starlite.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND if TYPE_CHECKING: from pydantic_openapi_schema.v3_1_0.open_api import OpenAPI diff --git a/starlite/openapi/path_item.py b/starlite/openapi/path_item.py index 3f7093e63c..0bc46076bb 100644 --- a/starlite/openapi/path_item.py +++ b/starlite/openapi/path_item.py @@ -2,11 +2,11 @@ from pydantic_openapi_schema.v3_1_0.operation import Operation from pydantic_openapi_schema.v3_1_0.path_item import PathItem -from starlette.routing import get_name from starlite.openapi.parameters import create_parameter_for_handler from starlite.openapi.request_body import create_request_body from starlite.openapi.responses import create_responses +from starlite.utils import get_name if TYPE_CHECKING: from pydantic import BaseModel diff --git a/starlite/openapi/responses.py b/starlite/openapi/responses.py index 614c53bbce..548e001488 100644 --- a/starlite/openapi/responses.py +++ b/starlite/openapi/responses.py @@ -8,7 +8,6 @@ MediaType as OpenAPISchemaMediaType, ) from pydantic_openapi_schema.v3_1_0.schema import Schema -from starlette.routing import get_name from typing_extensions import get_args, get_origin from starlite.datastructures.response_containers import File, Redirect, Stream, Template @@ -22,7 +21,7 @@ from starlite.openapi.schema import create_schema from starlite.openapi.utils import pascal_case_to_text from starlite.response import Response as StarliteResponse -from starlite.utils.model import create_parsed_model_field +from starlite.utils import create_parsed_model_field, get_enum_string_value, get_name if TYPE_CHECKING: @@ -68,7 +67,7 @@ def create_success_response( return_annotation = signature.return_annotation if signature.return_annotation is Template: return_annotation = str # since templates return str - route_handler.media_type = MediaType.HTML + route_handler.media_type = get_enum_string_value(MediaType.HTML) elif get_origin(signature.return_annotation) is StarliteResponse: return_annotation = get_args(signature.return_annotation)[0] or Any as_parsed_model_field = create_parsed_model_field(return_annotation) @@ -152,13 +151,13 @@ def create_error_responses(exceptions: List[Type[HTTPException]]) -> Iterator[Tu Schema( type=OpenAPIType.OBJECT, required=["detail", "status_code"], - properties=dict( - status_code=Schema(type=OpenAPIType.INTEGER), - detail=Schema(type=OpenAPIType.STRING), - extra=Schema( + properties={ + "status_code": Schema(type=OpenAPIType.INTEGER), + "detail": Schema(type=OpenAPIType.STRING), + "extra": Schema( type=[OpenAPIType.NULL, OpenAPIType.OBJECT, OpenAPIType.ARRAY], additionalProperties=Schema() ), - ), + }, description=pascal_case_to_text(get_name(exc)), examples=[{"status_code": status_code, "detail": HTTPStatus(status_code).phrase, "extra": {}}], ) diff --git a/starlite/params.py b/starlite/params.py index cc1803e026..1c63de572c 100644 --- a/starlite/params.py +++ b/starlite/params.py @@ -79,7 +79,7 @@ def Parameter( regex: A string representing a regex against which the given string will be matched. Equivalent to pattern in the OpenAPI specification. """ - extra: Dict[str, Any] = dict(is_parameter=True) + extra: Dict[str, Any] = {"is_parameter": True} extra.update(header=header) extra.update(cookie=cookie) extra.update(query=query) diff --git a/starlite/parsers.py b/starlite/parsers.py index 6fd019c1d5..1cfd955f49 100644 --- a/starlite/parsers.py +++ b/starlite/parsers.py @@ -1,7 +1,8 @@ from contextlib import suppress from functools import reduce +from http.cookies import _unquote as unquote_cookie from typing import TYPE_CHECKING, Any, Dict, List, Tuple -from urllib.parse import parse_qsl +from urllib.parse import parse_qsl, unquote from orjson import JSONDecodeError, loads from pydantic.fields import SHAPE_LIST, SHAPE_SINGLETON @@ -78,3 +79,19 @@ def parse_form_data(media_type: "RequestEncodingType", form_data: "FormMultiDict if field.shape is SHAPE_SINGLETON and field.type_ in (UploadFile, MultipartUploadFile) and values_dict: return list(values_dict.values())[0] return values_dict + + +def parse_cookie_string(cookie_string: str) -> Dict[str, str]: + """ + Parses a cookie string into a dictionary of values. + Args: + cookie_string: A cookie string. + + Returns: + A string keyed dictionary of values + """ + output: Dict[str, str] = {} + cookies = [cookie.split("=", 1) if "=" in cookie else ("", cookie) for cookie in cookie_string.split(";")] + for k, v in filter(lambda x: x[0] or x[1], ((k.strip(), v.strip()) for k, v in cookies)): + output[k] = unquote(unquote_cookie(v)) + return output diff --git a/starlite/plugins/tortoise_orm.py b/starlite/plugins/tortoise_orm.py index db0a6cdb73..d1114fc069 100644 --- a/starlite/plugins/tortoise_orm.py +++ b/starlite/plugins/tortoise_orm.py @@ -8,8 +8,11 @@ from starlite.plugins.base import PluginProtocol try: - from tortoise import Model, ModelMeta - from tortoise.contrib.pydantic import PydanticModel, pydantic_model_creator + from tortoise import Model, ModelMeta # type:ignore [attr-defined] + from tortoise.contrib.pydantic import ( # type:ignore [attr-defined] + PydanticModel, + pydantic_model_creator, + ) except ImportError as e: raise MissingDependencyException("tortoise-orm is not installed") from e diff --git a/starlite/response.py b/starlite/response.py deleted file mode 100644 index 215b054b33..0000000000 --- a/starlite/response.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar, Union, cast - -import yaml -from orjson import OPT_INDENT_2, OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps -from pydantic_openapi_schema.v3_1_0.open_api import OpenAPI -from starlette.responses import Response as StarletteResponse -from starlette.status import HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED - -from starlite.enums import MediaType, OpenAPIMediaType -from starlite.exceptions import ImproperlyConfiguredException -from starlite.utils.serialization import default_serializer - -T = TypeVar("T") - -if TYPE_CHECKING: - from starlite.datastructures.background_tasks import BackgroundTask, BackgroundTasks - from starlite.template import TemplateEngineProtocol - from starlite.types import ResponseCookies - - -class Response(StarletteResponse, Generic[T]): - def __init__( - self, - content: T, - *, - status_code: int, - media_type: Union["MediaType", "OpenAPIMediaType", str], - background: Optional[Union["BackgroundTask", "BackgroundTasks"]] = None, - headers: Optional[Dict[str, Any]] = None, - cookies: Optional["ResponseCookies"] = None, - ) -> None: - """The response class is used to return an HTTP response. - - Args: - content: A value for the response body that will be rendered into bytes string. - status_code: A value for the response HTTP status code. - media_type: A value for the response 'Content-Type' header. - background: A [BackgroundTask][starlite.datastructures.BackgroundTask] instance or - [BackgroundTasks][starlite.datastructures.BackgroundTasks] to execute after the response is finished. - Defaults to None. - headers: A string keyed dictionary of response headers. Header keys are insensitive. - cookies: A list of [Cookie][starlite.datastructures.Cookie] instances to be set under the response 'Set-Cookie' header. - """ - super().__init__( - content=content, - status_code=status_code, - headers=headers or {}, - media_type=media_type, - background=cast("BackgroundTask", background), - ) - self.cookies = cookies or [] - - @staticmethod - def serializer(value: Any) -> Any: - """Serializer hook for orjson to handle pydantic models. - - Args: - value: A value to serialize - Returns: - A serialized value - Raises: - TypeError: if value is not supported - """ - return default_serializer(value) - - def render(self, content: Any) -> bytes: - """ - Handles the rendering of content T into a bytes string. - Args: - content: An arbitrary value of type T - - Returns: - An encoded bytes string - """ - try: - if content is None and ( - self.status_code < 100 or self.status_code in {HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED} - ): - return b"" - if self.media_type == MediaType.JSON: - return dumps(content, default=self.serializer, option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS) - if isinstance(content, OpenAPI): - content_dict = content.dict(by_alias=True, exclude_none=True) - if self.media_type == OpenAPIMediaType.OPENAPI_YAML: - encoded = yaml.dump(content_dict, default_flow_style=False).encode("utf-8") - return cast("bytes", encoded) - return dumps(content_dict, option=OPT_INDENT_2 | OPT_OMIT_MICROSECONDS) - return super().render(content) - except (AttributeError, ValueError, TypeError) as e: - raise ImproperlyConfiguredException("Unable to serialize response content") from e - - -class TemplateResponse(Response): - def __init__( - self, - template_name: str, - template_engine: "TemplateEngineProtocol", - status_code: int, - context: Dict[str, Any], - background: Optional[Union["BackgroundTask", "BackgroundTasks"]] = None, - headers: Optional[Dict[str, Any]] = None, - cookies: Optional["ResponseCookies"] = None, - ) -> None: - """Handles the rendering of a given template into a bytes string. - - Args: - template_name: Path-like name for the template to be rendered, e.g. "index.html". - template_engine: The template engine class to use to render the response. - status_code: A value for the response HTTP status code. - context: A dictionary of key/value pairs to be passed to the temple engine's render method. - background: A [BackgroundTask][starlite.datastructures.BackgroundTask] instance or - [BackgroundTasks][starlite.datastructures.BackgroundTasks] to execute after the response is finished. - Defaults to None. - headers: A string keyed dictionary of response headers. Header keys are insensitive. - cookies: A list of [Cookie][starlite.datastructures.Cookie] instances to be set under the response 'Set-Cookie' header. - """ - template = template_engine.get_template(template_name) - content = template.render(**context) - super().__init__( - content=content, - status_code=status_code, - headers=headers, - media_type=MediaType.HTML, - background=background, - cookies=cookies, - ) diff --git a/starlite/response/__init__.py b/starlite/response/__init__.py new file mode 100644 index 0000000000..fc0a322a59 --- /dev/null +++ b/starlite/response/__init__.py @@ -0,0 +1,7 @@ +from .base import Response +from .file import FileResponse +from .redirect import RedirectResponse +from .streaming import StreamingResponse +from .template import TemplateResponse + +__all__ = ["Response", "RedirectResponse", "StreamingResponse", "TemplateResponse", "FileResponse"] diff --git a/starlite/response/base.py b/starlite/response/base.py new file mode 100644 index 0000000000..77f76f32d2 --- /dev/null +++ b/starlite/response/base.py @@ -0,0 +1,328 @@ +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) + +from orjson import OPT_INDENT_2, OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps +from pydantic_openapi_schema.v3_1_0 import OpenAPI +from yaml import dump as dump_yaml + +from starlite.datastructures import Cookie +from starlite.enums import MediaType, OpenAPIMediaType +from starlite.exceptions import ImproperlyConfiguredException +from starlite.status_codes import ( + HTTP_200_OK, + HTTP_204_NO_CONTENT, + HTTP_304_NOT_MODIFIED, +) +from starlite.utils.helpers import get_enum_string_value +from starlite.utils.serialization import default_serializer + +if TYPE_CHECKING: + from typing_extensions import Literal + + from starlite.datastructures import BackgroundTask, BackgroundTasks + from starlite.types import ( + HTTPResponseBodyEvent, + HTTPResponseStartEvent, + Receive, + ResponseCookies, + Scope, + Send, + ) + +T = TypeVar("T") + + +class Response(Generic[T]): + __slots__ = ( + "status_code", + "media_type", + "background", + "headers", + "cookies", + "encoding", + "body", + "status_allows_body", + "is_head_response", + ) + + def __init__( + self, + content: T, + *, + status_code: int = HTTP_200_OK, + media_type: Union[MediaType, "OpenAPIMediaType", str] = MediaType.JSON, + background: Optional[Union["BackgroundTask", "BackgroundTasks"]] = None, + headers: Optional[Dict[str, Any]] = None, + cookies: Optional["ResponseCookies"] = None, + encoding: str = "utf-8", + is_head_response: bool = False, + ) -> None: + """This is the base Starlite HTTP response class, used as the basis for + all other response classes. + + Args: + content: A value for the response body that will be rendered into bytes string. + status_code: An HTTP status code. + media_type: A value for the response 'Content-Type' header. + background: A [BackgroundTask][starlite.datastructures.BackgroundTask] instance or + [BackgroundTasks][starlite.datastructures.BackgroundTasks] to execute after the response is finished. + Defaults to None. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + cookies: A list of [Cookie][starlite.datastructures.Cookie] instances to be set under the response 'Set-Cookie' header. + encoding: The encoding to be used for the response headers. + is_head_response: Whether the response should send only the headers ("head" request) or also the content. + """ + self.status_code = status_code + self.media_type = get_enum_string_value(media_type) + self.background = background + self.headers = headers or {} + self.cookies = cookies or [] + self.encoding = encoding + self.is_head_response = is_head_response + self.status_allows_body = not ( + self.status_code in {HTTP_204_NO_CONTENT, HTTP_304_NOT_MODIFIED} or self.status_code < HTTP_200_OK + ) + self.body = self.render(content) if not self.is_head_response else b"" + + def set_cookie( + self, + key: str, + value: Optional[str] = None, + max_age: Optional[int] = None, + expires: Optional[int] = None, + path: str = "/", + domain: Optional[str] = None, + secure: bool = False, + httponly: bool = False, + samesite: 'Literal["lax", "strict", "none"]' = "lax", + ) -> None: + """Sets a cookie on the response. + + Args: + key: Key for the cookie. + value: Value for the cookie, if none given defaults to empty string. + max_age: Maximal age of the cookie before its invalidated. + expires: Expiration date as unix MS timestamp. + path: Path fragment that must exist in the request url for the cookie to be valid. Defaults to '/'. + domain: Domain for which the cookie is valid. + secure: Https is required for the cookie. + httponly: Forbids javascript to access the cookie via 'Document.cookie'. + samesite: Controls whether a cookie is sent with cross-site requests. Defaults to 'lax'. + + Returns: + None. + """ + self.cookies.append( + Cookie( + domain=domain, + expires=expires, + httponly=httponly, + key=key, + max_age=max_age, + path=path, + samesite=samesite, + secure=secure, + value=value, + ) + ) + + def set_header(self, key: str, value: str) -> None: + """Sets a header on the response. + + Args: + key: Header key. + value: Header value. + + Returns: + None. + """ + self.headers[key] = value + + def set_etag(self, etag: str) -> None: + """Sets an etag header. + + Args: + etag: An etag value. + + Returns: + None + """ + self.headers["etag"] = etag + + def delete_cookie( + self, + key: str, + path: str = "/", + domain: Optional[str] = None, + ) -> None: + """Deletes a cookie. + + Args: + key: Key of the cookie. + path: Path of the cookie. + domain: Domain of the cookie. + + Returns: + None. + """ + cookie = Cookie(key=key, path=path, domain=domain, expires=0, max_age=0) + self.cookies = [c for c in self.cookies if c != cookie] + self.cookies.append(cookie) + + @staticmethod + def serializer(value: Any) -> Any: + """Serializer hook for orjson to handle pydantic models. + + Args: + value: A value to serialize + Returns: + A serialized value + Raises: + TypeError: if value is not supported + """ + return default_serializer(value) + + def render(self, content: Any) -> bytes: + """ + Handles the rendering of content T into a bytes string. + Args: + content: A value for the response body that will be rendered into bytes string. + + Returns: + An encoded bytes string + """ + if self.status_allows_body: + if isinstance(content, bytes): + return content + if isinstance(content, str): + return content.encode(self.encoding) + if self.media_type == MediaType.JSON: + try: + return dumps(content, default=self.serializer, option=OPT_SERIALIZE_NUMPY | OPT_OMIT_MICROSECONDS) + except (AttributeError, ValueError, TypeError) as e: + raise ImproperlyConfiguredException("Unable to serialize response content") from e + if isinstance(content, OpenAPI): + content_dict = content.dict(by_alias=True, exclude_none=True) + if self.media_type == OpenAPIMediaType.OPENAPI_YAML: + return cast("bytes", dump_yaml(content_dict, default_flow_style=False).encode("utf-8")) + return dumps(content_dict, option=OPT_INDENT_2 | OPT_OMIT_MICROSECONDS) + if content is None: + return b"" + raise ImproperlyConfiguredException( + f"unable to render response body for the given {content} with media_type {self.media_type}" + ) + if content is not None: + raise ImproperlyConfiguredException( + f"status_code {self.status_code} does not support a response body value" + ) + return b"" + + @property + def content_length(self) -> Optional[int]: + """ + + Returns: + The content length of the body (e.g. for use in a "Content-Length" header). + If the response does not have a body, this value is `None` + """ + if self.status_allows_body: + return len(self.body) + return None + + @property + def encoded_headers(self) -> List[Tuple[bytes, bytes]]: + """ + Notes: + - A 'Content-Length' header will be added if appropriate and not provided by the user. + + Returns: + A list of tuples containing the headers and cookies of the request in a format ready for ASGI transmission. + """ + + if self.media_type.startswith("text/"): + content_type = f"{self.media_type}; charset={self.encoding}" + else: + content_type = self.media_type + + encoded_headers = [ + *((k.lower().encode("latin-1"), str(v).lower().encode("latin-1")) for k, v in self.headers.items()), + *((b"set-cookie", cookie.to_header(header="").encode("latin-1")) for cookie in self.cookies), + (b"content-type", content_type.encode("latin-1")), + ] + + if self.content_length and not any(key == b"content-length" for key, _ in encoded_headers): + encoded_headers.append((b"content-length", str(self.content_length).encode("latin-1"))) + return encoded_headers + + async def after_response(self) -> None: + """Executed after the response is sent. + + Returns: + None + """ + if self.background is not None: + await self.background() + + async def start_response(self, send: "Send") -> None: + """ + Emits the start event of the response. This event includes the headers and status codes. + Args: + send: The ASGI send function. + + Returns: + None + """ + event: "HTTPResponseStartEvent" = { + "type": "http.response.start", + "status": self.status_code, + "headers": self.encoded_headers, + } + + await send(event) + + async def send_body(self, send: "Send", receive: "Receive") -> None: # pylint: disable=unused-argument + """Emits the response body. + + Args: + send: The ASGI send function. + receive: The ASGI receive function. + + Notes: + - Response subclasses should customize this method if there is a need to customize sending data. + + Returns: + None + """ + event: "HTTPResponseBodyEvent" = {"type": "http.response.body", "body": self.body, "more_body": False} + await send(event) + + async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: + """The call method of the response is an "ASGIApp". + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + + Returns: + None + """ + await self.start_response(send=send) + + if self.is_head_response: + event: "HTTPResponseBodyEvent" = {"type": "http.response.body", "body": b"", "more_body": False} + await send(event) + else: + await self.send_body(send=send, receive=receive) + + await self.after_response() diff --git a/starlite/response/file.py b/starlite/response/file.py new file mode 100644 index 0000000000..f764f5f376 --- /dev/null +++ b/starlite/response/file.py @@ -0,0 +1,165 @@ +from email.utils import formatdate +from mimetypes import guess_type +from os.path import basename +from pathlib import Path +from stat import S_ISREG +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Union, cast +from urllib.parse import quote +from zlib import adler32 + +from anyio import open_file + +from starlite.enums import MediaType +from starlite.exceptions import ImproperlyConfiguredException +from starlite.response.streaming import StreamingResponse +from starlite.status_codes import HTTP_200_OK + +if TYPE_CHECKING: + + from os import PathLike + from os import stat_result as stat_result_type + + from typing_extensions import Literal + + from starlite.datastructures import BackgroundTask, BackgroundTasks + from starlite.types import ResponseCookies + +ONE_MEGA_BYTE: int = 1024 * 1024 + + +async def async_file_iterator(file_path: Union[str, "PathLike", Path], chunk_size: int) -> AsyncGenerator[bytes, None]: + """ + A generator function that asynchronously reads a file and yields its chunks. + Args: + file_path: A path to a file. + chunk_size: The chunk file to use. + + Returns: + An async generator. + """ + async with await open_file(file_path, mode="rb") as file: + yield await file.read(chunk_size) + + +class FileResponse(StreamingResponse): + __slots__ = ("stat_result", "filename", "chunk_size") + + def __init__( + self, + path: Union[str, "PathLike", "Path"], + *, + status_code: int = HTTP_200_OK, + media_type: Optional[Union["Literal[MediaType.TEXT]", str]] = None, + background: Optional[Union["BackgroundTask", "BackgroundTasks"]] = None, + headers: Optional[Dict[str, Any]] = None, + cookies: Optional["ResponseCookies"] = None, + encoding: str = "utf-8", + is_head_response: bool = False, + filename: Optional[str] = None, + stat_result: Optional["stat_result_type"] = None, + chunk_size: int = ONE_MEGA_BYTE, + content_disposition_type: "Literal['attachment', 'inline']" = "attachment", + ) -> None: + """This class allows streaming a file as response body. + + Notes: + - This class extends the [StreamingReesponse][starlite.response.StreamingResponse] class. + + Args: + path: A file path in one of the supported formats. + status_code: An HTTP status code. + media_type: A value for the response 'Content-Type' header. If not provided, the value will be either + derived from the filename if provided and supported by the stdlib, or will default to + 'application/octet-stream'. + background: A [BackgroundTask][starlite.datastructures.BackgroundTask] instance or + [BackgroundTasks][starlite.datastructures.BackgroundTasks] to execute after the response is finished. + Defaults to None. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + cookies: A list of [Cookie][starlite.datastructures.Cookie] instances to be set under the response 'Set-Cookie' header. + encoding: The encoding to be used for the response headers. + is_head_response: Whether the response should send only the headers ("head" request) or also the content. + filename: An optional filename to set in the header. + stat_result: An optional result of calling 'os.stat'. If not provided, this will be done by the response + constructor. + chunk_size: The chunk sizes to use when streaming the file. Defaults to 1MB. + content_disposition_type: The type of the 'Content-Disposition'. Either 'inline' or 'attachment'. + """ + if not media_type: + mimetype, _ = guess_type(filename) if filename else (None, None) + media_type = mimetype or "application/octet-stream" + + super().__init__( + content=async_file_iterator(file_path=path, chunk_size=chunk_size), + status_code=status_code, + media_type=media_type, + background=background, + headers=headers, + cookies=cookies, + encoding=encoding, + is_head_response=is_head_response, + ) + self.stat_result = cast("stat_result_type", self._get_stat_result(path=path, stat_result=stat_result)) + self.set_header("last-modified", formatdate(self.stat_result.st_mtime, usegmt=True)) + self.set_header( + "content-disposition", + self._get_content_disposition( + filename=filename or basename(path), content_disposition_type=content_disposition_type + ), + ) + self.set_etag(self._create_etag(path=path)) + + def _create_etag(self, path: Union[str, "PathLike"]) -> str: + """Creates an etag. + + Notes: + - Function is derived from flask. + + Returns: + An etag. + """ + check = adler32(str(path).encode("utf-8")) & 0xFFFFFFFF + return f'"{self.stat_result.st_mtime}-{self.stat_result.st_size}-{check}"' + + @staticmethod + def _get_stat_result(path: Union[str, "PathLike"], stat_result: Optional["stat_result_type"]) -> "stat_result_type": + """ + + Args: + stat_result: An optional [stat_result][os.stat_result] instance. + + Returns: + An [stat_result][os.stat_result] instance. + """ + try: + if stat_result is None: + stat_result = Path(path).stat() + if not S_ISREG(stat_result.st_mode): + raise ImproperlyConfiguredException(f"{path} is not a file") + return stat_result + except FileNotFoundError as e: + raise ImproperlyConfiguredException(f"file {path} doesn't exist") from e + + @staticmethod + def _get_content_disposition(filename: str, content_disposition_type: "Literal['attachment', 'inline']") -> str: + """ + + Args: + content_disposition_type: The Content-Disposition type of the file. + + Returns: + A value for the 'Content-Disposition' header. + """ + quoted_filename = quote(filename) + is_utf8 = quoted_filename == filename + if is_utf8: + return f'{content_disposition_type}; filename="{filename}"' + return f"{content_disposition_type}; filename*=utf-8''{quoted_filename}" + + @property + def content_length(self) -> Optional[int]: + """ + + Returns: + Returns the value of 'self.stat_result.st_size' to populate the 'Content-Length' header. + """ + return self.stat_result.st_size diff --git a/starlite/response/redirect.py b/starlite/response/redirect.py new file mode 100644 index 0000000000..bb3d3bc27c --- /dev/null +++ b/starlite/response/redirect.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from urllib.parse import quote + +from starlite.constants import REDIRECT_STATUS_CODES +from starlite.enums import MediaType +from starlite.exceptions import ImproperlyConfiguredException +from starlite.response.base import Response +from starlite.status_codes import HTTP_307_TEMPORARY_REDIRECT + +if TYPE_CHECKING: + from typing_extensions import Literal + + from starlite.datastructures import BackgroundTask, BackgroundTasks + from starlite.types import ResponseCookies + + +class RedirectResponse(Response[Any]): + def __init__( + self, + url: str, + *, + status_code: "Literal[301, 302, 303, 307, 308]" = HTTP_307_TEMPORARY_REDIRECT, + background: Optional[Union["BackgroundTask", "BackgroundTasks"]] = None, + headers: Optional[Dict[str, Any]] = None, + cookies: Optional["ResponseCookies"] = None, + encoding: str = "utf-8", + ) -> None: + """This class is used to send redirect responses. + + Args: + url: A url to redirect to. + status_code: An HTTP status code. The status code should be one of 301, 302, 303, 307 or 308, + otherwise an exception will be raised. . + headers: A string keyed dictionary of response headers. Header keys are insensitive. + cookies: A list of [Cookie][starlite.datastructures.Cookie] instances to be set under the response 'Set-Cookie' header. + encoding: The encoding to be used for the response headers. + Raises: + [ImproperlyConfiguredException][starlite.exceptions.ImproperlyConfiguredException]: If status code is not a redirect status code. + """ + if status_code not in REDIRECT_STATUS_CODES: + raise ImproperlyConfiguredException( + f"{status_code} is not a valid for this response. " + f"Redirect responses should have one of " + f"the following status codes: {', '.join([str(s) for s in REDIRECT_STATUS_CODES])}" + ) + super().__init__( + background=background, + content=b"", + cookies=cookies, + headers={**(headers or {}), "location": quote(url, safe="/#%[]=:;$&()+,!?*@'~")}, + media_type=MediaType.TEXT, + status_code=status_code, + encoding=encoding, + ) diff --git a/starlite/response/streaming.py b/starlite/response/streaming.py new file mode 100644 index 0000000000..dc10547e94 --- /dev/null +++ b/starlite/response/streaming.py @@ -0,0 +1,124 @@ +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Dict, + Optional, + Union, +) + +from anyio import CancelScope, create_task_group + +from starlite.enums import MediaType +from starlite.response.base import Response +from starlite.status_codes import HTTP_200_OK +from starlite.types.composite import StreamType +from starlite.utils.sync import AsyncIteratorWrapper + +if TYPE_CHECKING: + from starlite.datastructures import BackgroundTask, BackgroundTasks + from starlite.enums import OpenAPIMediaType + from starlite.types import HTTPResponseBodyEvent, Receive, ResponseCookies, Send + + +class StreamingResponse(Response[StreamType[Union[str, bytes]]]): + __slots__ = ("iterator",) + + def __init__( + self, + content: StreamType[Union[str, bytes]], + *, + status_code: int = HTTP_200_OK, + media_type: Union[MediaType, "OpenAPIMediaType", str] = MediaType.JSON, + background: Optional[Union["BackgroundTask", "BackgroundTasks"]] = None, + headers: Optional[Dict[str, Any]] = None, + cookies: Optional["ResponseCookies"] = None, + encoding: str = "utf-8", + is_head_response: bool = False, + ): + """This class is an HTTP response that streams the response data as a + series of ASGI 'http.response.body' events. + + Args: + content: A sync or async iterator or iterable. + status_code: An HTTP status code. + media_type: A value for the response 'Content-Type' header. + background: A [BackgroundTask][starlite.datastructures.BackgroundTask] instance or + [BackgroundTasks][starlite.datastructures.BackgroundTasks] to execute after the response is finished. + Defaults to None. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + cookies: A list of [Cookie][starlite.datastructures.Cookie] instances to be set under the response 'Set-Cookie' header. + encoding: The encoding to be used for the response headers. + is_head_response: Whether the response should send only the headers ("head" request) or also the content. + """ + super().__init__( + background=background, + content=b"", # type: ignore[arg-type] + cookies=cookies, + encoding=encoding, + headers=headers, + media_type=media_type, + status_code=status_code, + is_head_response=is_head_response, + ) + self.iterator: Union[AsyncIterable[Union[str, bytes]], AsyncGenerator[Union[str, bytes], None]] = ( + content if isinstance(content, (AsyncIterable, AsyncIterator)) else AsyncIteratorWrapper(content) + ) + + async def _listen_for_disconnect(self, cancel_scope: "CancelScope", receive: "Receive") -> None: + """ + Listens for a cancellation message, and if received - calls cancel on the cancel scope. + + Args: + cancel_scope: A task group cancel scope instance. + receive: The ASGI receive function. + + Returns: + None + """ + if not cancel_scope.cancel_called: + message = await receive() + if message["type"] == "http.disconnect": + # despite the IDE warning, this is not a coroutine because anyio 3+ changed this. + # therefore make sure not to await this. + cancel_scope.cancel() + else: + await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive) + + async def _stream(self, send: "Send") -> None: + """Sends the chunks from the iterator as a stream of ASGI + 'http.response.body' events. + + Args: + send: The ASGI Send function. + + Returns: + None + """ + async for chunk in self.iterator: + stream_event: "HTTPResponseBodyEvent" = { + "type": "http.response.body", + "body": chunk if isinstance(chunk, bytes) else chunk.encode(self.encoding), + "more_body": True, + } + await send(stream_event) + terminus_event: "HTTPResponseBodyEvent" = {"type": "http.response.body", "body": b"", "more_body": False} + await send(terminus_event) + + async def send_body(self, send: "Send", receive: "Receive") -> None: + """Emits a stream of events correlating with the response body. + + Args: + send: The ASGI send function. + receive: The ASGI receive function. + + Returns: + None + """ + + async with create_task_group() as task_group: + task_group.start_soon(partial(self._stream, send)) + await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive) diff --git a/starlite/response/template.py b/starlite/response/template.py new file mode 100644 index 0000000000..d82d9e842f --- /dev/null +++ b/starlite/response/template.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from starlite.enums import MediaType +from starlite.response.base import Response +from starlite.status_codes import HTTP_200_OK + +if TYPE_CHECKING: + from starlite.datastructures import BackgroundTask, BackgroundTasks + from starlite.template import TemplateEngineProtocol + from starlite.types import ResponseCookies + + +class TemplateResponse(Response[bytes]): + def __init__( + self, + template_name: str, + *, + template_engine: "TemplateEngineProtocol", + context: Dict[str, Any], + status_code: int = HTTP_200_OK, + background: Optional[Union["BackgroundTask", "BackgroundTasks"]] = None, + headers: Optional[Dict[str, Any]] = None, + cookies: Optional["ResponseCookies"] = None, + encoding: str = "utf-8", + ) -> None: + """Handles the rendering of a given template into a bytes string. + + Args: + template_name: Path-like name for the template to be rendered, e.g. "index.html". + template_engine: The template engine class to use to render the response. + status_code: A value for the response HTTP status code. + context: A dictionary of key/value pairs to be passed to the temple engine's render method. + background: A [BackgroundTask][starlite.datastructures.BackgroundTask] instance or + [BackgroundTasks][starlite.datastructures.BackgroundTasks] to execute after the response is finished. + Defaults to None. + headers: A string keyed dictionary of response headers. Header keys are insensitive. + cookies: A list of [Cookie][starlite.datastructures.Cookie] instances to be set under the response 'Set-Cookie' header. + """ + template = template_engine.get_template(template_name) + super().__init__( + background=background, + content=template.render(**context), + cookies=cookies, + encoding=encoding, + headers=headers, + media_type=MediaType.HTML, + status_code=status_code, + ) diff --git a/starlite/router.py b/starlite/router.py index 8001323b67..58cda3aded 100644 --- a/starlite/router.py +++ b/starlite/router.py @@ -171,7 +171,7 @@ def register(self, value: ControllerRouterHandler) -> List["BaseRoute"]: if existing_handlers: route_handlers.extend(unique(existing_handlers)) existing_route_index = find_index( - self.routes, lambda x: x.path == path # pylint: disable=cell-var-from-loop + self.routes, lambda x: x.path == path # pylint: disable=cell-var-from-loop # noqa: B023 ) if existing_route_index == -1: # pragma: no cover diff --git a/starlite/routes/asgi.py b/starlite/routes/asgi.py index 47684a9d21..0bd7d869a7 100644 --- a/starlite/routes/asgi.py +++ b/starlite/routes/asgi.py @@ -1,11 +1,10 @@ from typing import TYPE_CHECKING, Any, cast -from starlette.routing import get_name - from starlite.connection import ASGIConnection from starlite.controller import Controller from starlite.enums import ScopeType from starlite.routes.base import BaseRoute +from starlite.utils import get_name if TYPE_CHECKING: from starlite.handlers.asgi import ASGIRouteHandler diff --git a/starlite/routes/http.py b/starlite/routes/http.py index d2672ee134..8719bdadf3 100644 --- a/starlite/routes/http.py +++ b/starlite/routes/http.py @@ -4,24 +4,29 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from anyio.to_thread import run_sync -from starlette.responses import RedirectResponse -from starlette.routing import get_name from starlite.connection import Request from starlite.controller import Controller from starlite.enums import ScopeType from starlite.exceptions import ImproperlyConfiguredException +from starlite.response import RedirectResponse from starlite.routes.base import BaseRoute from starlite.signature import get_signature_model -from starlite.utils import is_async_callable +from starlite.utils import get_name, is_async_callable if TYPE_CHECKING: - from starlette.responses import Response as StarletteResponse - from starlite.handlers.http import HTTPRouteHandler from starlite.kwargs import KwargsModel from starlite.response import Response - from starlite.types import AnyCallable, HTTPScope, Method, Receive, Scope, Send + from starlite.types import ( + AnyCallable, + ASGIApp, + HTTPScope, + Method, + Receive, + Scope, + Send, + ) class HTTPRoute(BaseRoute): @@ -73,7 +78,7 @@ async def handle(self, scope: "HTTPScope", receive: "Receive", send: "Send") -> scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model ) - await response(scope, receive, send) # type: ignore[arg-type] + await response(scope, receive, send) after_response_handler = route_handler.resolve_after_response() if after_response_handler: await after_response_handler(request) # type: ignore @@ -96,7 +101,7 @@ async def _get_response_for_request( request: Request[Any, Any], route_handler: "HTTPRouteHandler", parameter_model: "KwargsModel", - ) -> "StarletteResponse": + ) -> "ASGIApp": """Handles creating a response instance and/or using cache. Args: @@ -106,9 +111,9 @@ async def _get_response_for_request( parameter_model: The Handler's KwargsModel Returns: - An instance of StarletteResponse or a subclass of it + An instance of Response or a compatible ASGIApp or a subclass of it """ - response: Optional["StarletteResponse"] = None + response: Optional["ASGIApp"] = None if route_handler.cache: response = await self._get_cached_response(request=request, route_handler=route_handler) @@ -127,7 +132,7 @@ async def _get_response_for_request( async def _call_handler_function( self, scope: "Scope", request: Request, parameter_model: "KwargsModel", route_handler: "HTTPRouteHandler" - ) -> "StarletteResponse": + ) -> "ASGIApp": """Calls the before request handlers, retrieves any data required for the route handler, and calls the route handler's to_response method. @@ -185,9 +190,7 @@ async def _get_response_data( return fn() @staticmethod - async def _get_cached_response( - request: Request, route_handler: "HTTPRouteHandler" - ) -> Optional["StarletteResponse"]: + async def _get_cached_response(request: Request, route_handler: "HTTPRouteHandler") -> Optional["ASGIApp"]: """Retrieves and un-pickles the cached response, if existing. Args: @@ -203,13 +206,13 @@ async def _get_cached_response( cached_response = await cache.get(key=cache_key) if cached_response: - return cast("StarletteResponse", pickle.loads(cached_response)) # nosec # noqa: SCS113 + return cast("ASGIApp", pickle.loads(cached_response)) # nosec # noqa: SCS113 return None @staticmethod async def _set_cached_response( - response: Union["Response", "StarletteResponse"], request: Request, route_handler: "HTTPRouteHandler" + response: Union["Response", "ASGIApp"], request: Request, route_handler: "HTTPRouteHandler" ) -> None: """Pickles and caches a response object.""" cache = request.app.cache diff --git a/starlite/routes/websocket.py b/starlite/routes/websocket.py index 3eb477369a..113dcbfe0a 100644 --- a/starlite/routes/websocket.py +++ b/starlite/routes/websocket.py @@ -1,12 +1,11 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, cast -from starlette.routing import get_name - from starlite.controller import Controller from starlite.enums import ScopeType from starlite.exceptions import ImproperlyConfiguredException from starlite.routes.base import BaseRoute from starlite.signature import get_signature_model +from starlite.utils import get_name if TYPE_CHECKING: from starlite.connection import WebSocket diff --git a/starlite/signature.py b/starlite/signature.py index f33fccd7aa..729db1b452 100644 --- a/starlite/signature.py +++ b/starlite/signature.py @@ -374,6 +374,6 @@ def get_signature_model(value: Any) -> Type[SignatureModel]: """Helper function to retrieve and validate the signature model from a provider or handler.""" try: - return cast("Type[SignatureModel]", getattr(value, "signature_model")) + return cast("Type[SignatureModel]", value.signature_model) except AttributeError as e: # pragma: no cover raise ImproperlyConfiguredException(f"The 'signature_model' attribute for {value} is not set") from e diff --git a/starlite/status_codes.py b/starlite/status_codes.py new file mode 100644 index 0000000000..2c9371d6dd --- /dev/null +++ b/starlite/status_codes.py @@ -0,0 +1,198 @@ +"""This file includes code adapted from +https://github.com/encode/starlette/blob/master/starlette/status.py. + +Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from typing_extensions import Literal # noqa: TC002 + +# HTTP Status Codes + +HTTP_100_CONTINUE: Literal[100] = 100 +HTTP_101_SWITCHING_PROTOCOLS: Literal[101] = 101 +HTTP_102_PROCESSING: Literal[102] = 102 +HTTP_103_EARLY_HINTS: Literal[103] = 103 +HTTP_200_OK: Literal[200] = 200 +HTTP_201_CREATED: Literal[201] = 201 +HTTP_202_ACCEPTED: Literal[202] = 202 +HTTP_203_NON_AUTHORITATIVE_INFORMATION: Literal[203] = 203 +HTTP_204_NO_CONTENT: Literal[204] = 204 +HTTP_205_RESET_CONTENT: Literal[205] = 205 +HTTP_206_PARTIAL_CONTENT: Literal[206] = 206 +HTTP_207_MULTI_STATUS: Literal[207] = 207 +HTTP_208_ALREADY_REPORTED: Literal[208] = 208 +HTTP_226_IM_USED: Literal[226] = 226 +HTTP_300_MULTIPLE_CHOICES: Literal[300] = 300 +HTTP_301_MOVED_PERMANENTLY: Literal[301] = 301 +HTTP_302_FOUND: Literal[302] = 302 +HTTP_303_SEE_OTHER: Literal[303] = 303 +HTTP_304_NOT_MODIFIED: Literal[304] = 304 +HTTP_305_USE_PROXY: Literal[305] = 305 +HTTP_306_RESERVED: Literal[306] = 306 +HTTP_307_TEMPORARY_REDIRECT: Literal[307] = 307 +HTTP_308_PERMANENT_REDIRECT: Literal[308] = 308 +HTTP_400_BAD_REQUEST: Literal[400] = 400 +HTTP_401_UNAUTHORIZED: Literal[401] = 401 +HTTP_402_PAYMENT_REQUIRED: Literal[402] = 402 +HTTP_403_FORBIDDEN: Literal[403] = 403 +HTTP_404_NOT_FOUND: Literal[404] = 404 +HTTP_405_METHOD_NOT_ALLOWED: Literal[405] = 405 +HTTP_406_NOT_ACCEPTABLE: Literal[406] = 406 +HTTP_407_PROXY_AUTHENTICATION_REQUIRED: Literal[407] = 407 +HTTP_408_REQUEST_TIMEOUT: Literal[408] = 408 +HTTP_409_CONFLICT: Literal[409] = 409 +HTTP_410_GONE: Literal[410] = 410 +HTTP_411_LENGTH_REQUIRED: Literal[411] = 411 +HTTP_412_PRECONDITION_FAILED: Literal[412] = 412 +HTTP_413_REQUEST_ENTITY_TOO_LARGE: Literal[413] = 413 +HTTP_414_REQUEST_URI_TOO_LONG: Literal[414] = 414 +HTTP_415_UNSUPPORTED_MEDIA_TYPE: Literal[415] = 415 +HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE: Literal[416] = 416 +HTTP_417_EXPECTATION_FAILED: Literal[417] = 417 +HTTP_418_IM_A_TEAPOT: Literal[418] = 418 +HTTP_421_MISDIRECTED_REQUEST: Literal[421] = 421 +HTTP_422_UNPROCESSABLE_ENTITY: Literal[422] = 422 +HTTP_423_LOCKED: Literal[423] = 423 +HTTP_424_FAILED_DEPENDENCY: Literal[424] = 424 +HTTP_425_TOO_EARLY: Literal[425] = 425 +HTTP_426_UPGRADE_REQUIRED: Literal[426] = 426 +HTTP_428_PRECONDITION_REQUIRED: Literal[428] = 428 +HTTP_429_TOO_MANY_REQUESTS: Literal[429] = 429 +HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE: Literal[431] = 431 +HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS: Literal[451] = 451 +HTTP_500_INTERNAL_SERVER_ERROR: Literal[500] = 500 +HTTP_501_NOT_IMPLEMENTED: Literal[501] = 501 +HTTP_502_BAD_GATEWAY: Literal[502] = 502 +HTTP_503_SERVICE_UNAVAILABLE: Literal[503] = 503 +HTTP_504_GATEWAY_TIMEOUT: Literal[504] = 504 +HTTP_505_HTTP_VERSION_NOT_SUPPORTED: Literal[505] = 505 +HTTP_506_VARIANT_ALSO_NEGOTIATES: Literal[506] = 506 +HTTP_507_INSUFFICIENT_STORAGE: Literal[507] = 507 +HTTP_508_LOOP_DETECTED: Literal[508] = 508 +HTTP_510_NOT_EXTENDED: Literal[510] = 510 +HTTP_511_NETWORK_AUTHENTICATION_REQUIRED: Literal[511] = 511 + +# Websocket Codes +WS_1000_NORMAL_CLOSURE: Literal[1000] = 1000 +WS_1001_GOING_AWAY: Literal[1001] = 1001 +WS_1002_PROTOCOL_ERROR: Literal[1002] = 1002 +WS_1003_UNSUPPORTED_DATA: Literal[1003] = 1003 +WS_1005_NO_STATUS_RECEIVED: Literal[1005] = 1005 +WS_1006_ABNORMAL_CLOSURE: Literal[1006] = 1006 +WS_1007_INVALID_FRAME_PAYLOAD_DATA: Literal[1007] = 1007 +WS_1008_POLICY_VIOLATION: Literal[1008] = 1008 +WS_1009_MESSAGE_TOO_BIG: Literal[1009] = 1009 +WS_1010_MANDATORY_EXT: Literal[1010] = 1010 +WS_1011_INTERNAL_ERROR: Literal[1011] = 1011 +WS_1012_SERVICE_RESTART: Literal[1012] = 1012 +WS_1013_TRY_AGAIN_LATER: Literal[1013] = 1013 +WS_1014_BAD_GATEWAY: Literal[1014] = 1014 +WS_1015_TLS_HANDSHAKE: Literal[1015] = 1015 + + +__all__ = [ + "HTTP_100_CONTINUE", + "HTTP_101_SWITCHING_PROTOCOLS", + "HTTP_102_PROCESSING", + "HTTP_103_EARLY_HINTS", + "HTTP_200_OK", + "HTTP_201_CREATED", + "HTTP_202_ACCEPTED", + "HTTP_203_NON_AUTHORITATIVE_INFORMATION", + "HTTP_204_NO_CONTENT", + "HTTP_205_RESET_CONTENT", + "HTTP_206_PARTIAL_CONTENT", + "HTTP_207_MULTI_STATUS", + "HTTP_208_ALREADY_REPORTED", + "HTTP_226_IM_USED", + "HTTP_300_MULTIPLE_CHOICES", + "HTTP_301_MOVED_PERMANENTLY", + "HTTP_302_FOUND", + "HTTP_303_SEE_OTHER", + "HTTP_304_NOT_MODIFIED", + "HTTP_305_USE_PROXY", + "HTTP_306_RESERVED", + "HTTP_307_TEMPORARY_REDIRECT", + "HTTP_308_PERMANENT_REDIRECT", + "HTTP_400_BAD_REQUEST", + "HTTP_401_UNAUTHORIZED", + "HTTP_402_PAYMENT_REQUIRED", + "HTTP_403_FORBIDDEN", + "HTTP_404_NOT_FOUND", + "HTTP_405_METHOD_NOT_ALLOWED", + "HTTP_406_NOT_ACCEPTABLE", + "HTTP_407_PROXY_AUTHENTICATION_REQUIRED", + "HTTP_408_REQUEST_TIMEOUT", + "HTTP_409_CONFLICT", + "HTTP_410_GONE", + "HTTP_411_LENGTH_REQUIRED", + "HTTP_412_PRECONDITION_FAILED", + "HTTP_413_REQUEST_ENTITY_TOO_LARGE", + "HTTP_414_REQUEST_URI_TOO_LONG", + "HTTP_415_UNSUPPORTED_MEDIA_TYPE", + "HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE", + "HTTP_417_EXPECTATION_FAILED", + "HTTP_418_IM_A_TEAPOT", + "HTTP_421_MISDIRECTED_REQUEST", + "HTTP_422_UNPROCESSABLE_ENTITY", + "HTTP_423_LOCKED", + "HTTP_424_FAILED_DEPENDENCY", + "HTTP_425_TOO_EARLY", + "HTTP_426_UPGRADE_REQUIRED", + "HTTP_428_PRECONDITION_REQUIRED", + "HTTP_429_TOO_MANY_REQUESTS", + "HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE", + "HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS", + "HTTP_500_INTERNAL_SERVER_ERROR", + "HTTP_501_NOT_IMPLEMENTED", + "HTTP_502_BAD_GATEWAY", + "HTTP_503_SERVICE_UNAVAILABLE", + "HTTP_504_GATEWAY_TIMEOUT", + "HTTP_505_HTTP_VERSION_NOT_SUPPORTED", + "HTTP_506_VARIANT_ALSO_NEGOTIATES", + "HTTP_507_INSUFFICIENT_STORAGE", + "HTTP_508_LOOP_DETECTED", + "HTTP_510_NOT_EXTENDED", + "HTTP_511_NETWORK_AUTHENTICATION_REQUIRED", + "WS_1000_NORMAL_CLOSURE", + "WS_1001_GOING_AWAY", + "WS_1002_PROTOCOL_ERROR", + "WS_1003_UNSUPPORTED_DATA", + "WS_1005_NO_STATUS_RECEIVED", + "WS_1006_ABNORMAL_CLOSURE", + "WS_1007_INVALID_FRAME_PAYLOAD_DATA", + "WS_1008_POLICY_VIOLATION", + "WS_1009_MESSAGE_TOO_BIG", + "WS_1010_MANDATORY_EXT", + "WS_1011_INTERNAL_ERROR", + "WS_1012_SERVICE_RESTART", + "WS_1013_TRY_AGAIN_LATER", + "WS_1014_BAD_GATEWAY", + "WS_1015_TLS_HANDSHAKE", +] diff --git a/starlite/template/base.py b/starlite/template/base.py index 1b9faefeaf..0d666d35b8 100644 --- a/starlite/template/base.py +++ b/starlite/template/base.py @@ -4,7 +4,7 @@ from typing_extensions import Protocol, TypedDict, runtime_checkable if TYPE_CHECKING: - from starlite import Request + from starlite.connection import Request class TemplateContext(TypedDict): @@ -52,6 +52,7 @@ def url_for_static_asset(context: TemplateContext, name: str, file_path: str) -> be used in templates. Args: + context: The template context object. name: A static handler unique name. file_path: a string containing path to an asset. diff --git a/starlite/testing/request_factory.py b/starlite/testing/request_factory.py index 50bf3dcf06..b9d04be177 100644 --- a/starlite/testing/request_factory.py +++ b/starlite/testing/request_factory.py @@ -34,10 +34,13 @@ def _default_route_handler() -> None: ... +default_app = Starlite(route_handlers=[_default_route_handler]) + + class RequestFactory: def __init__( self, - app: Starlite = Starlite(route_handlers=[_default_route_handler]), + app: Starlite = default_app, server: str = "test.org", port: int = 3000, root_path: str = "", @@ -265,7 +268,7 @@ def _create_request_with_data( if request_media_type == RequestEncodingType.JSON: encoding_headers, stream = encode_json(data) elif request_media_type == RequestEncodingType.MULTI_PART: - encoding_headers, stream = encode_multipart_data(data, files=files or []) + encoding_headers, stream = encode_multipart_data(data, files=files or []) # type: ignore[assignment] else: encoding_headers, stream = encode_urlencoded_data(loads(dumps(data, default=default_serializer))) headers.update(encoding_headers) diff --git a/starlite/testing/test_client.py b/starlite/testing/test_client.py index 1f812d180b..403f0a25d1 100644 --- a/starlite/testing/test_client.py +++ b/starlite/testing/test_client.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: from typing_extensions import Literal - from starlite import Request, WebSocket from starlite.config import ( BaseLoggingConfig, CacheConfig, @@ -18,12 +17,14 @@ StaticFilesConfig, TemplateConfig, ) + from starlite.connection import Request, WebSocket from starlite.middleware.session import SessionCookieConfig from starlite.plugins.base import PluginProtocol from starlite.types import ( AfterExceptionHookHandler, AfterRequestHookHandler, AfterResponseHookHandler, + ASGIApp, BeforeMessageSendHookHandler, BeforeRequestHookHandler, ControllerRouterHandler, @@ -54,7 +55,7 @@ class TestClient(StarletteTestClient): def __init__( self, - app: Starlite, + app: Union[Starlite, "ASGIApp"], base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", @@ -96,7 +97,7 @@ def __enter__(self) -> "TestClient": Returns: TestClient """ - return super().__enter__() # pyright: ignore + return cast("TestClient", super().__enter__()) def create_session_cookies(self, session_data: Dict[str, Any]) -> Dict[str, str]: """Creates raw session cookies that are loaded into the session by the diff --git a/starlite/types/callable_types.py b/starlite/types/callable_types.py index ef6d0276d0..f3b8b9f72c 100644 --- a/starlite/types/callable_types.py +++ b/starlite/types/callable_types.py @@ -1,12 +1,10 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, TypeVar, Union -from .asgi_types import Message, Scope +from .asgi_types import ASGIApp, Message, Scope from .helpers import SyncOrAsyncUnion from .internal_types import StarliteType if TYPE_CHECKING: - from starlette.responses import Response as StarletteResponse # noqa: TC004 - from starlite.config import AppConfig # noqa: TC004 from starlite.connection import Request, WebSocket # noqa: TC004 from starlite.datastructures.state import State # noqa: TC004 @@ -18,7 +16,6 @@ HTTPRouteHandler = Any Request = Any Response = Any - StarletteResponse = Any State = Any WebSocket = Any WebsocketRouteHandler = Any @@ -28,7 +25,7 @@ AfterExceptionHookHandler = Callable[[Exception, Scope, State], SyncOrAsyncUnion[None]] AfterRequestHookHandler = Union[ - Callable[[StarletteResponse], SyncOrAsyncUnion[StarletteResponse]], Callable[[Response], SyncOrAsyncUnion[Response]] + Callable[[ASGIApp], SyncOrAsyncUnion[ASGIApp]], Callable[[Response], SyncOrAsyncUnion[Response]] ] AfterResponseHookHandler = Callable[[Request], SyncOrAsyncUnion[None]] AsyncAnyCallable = Callable[..., Awaitable[Any]] @@ -38,7 +35,7 @@ ] BeforeRequestHookHandler = Callable[[Request], Union[Any, Awaitable[Any]]] CacheKeyBuilder = Callable[[Request], str] -ExceptionHandler = Callable[[Request, _ExceptionT], StarletteResponse] +ExceptionHandler = Callable[[Request, _ExceptionT], Response] Guard = Union[ Callable[[Request, HTTPRouteHandler], SyncOrAsyncUnion[None]], Callable[[WebSocket, WebsocketRouteHandler], SyncOrAsyncUnion[None]], diff --git a/starlite/types/composite.py b/starlite/types/composite.py index 065e4c01b2..57afef5cdc 100644 --- a/starlite/types/composite.py +++ b/starlite/types/composite.py @@ -1,4 +1,17 @@ -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Callable, + Dict, + Iterable, + Iterator, + List, + Type, + TypeVar, + Union, +) from .asgi_types import ASGIApp from .callable_types import ExceptionHandler @@ -26,6 +39,8 @@ ResponseHeader = Any StarletteMiddleware = Any +T = TypeVar("T") + Dependencies = Dict[str, Provide] ExceptionHandlersMap = Dict[Union[int, Type[Exception]], ExceptionHandler] @@ -36,3 +51,4 @@ ParametersMap = Dict[str, FieldInfo] ResponseCookies = List[Cookie] ResponseHeadersMap = Dict[str, ResponseHeader] +StreamType = Union[Iterable[T], Iterator[T], AsyncIterable[T], AsyncIterator[T]] diff --git a/starlite/utils/__init__.py b/starlite/utils/__init__.py index d375cf4222..6d254b30a9 100644 --- a/starlite/utils/__init__.py +++ b/starlite/utils/__init__.py @@ -6,6 +6,7 @@ get_exception_handler, ) from .extractors import ConnectionDataExtractor, ResponseDataExtractor, obfuscate +from .helpers import get_enum_string_value, get_name from .model import ( convert_dataclass_to_model, convert_typeddict_to_model, @@ -13,7 +14,6 @@ ) from .path import join_paths, normalize_path from .predicates import ( - is_async_callable, is_class_and_subclass, is_dataclass_class_or_instance_typeguard, is_dataclass_class_typeguard, @@ -23,10 +23,17 @@ from .scope import get_serializer_from_scope from .sequence import find_index, unique from .serialization import default_serializer -from .sync import AsyncCallable, as_async_callable_list, async_partial +from .sync import ( + AsyncCallable, + AsyncIteratorWrapper, + as_async_callable_list, + async_partial, + is_async_callable, +) __all__ = ( "AsyncCallable", + "AsyncIteratorWrapper", "ConnectionDataExtractor", "ExceptionResponseContent", "ResponseDataExtractor", @@ -38,7 +45,11 @@ "create_parsed_model_field", "default_serializer", "find_index", + "generate_csrf_hash", + "generate_csrf_token", + "get_enum_string_value", "get_exception_handler", + "get_name", "get_serializer_from_scope", "is_async_callable", "is_class_and_subclass", @@ -52,6 +63,4 @@ "obfuscate", "should_skip_dependency_validation", "unique", - "generate_csrf_hash", - "generate_csrf_token", ) diff --git a/starlite/utils/exception.py b/starlite/utils/exception.py index 676bc333f4..34183e2503 100644 --- a/starlite/utils/exception.py +++ b/starlite/utils/exception.py @@ -2,16 +2,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from pydantic import BaseModel -from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from starlite.enums import MediaType -from starlite.exceptions.http_exceptions import HTTPException -from starlite.response import Response +from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR if TYPE_CHECKING: from typing import Type + from starlite.response import Response from starlite.types import ExceptionHandler, ExceptionHandlersMap @@ -36,15 +34,13 @@ def get_exception_handler(exception_handlers: "ExceptionHandlersMap", exc: Excep """ if not exception_handlers: return None - if isinstance(exc, (StarletteHTTPException, HTTPException)) and exc.status_code in exception_handlers: - return exception_handlers[exc.status_code] + status_code: Optional[int] = getattr(exc, "status_code", None) + if status_code in exception_handlers: + return exception_handlers[status_code] for cls in getmro(type(exc)): if cls in exception_handlers: return exception_handlers[cast("Type[Exception]", cls)] - if ( - not isinstance(exc, (StarletteHTTPException, HTTPException)) - and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers - ): + if not hasattr(exc, "status_code") and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers: return exception_handlers[HTTP_500_INTERNAL_SERVER_ERROR] return None @@ -59,8 +55,25 @@ class ExceptionResponseContent(BaseModel): extra: Optional[Union[Dict[str, Any], List[Any]]] = None """An extra mapping to attach to the exception.""" + def to_response(self) -> "Response": + """Creates a response from the model attributes. -def create_exception_response(exc: Exception) -> Response: + Returns: + A response instance. + """ + from starlite.response import ( # pylint: disable=import-outside-toplevel + Response, + ) + + return Response( + content=self.dict(exclude_none=True, exclude={"headers"}), + headers=self.headers, + media_type=MediaType.JSON, + status_code=self.status_code, + ) + + +def create_exception_response(exc: Exception) -> "Response": """Constructs a response from an exception. For instances of either `starlite.exceptions.HTTPException` or `starlette.exceptions.HTTPException` the response @@ -72,15 +85,10 @@ def create_exception_response(exc: Exception) -> Response: Returns: Response: HTTP response constructed from exception details. """ - if isinstance(exc, (HTTPException, StarletteHTTPException)): - content = ExceptionResponseContent(detail=exc.detail, status_code=exc.status_code) - if isinstance(exc, HTTPException): - content.extra = exc.extra - else: - content = ExceptionResponseContent(detail=repr(exc), status_code=HTTP_500_INTERNAL_SERVER_ERROR) - return Response( - media_type=MediaType.JSON, - content=content.dict(exclude_none=True), - status_code=content.status_code, - headers=exc.headers if isinstance(exc, (HTTPException, StarletteHTTPException)) else None, + content = ExceptionResponseContent( + status_code=getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR), + detail=getattr(exc, "detail", repr(exc)), + headers=getattr(exc, "headers", None), + extra=getattr(exc, "extra", None), ) + return content.to_response() diff --git a/starlite/utils/extractors.py b/starlite/utils/extractors.py index 3a66d61a9f..eb5fc32187 100644 --- a/starlite/utils/extractors.py +++ b/starlite/utils/extractors.py @@ -11,12 +11,12 @@ cast, ) -from starlette.requests import cookie_parser from typing_extensions import Literal, TypedDict from starlite.connection import Request from starlite.datastructures.upload_file import UploadFile from starlite.enums import HttpMethod, RequestEncodingType +from starlite.parsers import parse_cookie_string if TYPE_CHECKING: from starlite.connection import ASGIConnection @@ -402,7 +402,7 @@ def extract_cookies(self, messages: Tuple["HTTPResponseStartEvent", "HTTPRespons The Response's cookies dict. """ cookie_string = ";".join( - list( + list( # noqa: C417 map( lambda x: x[1].decode("latin-1"), filter(lambda x: x[0].lower() == b"set-cookie", messages[0]["headers"]), @@ -410,6 +410,6 @@ def extract_cookies(self, messages: Tuple["HTTPResponseStartEvent", "HTTPRespons ) ) if cookie_string: - parsed_cookies = cookie_parser(cookie_string) + parsed_cookies = parse_cookie_string(cookie_string) return obfuscate(parsed_cookies, self.obfuscate_cookies) if self.obfuscate_cookies else parsed_cookies return {} diff --git a/starlite/utils/helpers.py b/starlite/utils/helpers.py new file mode 100644 index 0000000000..55feb5b464 --- /dev/null +++ b/starlite/utils/helpers.py @@ -0,0 +1,31 @@ +from enum import Enum +from typing import Any, Union, cast + + +def get_name(value: Any) -> str: + """Helper to get the '__name__' dunder of a value. + + Args: + value: An arbitrary value. + + Returns: + A name string. + """ + + if hasattr(value, "__name__"): + return cast("str", value.__name__) + return type(value).__name__ + + +def get_enum_string_value(value: Union[Enum, str]) -> str: + """A helper function to return the string value of a string enum. + + See: https://github.com/starlite-api/starlite/pull/633#issuecomment-1286519267 + + Args: + value: An enum or string. + + Returns: + A string. + """ + return cast("str", value.value) if isinstance(value, Enum) else value diff --git a/starlite/utils/predicates.py b/starlite/utils/predicates.py index 49d7895bd2..e47451384c 100644 --- a/starlite/utils/predicates.py +++ b/starlite/utils/predicates.py @@ -1,9 +1,7 @@ -import asyncio -import functools import sys from dataclasses import is_dataclass from inspect import isclass -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Type, TypeVar, Union from typing_extensions import ParamSpec, TypeGuard, get_args, get_origin, is_typeddict @@ -28,22 +26,6 @@ T = TypeVar("T") -def is_async_callable(value: Callable[P, T]) -> TypeGuard[Callable[P, Awaitable[T]]]: - """Extends `asyncio.iscoroutinefunction()` to additionally detect async - `partial` objects and class instances with `async def __call__()` defined. - - Args: - value: Any - - Returns: - Bool determining if type of `value` is an awaitable. - """ - while isinstance(value, functools.partial): - value = value.func # type: ignore[unreachable] - - return asyncio.iscoroutinefunction(value) or (callable(value) and asyncio.iscoroutinefunction(value.__call__)) # type: ignore[operator] - - def is_class_and_subclass(value: Any, t_type: Type[T]) -> TypeGuard[Type[T]]: """Return `True` if `value` is a `class` and is a subtype of `t_type`. diff --git a/starlite/utils/sync.py b/starlite/utils/sync.py index f4e8827445..7282859ae8 100644 --- a/starlite/utils/sync.py +++ b/starlite/utils/sync.py @@ -1,16 +1,44 @@ +from asyncio import iscoroutinefunction from functools import partial from inspect import getfullargspec, ismethod -from typing import Any, Callable, Dict, Generic, List, TypeVar, Union, cast +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + TypeVar, + Union, + cast, +) from anyio.to_thread import run_sync -from typing_extensions import Literal, ParamSpec - -from starlite.utils.predicates import is_async_callable +from typing_extensions import Literal, ParamSpec, TypeGuard P = ParamSpec("P") T = TypeVar("T") +def is_async_callable(value: Callable[P, T]) -> TypeGuard[Callable[P, Awaitable[T]]]: + """Extends `asyncio.iscoroutinefunction()` to additionally detect async + `partial` objects and class instances with `async def __call__()` defined. + + Args: + value: Any + + Returns: + Bool determining if type of `value` is an awaitable. + """ + while isinstance(value, partial): + value = value.func # type: ignore[unreachable] + + return iscoroutinefunction(value) or (callable(value) and iscoroutinefunction(value.__call__)) # type: ignore[operator] + + class AsyncCallable(Generic[P, T]): __slots__ = ("args", "kwargs", "wrapped_callable", "is_method", "num_expected_args") @@ -22,10 +50,10 @@ def __init__(self, fn: Callable[P, T]) -> None: fn: Callable to wrap - can be any sync or async callable. """ - self.is_method = ismethod(fn) + self.is_method = ismethod(fn) or (callable(fn) and ismethod(fn.__call__)) # type: ignore self.num_expected_args = len(getfullargspec(fn).args) - (1 if self.is_method else 0) self.wrapped_callable: Dict[Literal["fn"], Callable] = { - "fn": fn if is_async_callable(fn) else async_partial(fn) + "fn": fn if is_async_callable(fn) else async_partial(fn) # pyright: ignore } async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: @@ -71,3 +99,34 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: return await run_sync(applied_kwarg, *args) return wrapper + + +class AsyncIteratorWrapper(Generic[T]): + def __init__(self, iterator: Union[Iterator[T], Iterable[T]]) -> None: + """Take a sync iterator or iterable and yields values from it + asynchronously. + + Args: + iterator: A sync iterator or iterable. + """ + self.iterator = iterator if isinstance(iterator, Iterator) else iter(iterator) + self.generator = self._async_generator() + + def _call_next(self) -> T: + try: + return next(self.iterator) + except StopIteration as e: + raise ValueError from e + + async def _async_generator(self) -> AsyncGenerator[T, None]: + while True: + try: + yield await run_sync(self._call_next) + except ValueError: + return + + def __aiter__(self) -> "AsyncIteratorWrapper[T]": + return self + + async def __anext__(self) -> T: + return await self.generator.__anext__() diff --git a/tests/__init__.py b/tests/__init__.py index bbacda1ae9..1097a41b47 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -72,7 +72,7 @@ class TypedDictPerson(TypedDict): pets: Optional[List[Pet]] -class Car(Base): +class Car(Base): # pyright: ignore __tablename__ = "db_cars" id = Column(INTEGER, primary_key=True) diff --git a/tests/app/test_app_config.py b/tests/app/test_app_config.py index 073cae177e..bb04b7b7db 100644 --- a/tests/app/test_app_config.py +++ b/tests/app/test_app_config.py @@ -48,6 +48,7 @@ def app_config_object() -> AppConfig: template_config=None, request_class=None, websocket_class=None, + etag=None, ) diff --git a/tests/app/test_before_send.py b/tests/app/test_before_send.py index 2ddcd2f231..0f83d53983 100644 --- a/tests/app/test_before_send.py +++ b/tests/app/test_before_send.py @@ -1,9 +1,9 @@ from typing import TYPE_CHECKING, Dict from starlette.datastructures import MutableHeaders -from starlette.status import HTTP_200_OK from starlite import get +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client if TYPE_CHECKING: diff --git a/tests/app/test_error_handling.py b/tests/app/test_error_handling.py index cf3d37cbd1..aa40eac698 100644 --- a/tests/app/test_error_handling.py +++ b/tests/app/test_error_handling.py @@ -1,7 +1,6 @@ -from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR - from starlite import MediaType, Request, Response, Starlite, get, post from starlite.exceptions import InternalServerException +from starlite.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from starlite.testing import TestClient, create_test_client from tests import Person diff --git a/tests/asgi_router/test_asgi_router.py b/tests/asgi_router/test_asgi_router.py new file mode 100644 index 0000000000..9e59f57a03 --- /dev/null +++ b/tests/asgi_router/test_asgi_router.py @@ -0,0 +1,39 @@ +import pytest + +from starlite.testing import create_test_client + + +class _LifeSpanCallable: + def __init__(self, should_raise: bool = False) -> None: + self.called = False + self.should_raise = should_raise + + def __call__(self) -> None: + self.called = True + if self.should_raise: + raise RuntimeError("damn") + + +def test_life_span_startup() -> None: + life_span_callable = _LifeSpanCallable() + with create_test_client([], on_startup=[life_span_callable]): + assert life_span_callable.called + + +def test_life_span_startup_error_handling() -> None: + life_span_callable = _LifeSpanCallable(should_raise=True) + with pytest.raises(RuntimeError), create_test_client([], on_startup=[life_span_callable]): + pass + + +def test_life_span_shutdown() -> None: + life_span_callable = _LifeSpanCallable() + with create_test_client([], on_shutdown=[life_span_callable]): + pass + assert life_span_callable.called + + +def test_life_span_shutdown_error_handling() -> None: + life_span_callable = _LifeSpanCallable(should_raise=True) + with pytest.raises(RuntimeError), create_test_client([], on_shutdown=[life_span_callable]): + pass diff --git a/tests/connection/request/test_request.py b/tests/connection/request/test_request.py index e7613fb278..0f01bc1ed3 100644 --- a/tests/connection/request/test_request.py +++ b/tests/connection/request/test_request.py @@ -1,19 +1,31 @@ +"""A large part of the tests in this file were adapted from: + +https://github.com/encode/starlette/blob/master/tests/test_requests.py. And are +meant to ensure our compatibility with their API. +""" + import sys -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from unittest.mock import patch import pytest from orjson import JSONDecodeError +from starlette.datastructures import Address -from starlite import StaticFilesConfig, get -from starlite.connection import Request -from starlite.testing import create_test_client +from starlite import InternalServerException, MediaType, StaticFilesConfig, get +from starlite.connection import Request, empty_send +from starlite.datastructures import State +from starlite.response import Response +from starlite.status_codes import HTTP_200_OK +from starlite.testing import TestClient, create_test_client if TYPE_CHECKING: from pathlib import Path + from starlite.types import Receive, Scope, Send + -@pytest.mark.skipif(sys.version_info < (3, 8), reason="skipping due to python 3.7 async failures") # type: ignore[misc] +@pytest.mark.skipif(sys.version_info < (3, 8), reason="skipping due to python 3.7 async failures") async def test_request_empty_body_to_json(anyio_backend: str) -> None: with patch.object(Request, "body", return_value=b""): request_empty_payload: Request = Request(scope={"type": "http"}) # type: ignore @@ -21,14 +33,14 @@ async def test_request_empty_body_to_json(anyio_backend: str) -> None: assert request_json is None -@pytest.mark.skipif(sys.version_info < (3, 8), reason="skipping due to python 3.7 async failures") # type: ignore[misc] +@pytest.mark.skipif(sys.version_info < (3, 8), reason="skipping due to python 3.7 async failures") async def test_request_invalid_body_to_json(anyio_backend: str) -> None: with patch.object(Request, "body", return_value=b"invalid"), pytest.raises(JSONDecodeError): request_empty_payload: Request = Request(scope={"type": "http"}) # type: ignore await request_empty_payload.json() -@pytest.mark.skipif(sys.version_info < (3, 8), reason="skipping due to python 3.7 async failures") # type: ignore[misc] +@pytest.mark.skipif(sys.version_info < (3, 8), reason="skipping due to python 3.7 async failures") async def test_request_valid_body_to_json(anyio_backend: str) -> None: with patch.object(Request, "body", return_value=b'{"test": "valid"}'): request_empty_payload: Request = Request(scope={"type": "http"}) # type: ignore @@ -104,3 +116,341 @@ def handler(request: MyRequest) -> None: with create_test_client(route_handlers=[handler], request_class=MyRequest) as client: client.get("/") assert value["called"] + + +def test_request_url() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + data = {"method": request.method, "url": str(request.url)} + response = Response(content=data, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/123?a=abc") + assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"} + + response = client.get("https://example.org:123/") + assert response.json() == {"method": "GET", "url": "https://example.org:123/"} + + +def test_request_query_params() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + params = dict(request.query_params) + response = Response(content={"params": params}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/?a=123&b=456") + assert response.json() == {"params": {"a": ["123"], "b": ["456"]}} + + +def test_request_headers() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + headers = dict(request.headers) + response = Response(content={"headers": headers}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/", headers={"host": "example.org"}) + assert response.json() == { + "headers": { + "host": "example.org", + "user-agent": "testclient", + "accept-encoding": "gzip, deflate, br", + "accept": "*/*", + "connection": "keep-alive", + } + } + + +@pytest.mark.parametrize( + "scope,expected_client", + ( + ({"type": "http", "client": ["client", 42]}, Address("client", 42)), + ({"type": "http", "client": None}, None), + ({"type": "http"}, None), + ), +) +def test_request_client(scope: "Scope", expected_client: Optional[Address]) -> None: + client = Request[Any, Any](scope).client + assert client == expected_client + + +def test_request_body() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + body = await request.body() + response = Response(content={"body": body.decode()}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + + response = client.get("/") + assert response.json() == {"body": ""} + + response = client.post("/", json={"a": "123"}) + assert response.json() == {"body": '{"a": "123"}'} + + response = client.post("/", content="abc") + assert response.json() == {"body": "abc"} + + +def test_request_stream() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + body = b"" + async for chunk in request.stream(): + body += chunk + response = Response(content={"body": body.decode()}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + + response = client.get("/") + assert response.json() == {"body": ""} + + response = client.post("/", json={"a": "123"}) + assert response.json() == {"body": '{"a": "123"}'} + + response = client.post("/", content="abc") + assert response.json() == {"body": "abc"} + + +def test_request_form_urlencoded() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + form = await request.form() + response = Response(content={"form": dict(form)}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + + response = client.post("/", data={"abc": "123 @"}) + assert response.json() == {"form": {"abc": "123 @"}} + + +def test_request_body_then_stream() -> None: + async def app(scope: "Any", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + body = await request.body() + chunks = b"" + async for chunk in request.stream(): + chunks += chunk + response = Response( + content={"body": body.decode(), "stream": chunks.decode()}, + status_code=HTTP_200_OK, + media_type=MediaType.JSON, + ) + await response(scope, receive, send) + + client = TestClient(app) + + response = client.post("/", content="abc") + assert response.json() == {"body": "abc", "stream": "abc"} + + +def test_request_stream_then_body() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + chunks = b"" + async for chunk in request.stream(): + chunks += chunk + try: + body = await request.body() + except InternalServerException: + body = b"" + response = Response( + content={"body": body.decode(), "stream": chunks.decode()}, + status_code=HTTP_200_OK, + media_type=MediaType.JSON, + ) + await response(scope, receive, send) + + client = TestClient(app) + + response = client.post("/", content="abc") + assert response.json() == {"body": "", "stream": "abc"} + + +def test_request_json() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + data = await request.json() + response = Response(content={"json": data}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.post("/", json={"a": "123"}) + assert response.json() == {"json": {"a": "123"}} + + +def test_request_raw_path() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + path = str(request.scope["path"]) + raw_path = str(request.scope["raw_path"]) + response = Response(content=f"{path}, {raw_path}", status_code=HTTP_200_OK, media_type=MediaType.TEXT) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/he%2Fllo") + assert response.text == "/he/llo, b'/he%2Fllo'" + + +def test_request_without_setting_receive() -> None: + """If Request is instantiated without the 'receive' channel, then .body() + is not available.""" + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope) + try: + data = await request.json() + except RuntimeError: + data = "Receive channel not available" + response = Response(content={"json": data}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.post("/", json={"a": "123"}) + assert response.json() == {"json": "Receive channel not available"} + + +async def test_request_disconnect() -> None: + """If a client disconnect occurs while reading request body then + InternalServerException should be raised.""" + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + await request.body() + + async def receiver() -> dict: + return {"type": "http.disconnect"} + + with pytest.raises(InternalServerException): + await app({"type": "http", "method": "POST", "path": "/"}, receiver, empty_send) # type: ignore + + +def test_request_state_object() -> None: + scope = {"state": {"old": "foo"}} + + s = State(scope["state"]) + + s.new = "value" + assert s.new == "value" + + del s.new + + with pytest.raises(AttributeError): + s.new + + +def test_request_state() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + scope["state"] = {} + request = Request[Any, Any](scope, receive) + request.state.example = 123 + response = Response( + content={"state.example": request.state.example}, status_code=HTTP_200_OK, media_type=MediaType.JSON + ) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/123?a=abc") + assert response.json() == {"state.example": 123} + + +def test_request_cookies() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + + request = Request[Any, Any](scope, receive) + mycookie = request.cookies.get("mycookie") + if mycookie: + response = Response(content=mycookie, media_type="text/plain", status_code=HTTP_200_OK) + else: + response = Response(content="Hello, world!", media_type=MediaType.TEXT, status_code=HTTP_200_OK) + response.set_cookie("mycookie", "Hello, cookies!") + + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.text == "Hello, world!" + response = client.get("/") + assert response.text == "Hello, cookies!" + + +def test_chunked_encoding() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope, receive) + body = await request.body() + response = Response(content={"body": body.decode()}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + + def post_body() -> Generator[bytes, None, None]: + yield b"foo" + yield b"bar" + + response = client.post("/", content=post_body()) + assert response.json() == {"body": "foobar"} + + +def test_request_send_push_promise() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + # the server is push-enabled + scope["extensions"]["http.response.push"] = {} # type: ignore + + request = Request[Any, Any](scope, receive, send) + await request.send_push_promise("/style.css") + + response = Response(content={"json": "OK"}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.json() == {"json": "OK"} + + +def test_request_send_push_promise_without_push_extension() -> None: + """If server does not support the `http.response.push` extension, + + .send_push_promise() does nothing. + """ + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + request = Request[Any, Any](scope) + await request.send_push_promise("/style.css") + + response = Response(content={"json": "OK"}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.json() == {"json": "OK"} + + +def test_request_send_push_promise_without_setting_send() -> None: + """If Request is instantiated without the send channel, then. + + .send_push_promise() is not available. + """ + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + # the server is push-enabled + scope["extensions"]["http.response.push"] = {} # type: ignore + + data = "OK" + request = Request[Any, Any](scope) + try: + await request.send_push_promise("/style.css") + except RuntimeError: + data = "Send channel not available" + response = Response(content={"json": data}, status_code=HTTP_200_OK, media_type=MediaType.JSON) + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.json() == {"json": "Send channel not available"} diff --git a/tests/connection/request/test_starlette_tests.py b/tests/connection/request/test_starlette_tests.py deleted file mode 100644 index e6115c7357..0000000000 --- a/tests/connection/request/test_starlette_tests.py +++ /dev/null @@ -1,357 +0,0 @@ -"""The tests in this file were adapted from: - -https://github.com/encode/starlette/blob/master/tests/test_requests.py. -""" - -from typing import TYPE_CHECKING, Any, Optional - -import pytest -from starlette.datastructures import Address, State -from starlette.status import HTTP_200_OK -from starlette.testclient import TestClient - -from starlite import InternalServerException, MediaType -from starlite.connection import Request, empty_send -from starlite.response import Response - -if TYPE_CHECKING: - from starlite.types import Receive, Send - - -def test_request_url() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - data = {"method": request.method, "url": str(request.url)} - response = Response(content=data, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/123?a=abc") - assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"} - - response = client.get("https://example.org:123/") - assert response.json() == {"method": "GET", "url": "https://example.org:123/"} - - -def test_request_query_params() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - params = dict(request.query_params) - response = Response(content={"params": params}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/?a=123&b=456") - assert response.json() == {"params": {"a": ["123"], "b": ["456"]}} - - -def test_request_headers() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - headers = dict(request.headers) - response = Response(content={"headers": headers}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/", headers={"host": "example.org"}) - assert response.json() == { - "headers": { - "host": "example.org", - "user-agent": "testclient", - "accept-encoding": "gzip, deflate, br", - "accept": "*/*", - "connection": "keep-alive", - } - } - - -@pytest.mark.parametrize( - "scope,expected_client", - [ - ({"client": ["client", 42]}, Address("client", 42)), - ({"client": None}, None), - ({}, None), - ], -) -def test_request_client(scope: Any, expected_client: Optional[Address]) -> None: - scope.update({"type": "http"}) # required by Request's constructor - client = Request(scope).client - assert client == expected_client - - -def test_request_body() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - body = await request.body() - response = Response(content={"body": body.decode()}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - - response = client.get("/") - assert response.json() == {"body": ""} - - response = client.post("/", json={"a": "123"}) - assert response.json() == {"body": '{"a": "123"}'} - - response = client.post("/", content="abc") - assert response.json() == {"body": "abc"} - - -def test_request_stream() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - body = b"" - async for chunk in request.stream(): - body += chunk - response = Response(content={"body": body.decode()}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - - response = client.get("/") - assert response.json() == {"body": ""} - - response = client.post("/", json={"a": "123"}) - assert response.json() == {"body": '{"a": "123"}'} - - response = client.post("/", content="abc") - assert response.json() == {"body": "abc"} - - -def test_request_form_urlencoded() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - form = await request.form() - response = Response(content={"form": dict(form)}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - - response = client.post("/", data={"abc": "123 @"}) - assert response.json() == {"form": {"abc": "123 @"}} - - -def test_request_body_then_stream() -> None: - async def app(scope: "Any", receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - body = await request.body() - chunks = b"" - async for chunk in request.stream(): - chunks += chunk - response = Response( - content={"body": body.decode(), "stream": chunks.decode()}, - status_code=HTTP_200_OK, - media_type=MediaType.JSON, - ) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - - response = client.post("/", content="abc") - assert response.json() == {"body": "abc", "stream": "abc"} - - -def test_request_stream_then_body() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - chunks = b"" - async for chunk in request.stream(): - chunks += chunk - try: - body = await request.body() - except InternalServerException: - body = b"" - response = Response( - content={"body": body.decode(), "stream": chunks.decode()}, - status_code=HTTP_200_OK, - media_type=MediaType.JSON, - ) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - - response = client.post("/", content="abc") - assert response.json() == {"body": "", "stream": "abc"} - - -def test_request_json() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - data = await request.json() - response = Response(content={"json": data}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.post("/", json={"a": "123"}) - assert response.json() == {"json": {"a": "123"}} - - -def test_request_raw_path() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - path = request.scope["path"] - raw_path = request.scope["raw_path"] - response = Response(content=f"{path}, {raw_path}", status_code=HTTP_200_OK, media_type=MediaType.TEXT) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/he%2Fllo") - assert response.text == "/he/llo, b'/he%2Fllo'" - - -def test_request_without_setting_receive() -> None: - """If Request is instantiated without the 'receive' channel, then .body() - is not available.""" - - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope) - try: - data = await request.json() - except RuntimeError: - data = "Receive channel not available" - response = Response(content={"json": data}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.post("/", json={"a": "123"}) - assert response.json() == {"json": "Receive channel not available"} - - -async def test_request_disconnect() -> None: - """If a client disconnect occurs while reading request body then - InternalServerException should be raised.""" - - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - await request.body() - - async def receiver(): - return {"type": "http.disconnect"} - - scope = {"type": "http", "method": "POST", "path": "/"} - with pytest.raises(InternalServerException): - await app(scope, receiver, empty_send) - - -def test_request_state_object() -> None: - scope = {"state": {"old": "foo"}} - - s = State(scope["state"]) - - s.new = "value" - assert s.new == "value" - - del s.new - - with pytest.raises(AttributeError): - s.new - - -def test_request_state() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - scope["state"] = {} - request = Request(scope, receive) - request.state.example = 123 - response = Response( - content={"state.example": request.state.example}, status_code=HTTP_200_OK, media_type=MediaType.JSON - ) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/123?a=abc") - assert response.json() == {"state.example": 123} - - -def test_request_cookies() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - mycookie = request.cookies.get("mycookie") - if mycookie: - response = Response(content=mycookie, media_type="text/plain", status_code=HTTP_200_OK) - else: - response = Response(content="Hello, world!", media_type=MediaType.TEXT, status_code=HTTP_200_OK) - response.set_cookie("mycookie", "Hello, cookies!") - - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/") - assert response.text == "Hello, world!" - response = client.get("/") - assert response.text == "Hello, cookies!" - - -def test_chunked_encoding() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope, receive) - body = await request.body() - response = Response(content={"body": body.decode()}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - - def post_body(): - yield b"foo" - yield b"bar" - - response = client.post("/", content=post_body()) - assert response.json() == {"body": "foobar"} - - -def test_request_send_push_promise() -> None: - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - # the server is push-enabled - scope["extensions"]["http.response.push"] = {} - - request = Request(scope, receive, send) - await request.send_push_promise("/style.css") - - response = Response(content={"json": "OK"}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/") - assert response.json() == {"json": "OK"} - - -def test_request_send_push_promise_without_push_extension() -> None: - """If server does not support the `http.response.push` extension, - - .send_push_promise() does nothing. - """ - - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - request = Request(scope) - await request.send_push_promise("/style.css") - - response = Response(content={"json": "OK"}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/") - assert response.json() == {"json": "OK"} - - -def test_request_send_push_promise_without_setting_send() -> None: - """If Request is instantiated without the send channel, then. - - .send_push_promise() is not available. - """ - - async def app(scope: Any, receive: "Receive", send: "Send") -> None: - # the server is push-enabled - scope["extensions"]["http.response.push"] = {} - - data = "OK" - request = Request(scope) - try: - await request.send_push_promise("/style.css") - except RuntimeError: - data = "Send channel not available" - response = Response(content={"json": data}, status_code=HTTP_200_OK, media_type=MediaType.JSON) - await response(scope, receive, send) - - client = TestClient(app) # type: ignore - response = client.get("/") - assert response.json() == {"json": "Send channel not available"} diff --git a/tests/connection/websocket/test_starlette_tests.py b/tests/connection/websocket/test_starlette_tests.py deleted file mode 100644 index 5087faadf6..0000000000 --- a/tests/connection/websocket/test_starlette_tests.py +++ /dev/null @@ -1,309 +0,0 @@ -"""The tests in this file were adapted from: - -https://github.com/encode/starlette/blob/master/tests/test_websockets.py -""" - -from typing import TYPE_CHECKING - -import anyio -import pytest -from starlette import status -from starlette.testclient import TestClient -from starlette.websockets import WebSocketDisconnect - -from starlite import WebSocketException -from starlite.connection import WebSocket - -if TYPE_CHECKING: - from starlite.types import Receive, Scope, Send - - -def test_websocket_url(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - await websocket.send_json({"url": str(websocket.url)}) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/123?a=abc") as websocket: - data = websocket.receive_json() - assert data == {"url": "ws://testserver/123?a=abc"} - - -def test_websocket_binary_json(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - message = await websocket.receive_json(mode="binary") - await websocket.send_json(message, mode="binary") - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/123?a=abc") as websocket: - websocket.send_json({"test": "data"}, mode="binary") - data = websocket.receive_json(mode="binary") - assert data == {"test": "data"} - - -def test_websocket_query_params(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - query_params = dict(websocket.query_params) - await websocket.accept() - await websocket.send_json({"params": query_params}) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/?a=abc&b=456") as websocket: - data = websocket.receive_json() - assert data == {"params": {"a": ["abc"], "b": ["456"]}} - - -def test_websocket_headers(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - headers = dict(websocket.headers) - await websocket.accept() - await websocket.send_json({"headers": headers}) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - expected_headers = { - "accept": "*/*", - "accept-encoding": "gzip, deflate, br", - "connection": "upgrade", - "host": "testserver", - "user-agent": "testclient", - "sec-websocket-key": "testserver==", - "sec-websocket-version": "13", - } - data = websocket.receive_json() - assert data == {"headers": expected_headers} - - -def test_websocket_port(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - await websocket.send_json({"port": websocket.url.port}) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket: - data = websocket.receive_json() - assert data == {"port": 123} - - -def test_websocket_send_and_receive_text(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - data = await websocket.receive_text() - await websocket.send_text("Message was: " + data) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - websocket.send_text("Hello, world!") - data = websocket.receive_text() - assert data == "Message was: Hello, world!" - - -def test_websocket_send_and_receive_bytes(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - data = await websocket.receive_bytes() - await websocket.send_bytes(b"Message was: " + data) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - websocket.send_bytes(b"Hello, world!") - data = websocket.receive_bytes() - assert data == b"Message was: Hello, world!" - - -def test_websocket_send_and_receive_json(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - data = await websocket.receive_json() - await websocket.send_json({"message": data}) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - websocket.send_json({"hello": "world"}) - data = websocket.receive_json() - assert data == {"message": {"hello": "world"}} - - -def test_websocket_concurrency_pattern(): - stream_send, stream_receive = anyio.create_memory_object_stream() - - async def reader(websocket): - async with stream_send: - data = await websocket.receive_json() - await stream_send.send(data) - - async def writer(websocket): - async with stream_receive: - async for message in stream_receive: - await websocket.send_json(message) - - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - async with anyio.create_task_group() as task_group: - task_group.start_soon(reader, websocket) - await writer(websocket) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - websocket.send_json({"hello": "world"}) - data = websocket.receive_json() - assert data == {"hello": "world"} - - -def test_client_close(): - close_code = None - - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - nonlocal close_code - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - try: - await websocket.receive_text() - except WebSocketException as exc: - close_code = exc.code - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - websocket.close(code=status.WS_1001_GOING_AWAY) - assert close_code == status.WS_1001_GOING_AWAY - - -def test_application_close(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - await websocket.close(status.WS_1001_GOING_AWAY) - - client = TestClient(app) - with client.websocket_connect("/") as websocket, pytest.raises(WebSocketDisconnect) as exc: - websocket.receive_text() - assert exc.value.code == status.WS_1001_GOING_AWAY - - -def test_rejected_connection(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.close(status.WS_1001_GOING_AWAY) - - client = TestClient(app) - with pytest.raises(WebSocketDisconnect) as exc, client.websocket_connect("/"): - pass # pragma: nocover - assert exc.value.code == status.WS_1001_GOING_AWAY - - -def test_subprotocol(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - assert websocket.scope["subprotocols"] == ["soap", "wamp"] - await websocket.accept(subprotocols="wamp") - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: - assert websocket.accepted_subprotocol == "wamp" - - -def test_additional_headers(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept(headers=[(b"additional", b"header")]) - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - assert websocket.extra_headers == [(b"additional", b"header")] - - -def test_no_additional_headers(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - await websocket.close() - - client = TestClient(app) - with client.websocket_connect("/") as websocket: - assert websocket.extra_headers == [] - - -def test_websocket_exception(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - raise RuntimeError - - client = TestClient(app) - with pytest.raises(RuntimeError), client.websocket_connect("/123?a=abc"): - pass # pragma: nocover - - -def test_duplicate_disconnect(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - message = await websocket.receive() - assert message["type"] == "websocket.disconnect" - message = await websocket.receive() - - client = TestClient(app) - with pytest.raises(WebSocketException), client.websocket_connect("/") as websocket: - websocket.close() - - -def test_websocket_close_reason() -> None: - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") - - client = TestClient(app) - with client.websocket_connect("/") as websocket, pytest.raises(WebSocketDisconnect) as exc: - websocket.receive_text() - assert exc.value.code == status.WS_1001_GOING_AWAY - assert exc.value.reason == "Going Away" - - -def test_receive_text_before_accept(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.receive_text() - - client = TestClient(app) - with pytest.raises(WebSocketException), client.websocket_connect("/"): - pass # pragma: nocover - - -def test_receive_bytes_before_accept(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.receive_bytes() - - client = TestClient(app) - with pytest.raises(WebSocketException), client.websocket_connect("/"): - pass # pragma: nocover - - -def test_receive_json_before_accept(): - async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.receive_json() - - client = TestClient(app) - with pytest.raises(WebSocketException), client.websocket_connect("/"): - pass # pragma: nocover diff --git a/tests/connection/websocket/test_websocket.py b/tests/connection/websocket/test_websocket.py index 619673d4dc..3c96bc2188 100644 --- a/tests/connection/websocket/test_websocket.py +++ b/tests/connection/websocket/test_websocket.py @@ -1,15 +1,26 @@ +"""A large part of the tests in this file were adapted from: + +https://github.com/encode/starlette/blob/master/tests/test_websockets.py And are +meant to ensure our compatibility with their API. +""" + from typing import TYPE_CHECKING, Any +import anyio import pytest from starlette.datastructures import Headers +from starlette.websockets import WebSocketDisconnect -from starlite import websocket +from starlite import WebSocketException, websocket from starlite.connection import WebSocket -from starlite.testing import create_test_client +from starlite.status_codes import WS_1001_GOING_AWAY +from starlite.testing import TestClient, create_test_client if TYPE_CHECKING: from typing_extensions import Literal + from starlite.types import Receive, Scope, Send + @pytest.mark.parametrize("mode", ["text", "binary"]) def test_websocket_send_receive_json(mode: "Literal['text', 'binary']") -> None: @@ -55,7 +66,7 @@ async def handler(socket: WebSocket) -> None: async def test_custom_request_class() -> None: value: Any = {} - class MyWebSocket(WebSocket): + class MyWebSocket(WebSocket[Any, Any]): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.scope["called"] = True # type: ignore @@ -69,3 +80,273 @@ async def handler(socket: MyWebSocket) -> None: with create_test_client(route_handlers=[handler], websocket_class=MyWebSocket).websocket_connect("/") as ws: ws.receive() assert value["called"] + + +def test_websocket_url() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}) + await websocket.close() + + with TestClient(app).websocket_connect("/123?a=abc") as websocket: + data = websocket.receive_json() + assert data == {"url": "ws://testserver/123?a=abc"} + + +def test_websocket_binary_json() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + message = await websocket.receive_json(mode="binary") + await websocket.send_json(message, mode="binary") + await websocket.close() + + with TestClient(app).websocket_connect("/123?a=abc") as websocket: + websocket.send_json({"test": "data"}, mode="binary") + data = websocket.receive_json(mode="binary") + assert data == {"test": "data"} + + +def test_websocket_query_params() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + query_params = dict(websocket.query_params) + await websocket.accept() + await websocket.send_json({"params": query_params}) + await websocket.close() + + with TestClient(app).websocket_connect("/?a=abc&b=456") as websocket: + data = websocket.receive_json() + assert data == {"params": {"a": ["abc"], "b": ["456"]}} + + +def test_websocket_headers() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + headers = dict(websocket.headers) + await websocket.accept() + await websocket.send_json({"headers": headers}) + await websocket.close() + + with TestClient(app).websocket_connect("/") as websocket: + expected_headers = { + "accept": "*/*", + "accept-encoding": "gzip, deflate, br", + "connection": "upgrade", + "host": "testserver", + "user-agent": "testclient", + "sec-websocket-key": "testserver==", + "sec-websocket-version": "13", + } + data = websocket.receive_json() + assert data == {"headers": expected_headers} + + +def test_websocket_port() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + await websocket.send_json({"port": websocket.url.port}) + await websocket.close() + + with TestClient(app).websocket_connect("ws://example.com:123/123?a=abc") as websocket: + data = websocket.receive_json() + assert data == {"port": 123} + + +def test_websocket_send_and_receive_text() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + data = await websocket.receive_text() + await websocket.send_text("Message was: " + data) + await websocket.close() + + with TestClient(app).websocket_connect("/") as websocket: + websocket.send_text("Hello, world!") + data = websocket.receive_text() + assert data == "Message was: Hello, world!" + + +def test_websocket_send_and_receive_bytes() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + data = await websocket.receive_bytes() + await websocket.send_bytes(b"Message was: " + data) + await websocket.close() + + with TestClient(app).websocket_connect("/") as websocket: + websocket.send_bytes(b"Hello, world!") + data = websocket.receive_bytes() + assert data == b"Message was: Hello, world!" + + +def test_websocket_send_and_receive_json() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + data = await websocket.receive_json() + await websocket.send_json({"message": data}) + await websocket.close() + + with TestClient(app).websocket_connect("/") as websocket: + websocket.send_json({"hello": "world"}) + data = websocket.receive_json() + assert data == {"message": {"hello": "world"}} + + +def test_websocket_concurrency_pattern() -> None: + stream_send, stream_receive = anyio.create_memory_object_stream() + + async def reader(websocket: WebSocket[Any, Any]) -> None: + async with stream_send: + data = await websocket.receive_json() + await stream_send.send(data) + + async def writer(websocket: WebSocket[Any, Any]) -> None: + async with stream_receive: + async for message in stream_receive: + await websocket.send_json(message) + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + async with anyio.create_task_group() as task_group: + task_group.start_soon(reader, websocket) + await writer(websocket) + await websocket.close() + + with TestClient(app).websocket_connect("/") as websocket: + websocket.send_json({"hello": "world"}) + data = websocket.receive_json() + assert data == {"hello": "world"} + + +def test_client_close() -> None: + close_code = None + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + nonlocal close_code + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + try: + await websocket.receive_text() + except WebSocketException as exc: + close_code = exc.code + + with TestClient(app).websocket_connect("/") as websocket: + websocket.close(code=WS_1001_GOING_AWAY) + assert close_code == WS_1001_GOING_AWAY + + +def test_application_close() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + await websocket.close(WS_1001_GOING_AWAY) + + with TestClient(app).websocket_connect("/") as websocket, pytest.raises(WebSocketDisconnect) as exc: + websocket.receive_text() + assert exc.value.code == WS_1001_GOING_AWAY + + +def test_rejected_connection() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.close(WS_1001_GOING_AWAY) + + with pytest.raises(WebSocketDisconnect) as exc, TestClient(app).websocket_connect("/"): + pass + assert exc.value.code == WS_1001_GOING_AWAY + + +def test_subprotocol() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + assert websocket.scope["subprotocols"] == ["soap", "wamp"] + await websocket.accept(subprotocols="wamp") + await websocket.close() + + with TestClient(app).websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: + assert websocket.accepted_subprotocol == "wamp" + + +def test_additional_headers() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept(headers=[(b"additional", b"header")]) + await websocket.close() + + with TestClient(app).websocket_connect("/") as websocket: + assert websocket.extra_headers == [(b"additional", b"header")] + + +def test_no_additional_headers() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + await websocket.close() + + with TestClient(app).websocket_connect("/") as websocket: + assert websocket.extra_headers == [] + + +def test_websocket_exception() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + raise RuntimeError + + with pytest.raises(RuntimeError), TestClient(app).websocket_connect("/123?a=abc"): + pass + + +def test_duplicate_disconnect() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + message = await websocket.receive() + assert message["type"] == "websocket.disconnect" + message = await websocket.receive() + + with pytest.raises(WebSocketException), TestClient(app).websocket_connect("/") as websocket: + websocket.close() + + +def test_websocket_close_reason() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.accept() + await websocket.close(code=WS_1001_GOING_AWAY, reason="Going Away") + + with TestClient(app).websocket_connect("/") as websocket, pytest.raises(WebSocketDisconnect) as exc: + websocket.receive_text() + assert exc.value.code == WS_1001_GOING_AWAY + assert exc.value.reason == "Going Away" + + +def test_receive_text_before_accept() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.receive_text() + + with pytest.raises(WebSocketException), TestClient(app).websocket_connect("/"): + pass + + +def test_receive_bytes_before_accept() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.receive_bytes() + + with pytest.raises(WebSocketException), TestClient(app).websocket_connect("/"): + pass + + +def test_receive_json_before_accept() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + websocket = WebSocket[Any, Any](scope, receive=receive, send=send) + await websocket.receive_json() + + with pytest.raises(WebSocketException), TestClient(app).websocket_connect("/"): + pass diff --git a/tests/datastructures/test_background_task.py b/tests/datastructures/test_background_task.py index ba5c914712..b75110220b 100644 --- a/tests/datastructures/test_background_task.py +++ b/tests/datastructures/test_background_task.py @@ -1,12 +1,11 @@ from typing import List -from starlette.status import HTTP_200_OK - from starlite import BackgroundTask, BackgroundTasks, get +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client -async def test_background_tasks() -> None: +async def test_background_tasks_regular_execution() -> None: values: List[int] = [] def extend_values(values_to_extend: List[int]) -> None: @@ -24,3 +23,24 @@ def handler() -> None: response = client.get("/") assert response.status_code == HTTP_200_OK assert values == [1, 2, 3, 4, 5, 6] + + +async def test_background_tasks_task_group_execution() -> None: + values: List[int] = [] + + def extend_values(values_to_extend: List[int]) -> None: + values.extend(values_to_extend) + + tasks = BackgroundTasks( + [BackgroundTask(extend_values, [1, 2, 3]), BackgroundTask(extend_values, values_to_extend=[4, 5, 6])], + run_in_task_group=True, + ) + + @get("/", background=tasks) + def handler() -> None: + return None + + with create_test_client(handler) as client: + response = client.get("/") + assert response.status_code == HTTP_200_OK + assert set(values) == {1, 2, 3, 4, 5, 6} diff --git a/tests/datastructures/test_cookie.py b/tests/datastructures/test_cookie.py index 192a0ee515..48270868c6 100644 --- a/tests/datastructures/test_cookie.py +++ b/tests/datastructures/test_cookie.py @@ -40,3 +40,11 @@ def test_cookie_with_max_age_as_header() -> None: def test_cookie_as_header_without_header_name() -> None: cookie = Cookie(key="key") assert cookie.to_header(header="") == 'key=""; Path=/; SameSite=lax' + + +def test_equality() -> None: + assert Cookie(key="key") == Cookie(key="key") + assert Cookie(key="key") != Cookie(key="key", path="/test") + assert Cookie(key="key", path="/test") != Cookie(key="key", path="/test", domain="localhost") + assert Cookie(key="key", path="/test", domain="localhost") == Cookie(key="key", path="/test", domain="localhost") + assert Cookie(key="key") != "key" diff --git a/tests/datastructures/test_headers.py b/tests/datastructures/test_headers.py index 8fd8069877..70ac6866e5 100644 --- a/tests/datastructures/test_headers.py +++ b/tests/datastructures/test_headers.py @@ -42,7 +42,7 @@ def test_cache_control_from_header_single_value() -> None: assert header_dict == {"no-cache": True} -@pytest.mark.parametrize("invalid_value", ["x=y=z", "x, ", "no-cache=10"]) # type: ignore[misc] +@pytest.mark.parametrize("invalid_value", ["x=y=z", "x, ", "no-cache=10"]) def test_cache_control_from_header_invalid_value(invalid_value: str) -> None: with pytest.raises(ImproperlyConfiguredException): CacheControlHeader.from_header(invalid_value) @@ -77,20 +77,20 @@ def test_etag_from_header() -> None: assert etag.weak is False -@pytest.mark.parametrize("value", ['W/"foo"', 'w/"foo"']) # type: ignore[misc] +@pytest.mark.parametrize("value", ['W/"foo"', 'w/"foo"']) def test_etag_from_header_weak(value: str) -> None: etag = ETag.from_header(value) assert etag.value == "foo" assert etag.weak is True -@pytest.mark.parametrize("value", ['"føo"', 'W/"føo"']) # type: ignore[misc] +@pytest.mark.parametrize("value", ['"føo"', 'W/"føo"']) def test_etag_from_header_non_ascii_value(value: str) -> None: with pytest.raises(ImproperlyConfiguredException): ETag.from_header(value) -@pytest.mark.parametrize("value", ["foo", "W/foo"]) # type: ignore[misc] +@pytest.mark.parametrize("value", ["foo", "W/foo"]) def test_etag_from_header_missing_quotes(value: str) -> None: with pytest.raises(ImproperlyConfiguredException): ETag.from_header(value) diff --git a/tests/dependency_injection/test_http_handler_dependency_injection.py b/tests/dependency_injection/test_http_handler_dependency_injection.py index ec0826248a..74bc30e1ee 100644 --- a/tests/dependency_injection/test_http_handler_dependency_injection.py +++ b/tests/dependency_injection/test_http_handler_dependency_injection.py @@ -1,9 +1,8 @@ from asyncio import sleep from typing import TYPE_CHECKING, Any, Dict -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST - from starlite import Controller, Provide, get +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite.testing import create_test_client if TYPE_CHECKING: diff --git a/tests/dependency_injection/test_injection_of_generic_models.py b/tests/dependency_injection/test_injection_of_generic_models.py index 82e5d336e7..db87b29867 100644 --- a/tests/dependency_injection/test_injection_of_generic_models.py +++ b/tests/dependency_injection/test_injection_of_generic_models.py @@ -2,9 +2,9 @@ from pydantic import BaseModel from pydantic.generics import GenericModel -from starlette.status import HTTP_200_OK from starlite import Provide, get +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client T = TypeVar("T") diff --git a/tests/dependency_injection/test_inter_dependencies.py b/tests/dependency_injection/test_inter_dependencies.py index aff73a3f98..dbb034e5f3 100644 --- a/tests/dependency_injection/test_inter_dependencies.py +++ b/tests/dependency_injection/test_inter_dependencies.py @@ -1,8 +1,7 @@ from random import randint -from starlette.status import HTTP_200_OK - from starlite import Controller, MediaType, Provide, get +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client diff --git a/tests/dto_factory/test_dto_factory_integration.py b/tests/dto_factory/test_dto_factory_integration.py index 6e58048de3..d27b4c4f0d 100644 --- a/tests/dto_factory/test_dto_factory_integration.py +++ b/tests/dto_factory/test_dto_factory_integration.py @@ -4,10 +4,10 @@ import pytest from pydantic import BaseModel from pydantic_factories import ModelFactory -from starlette.status import HTTP_201_CREATED from starlite import DTOFactory, post from starlite.plugins.sql_alchemy import SQLAlchemyPlugin +from starlite.status_codes import HTTP_201_CREATED from starlite.testing import create_test_client from tests import Person, TypedDictPerson, VanillaDataClassPerson from tests.plugins.sql_alchemy_plugin import Pet, WildAnimal diff --git a/tests/dto_factory/test_dto_factory_model_conversion.py b/tests/dto_factory/test_dto_factory_model_conversion.py index 372e3e68ea..9bf6b8588a 100644 --- a/tests/dto_factory/test_dto_factory_model_conversion.py +++ b/tests/dto_factory/test_dto_factory_model_conversion.py @@ -77,7 +77,7 @@ def test_conversion_from_model_instance( pets=None, ) else: - model_instance = cast("Type[Pet]", model)( + model_instance = cast("Type[Pet]", model)( # pyright: ignore id=1, species=Species.MONKEY, name="Mike", diff --git a/tests/dto_factory/test_dto_factory_partials.py b/tests/dto_factory/test_dto_factory_partials.py index e43370fa51..9c4019f95d 100644 --- a/tests/dto_factory/test_dto_factory_partials.py +++ b/tests/dto_factory/test_dto_factory_partials.py @@ -41,5 +41,5 @@ def test_partial_dto_sqlalchemy_model() -> None: # Test for partial DTO partial_dto_car = Partial[dto_car] # type: ignore - ford = partial_dto_car(**car_two) + ford = partial_dto_car(**car_two) # pyright: ignore assert ford.make == "Ford" # type: ignore diff --git a/tests/handlers/asgi/test_handle_asgi.py b/tests/handlers/asgi/test_handle_asgi.py index c15c883eba..786b9d694d 100644 --- a/tests/handlers/asgi/test_handle_asgi.py +++ b/tests/handlers/asgi/test_handle_asgi.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK - from starlite import Controller, MediaType, Response, ScopeType, asgi +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client from starlite.types import Receive, Scope, Send @@ -11,7 +10,7 @@ async def root_asgi_handler(scope: Scope, receive: Receive, send: Send) -> None: assert scope["type"] == ScopeType.HTTP assert scope["method"] == "GET" response = Response("Hello World", media_type=MediaType.TEXT, status_code=HTTP_200_OK) - await response(scope, receive, send) # type: ignore[arg-type] + await response(scope, receive, send) class MyController(Controller): path = "/asgi" @@ -21,7 +20,7 @@ async def root_asgi_handler(self, scope: Scope, receive: Receive, send: Send) -> assert scope["type"] == ScopeType.HTTP assert scope["method"] == "GET" response = Response("Hello World", media_type=MediaType.TEXT, status_code=HTTP_200_OK) - await response(scope, receive, send) # type: ignore[arg-type] + await response(scope, receive, send) with create_test_client([root_asgi_handler, MyController]) as client: response = client.get("/") diff --git a/tests/handlers/http/test_defaults.py b/tests/handlers/http/test_defaults.py index 665aeba868..603e08827c 100644 --- a/tests/handlers/http/test_defaults.py +++ b/tests/handlers/http/test_defaults.py @@ -1,10 +1,10 @@ from typing import Any import pytest -from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite import HttpMethod from starlite.handlers import HTTPRouteHandler +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT @pytest.mark.parametrize( diff --git a/tests/handlers/http/test_kwarg_handling.py b/tests/handlers/http/test_kwarg_handling.py index 7b8c9859a9..2402801869 100644 --- a/tests/handlers/http/test_kwarg_handling.py +++ b/tests/handlers/http/test_kwarg_handling.py @@ -4,7 +4,6 @@ from hypothesis import given from hypothesis import strategies as st from pydantic.main import BaseModel -from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite import ( HttpMethod, @@ -18,6 +17,7 @@ post, put, ) +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite.types import ResponseType from starlite.utils import normalize_path diff --git a/tests/handlers/http/test_to_response.py b/tests/handlers/http/test_to_response.py index 562895c6c4..ffb639a149 100644 --- a/tests/handlers/http/test_to_response.py +++ b/tests/handlers/http/test_to_response.py @@ -2,20 +2,12 @@ from json import loads from pathlib import Path from time import sleep -from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, Iterator +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterator import pytest from pydantic import ValidationError -from starlette.responses import ( - FileResponse, - HTMLResponse, - JSONResponse, - PlainTextResponse, - RedirectResponse, -) +from starlette.responses import HTMLResponse, JSONResponse, PlainTextResponse from starlette.responses import Response as StarletteResponse -from starlette.responses import StreamingResponse -from starlette.status import HTTP_200_OK, HTTP_308_PERMANENT_REDIRECT from starlite import ( Cookie, @@ -34,8 +26,14 @@ route, ) from starlite.datastructures import BackgroundTask -from starlite.response import TemplateResponse +from starlite.response import ( + FileResponse, + RedirectResponse, + StreamingResponse, + TemplateResponse, +) from starlite.signature import SignatureModelFactory +from starlite.status_codes import HTTP_200_OK, HTTP_308_PERMANENT_REDIRECT from starlite.testing import RequestFactory, create_test_client from tests import Person, PersonFactory @@ -82,7 +80,7 @@ def __init__(self) -> None: self.i = 0 self.to = 0.1 - def __aiter__(self) -> AsyncIterator[str]: + async def __aiter__(self) -> "MyAsyncIterator": return self async def __anext__(self) -> str: @@ -112,7 +110,7 @@ async def test_function(data: Person) -> Person: app=Starlite(route_handlers=[]), request=RequestFactory().get("/"), ) - assert loads(response.body) == person_instance.dict() + assert loads(response.body) == person_instance.dict() # type: ignore async def test_to_response_returning_starlite_response() -> None: @@ -132,20 +130,16 @@ def test_function() -> Response: @pytest.mark.parametrize( "expected_response", [ - Response(status_code=HTTP_200_OK, content=b"abc", media_type=MediaType.TEXT), StarletteResponse(status_code=HTTP_200_OK, content=b"abc"), PlainTextResponse(content="abc"), HTMLResponse(content="
None: - @get(path="/test") + @get(path="/test", response_cookies=[Cookie(key="my-cookies", value="abc", path="/test")]) def test_function() -> StarletteResponse: return expected_response @@ -156,7 +150,8 @@ def test_function() -> StarletteResponse: data=route_handler.fn(), plugins=[], app=client.app, request=RequestFactory().get("/") # type: ignore ) assert isinstance(response, StarletteResponse) - assert response is expected_response + assert response is expected_response # type: ignore[unreachable] + assert response.headers["set-cookie"] == "my-cookies=abc; Path=/test; SameSite=lax" async def test_to_response_returning_redirect_response(anyio_backend: str) -> None: @@ -186,10 +181,10 @@ def test_function() -> Redirect: assert response.headers["location"] == "/somewhere-else" assert response.headers["local-header"] == "123" assert response.headers["response-header"] == "abc" - cookies = response.headers.getlist("set-cookie") + cookies = response.cookies assert len(cookies) == 2 - assert cookies[0] == "redirect-cookie=xyz; Path=/; SameSite=lax" - assert cookies[1] == "general-cookie=xxx; Path=/; SameSite=lax" + assert cookies[0].to_header(header="") == "redirect-cookie=xyz; Path=/; SameSite=lax" + assert cookies[1].to_header(header="") == "general-cookie=xxx; Path=/; SameSite=lax" assert response.background == background_task @@ -244,14 +239,12 @@ def test_function() -> File: ) assert isinstance(response, FileResponse) assert response.stat_result - assert response.path == current_file_path - assert response.filename == filename assert response.headers["local-header"] == "123" assert response.headers["response-header"] == "abc" - cookies = response.headers.getlist("set-cookie") + cookies = response.cookies assert len(cookies) == 3 - assert cookies[0] == "file-cookie=xyz; Path=/; SameSite=lax" - assert cookies[1] == "general-cookie=xxx; Path=/; SameSite=lax" + assert cookies[0].to_header(header="") == "file-cookie=xyz; Path=/; SameSite=lax" + assert cookies[1].to_header(header="") == "general-cookie=xxx; Path=/; SameSite=lax" assert response.background == background_task @@ -266,7 +259,13 @@ def test_function() -> File: [my_async_generator, False], [MyAsyncIterator, False], [MySyncIterator, False], - [{"key": 1}, True], + [[1, 2, 3, 4], False], + ["abc", False], + [b"abc", False], + [{"key": 1}, False], + [[{"key": 1}], False], + [1, True], + [None, True], ], ) async def test_to_response_streaming_response(iterator: Any, should_raise: bool, anyio_backend: str) -> None: @@ -295,10 +294,10 @@ def test_function() -> Stream: assert isinstance(response, StreamingResponse) assert response.headers["local-header"] == "123" assert response.headers["response-header"] == "abc" - cookies = response.headers.getlist("set-cookie") + cookies = response.cookies assert len(cookies) == 3 - assert cookies[0] == "streaming-cookie=xyz; Path=/; SameSite=lax" - assert cookies[1] == "general-cookie=xxx; Path=/; SameSite=lax" + assert cookies[0].to_header(header="") == "streaming-cookie=xyz; Path=/; SameSite=lax" + assert cookies[1].to_header(header="") == "general-cookie=xxx; Path=/; SameSite=lax" assert response.background == background_task else: with pytest.raises(ValidationError): @@ -331,8 +330,8 @@ def test_function() -> Template: assert isinstance(response, TemplateResponse) assert response.headers["local-header"] == "123" assert response.headers["response-header"] == "abc" - cookies = response.headers.getlist("set-cookie") + cookies = response.cookies assert len(cookies) == 2 - assert cookies[0] == "template-cookie=xyz; Path=/; SameSite=lax" - assert cookies[1] == "general-cookie=xxx; Path=/; SameSite=lax" + assert cookies[0].to_header(header="") == "template-cookie=xyz; Path=/; SameSite=lax" + assert cookies[1].to_header(header="") == "general-cookie=xxx; Path=/; SameSite=lax" assert response.background == background_task diff --git a/tests/handlers/http/test_validations.py b/tests/handlers/http/test_validations.py index 94b791bb31..75fdd97bb7 100644 --- a/tests/handlers/http/test_validations.py +++ b/tests/handlers/http/test_validations.py @@ -3,12 +3,6 @@ import pytest from pydantic import ValidationError -from starlette.status import ( - HTTP_100_CONTINUE, - HTTP_200_OK, - HTTP_304_NOT_MODIFIED, - HTTP_307_TEMPORARY_REDIRECT, -) from starlite import ( File, @@ -23,12 +17,18 @@ ) from starlite.exceptions import ImproperlyConfiguredException, ValidationException from starlite.handlers import HTTPRouteHandler +from starlite.status_codes import ( + HTTP_100_CONTINUE, + HTTP_200_OK, + HTTP_304_NOT_MODIFIED, + HTTP_307_TEMPORARY_REDIRECT, +) from tests import Person def test_route_handler_validation_http_method() -> None: # doesn't raise for http methods - for value in (*list(HttpMethod), *list(map(lambda x: x.upper(), list(HttpMethod)))): + for value in (*list(HttpMethod), *list(map(lambda x: x.upper(), list(HttpMethod)))): # noqa: C417 assert route(http_method=value) # type: ignore # raises for invalid values diff --git a/tests/handlers/websocket/test_kwarg_handling.py b/tests/handlers/websocket/test_kwarg_handling.py index 3dbd1897eb..00b18ab5d7 100644 --- a/tests/handlers/websocket/test_kwarg_handling.py +++ b/tests/handlers/websocket/test_kwarg_handling.py @@ -28,7 +28,7 @@ async def websocket_handler( client = create_test_client(route_handlers=websocket_handler) # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"cookie": "yum"} + client.cookies = {"cookie": "yum"} # type: ignore with client.websocket_connect("/1?qp=1", headers={"some-header": "abc"}) as ws: ws.send_json({"data": "123"}) diff --git a/tests/kwargs/test_cookie_params.py b/tests/kwargs/test_cookie_params.py index 8285c7368c..aac219d998 100644 --- a/tests/kwargs/test_cookie_params.py +++ b/tests/kwargs/test_cookie_params.py @@ -2,9 +2,9 @@ import pytest from pydantic.fields import FieldInfo -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite import Parameter, get +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite.testing import create_test_client @@ -41,6 +41,6 @@ def test_method(special_cookie: t_type = param) -> None: # type: ignore with create_test_client(test_method) as client: # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = param_dict + client.cookies = param_dict # type: ignore response = client.get(test_path) assert response.status_code == expected_code diff --git a/tests/kwargs/test_defaults.py b/tests/kwargs/test_defaults.py index 1eb29c9bd4..e22f9fa397 100644 --- a/tests/kwargs/test_defaults.py +++ b/tests/kwargs/test_defaults.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK - from starlite import Parameter, get +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client diff --git a/tests/kwargs/test_header_params.py b/tests/kwargs/test_header_params.py index 771c938085..80f9b98726 100644 --- a/tests/kwargs/test_header_params.py +++ b/tests/kwargs/test_header_params.py @@ -4,9 +4,9 @@ import pytest from pydantic import UUID4 from pydantic.fields import FieldInfo -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite import Parameter, get +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite.testing import create_test_client diff --git a/tests/kwargs/test_json_data.py b/tests/kwargs/test_json_data.py index ff8d22c0a0..869038373d 100644 --- a/tests/kwargs/test_json_data.py +++ b/tests/kwargs/test_json_data.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_201_CREATED - from starlite import Body, post +from starlite.status_codes import HTTP_201_CREATED from starlite.testing import create_test_client from tests.kwargs import Form diff --git a/tests/kwargs/test_layered_params.py b/tests/kwargs/test_layered_params.py index 3a5a2e0736..fa3c94a073 100644 --- a/tests/kwargs/test_layered_params.py +++ b/tests/kwargs/test_layered_params.py @@ -1,9 +1,9 @@ from typing import List import pytest -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite import Controller, Parameter, Router, get +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite.testing import create_test_client @@ -50,7 +50,7 @@ def my_handler( }, ) as client: # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"app4": "jeronimo"} + client.cookies = {"app4": "jeronimo"} # type: ignore query = {"controller1": "99", "controller3": "tuna", "router1": "albatross", "app2": ["x", "y"]} headers = {"router3": "10"} @@ -99,7 +99,7 @@ def my_handler(self) -> dict: query.pop(parameter) # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = cookies + client.cookies = cookies # type: ignore response = client.get("/router/controller/1", params=query, headers=headers) diff --git a/tests/kwargs/test_multipart_data.py b/tests/kwargs/test_multipart_data.py index 1ac19568b9..33fc55078b 100644 --- a/tests/kwargs/test_multipart_data.py +++ b/tests/kwargs/test_multipart_data.py @@ -6,10 +6,10 @@ import pytest from pydantic import BaseConfig, BaseModel -from starlette.status import HTTP_201_CREATED from starlite_multipart.datastructures import UploadFile from starlite import Body, Request, RequestEncodingType, post +from starlite.status_codes import HTTP_201_CREATED from starlite.testing import create_test_client from tests import Person, PersonFactory from tests.kwargs import Form diff --git a/tests/kwargs/test_path_params.py b/tests/kwargs/test_path_params.py index d458520d55..ad586ab7a2 100644 --- a/tests/kwargs/test_path_params.py +++ b/tests/kwargs/test_path_params.py @@ -6,9 +6,9 @@ import pytest from pydantic import UUID4 -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite import ImproperlyConfiguredException, Parameter, Starlite, get +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite.testing import create_test_client diff --git a/tests/kwargs/test_query_params.py b/tests/kwargs/test_query_params.py index ea04e87f4d..451a093f10 100644 --- a/tests/kwargs/test_query_params.py +++ b/tests/kwargs/test_query_params.py @@ -14,9 +14,9 @@ from urllib.parse import urlencode import pytest -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite import Parameter, get +from starlite.status_codes import HTTP_200_OK, HTTP_400_BAD_REQUEST from starlite.testing import create_test_client diff --git a/tests/kwargs/test_reserved_kwargs_injection.py b/tests/kwargs/test_reserved_kwargs_injection.py index a84d405b0a..d3581d0088 100644 --- a/tests/kwargs/test_reserved_kwargs_injection.py +++ b/tests/kwargs/test_reserved_kwargs_injection.py @@ -3,7 +3,6 @@ import pytest from pydantic import BaseModel, Field from pydantic_factories import ModelFactory -from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite import ( Controller, @@ -17,6 +16,7 @@ post, put, ) +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite.testing import create_test_client from starlite.types import Scope from tests import Person, PersonFactory diff --git a/tests/kwargs/test_url_encoded_data.py b/tests/kwargs/test_url_encoded_data.py index 0b2a948bdf..d0304faabe 100644 --- a/tests/kwargs/test_url_encoded_data.py +++ b/tests/kwargs/test_url_encoded_data.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_201_CREATED - from starlite import Body, RequestEncodingType, post +from starlite.status_codes import HTTP_201_CREATED from starlite.testing import create_test_client from tests.kwargs import Form diff --git a/tests/life_cycle_hooks/test_after_response.py b/tests/life_cycle_hooks/test_after_response.py index be956863d7..4461147a3b 100644 --- a/tests/life_cycle_hooks/test_after_response.py +++ b/tests/life_cycle_hooks/test_after_response.py @@ -2,9 +2,9 @@ from typing import TYPE_CHECKING, Dict import pytest -from starlette.status import HTTP_200_OK from starlite import Controller, Request, Router, get +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client state: Dict[str, str] = {} diff --git a/tests/life_cycle_hooks/test_before_request.py b/tests/life_cycle_hooks/test_before_request.py index d9271a4990..32a86d03bb 100644 --- a/tests/life_cycle_hooks/test_before_request.py +++ b/tests/life_cycle_hooks/test_before_request.py @@ -43,13 +43,13 @@ async def async_after_request_handler(response: Response) -> Response: @pytest.mark.parametrize( "handler, expected", - [ - [get(path="/")(greet), {"hello": "world"}], - [get(path="/", before_request=sync_before_request_handler_with_return_value)(greet), {"hello": "moon"}], - [get(path="/", before_request=async_before_request_handler_with_return_value)(greet), {"hello": "moon"}], - [get(path="/", before_request=sync_before_request_handler_without_return_value)(greet), {"hello": "world"}], - [get(path="/", before_request=async_before_request_handler_without_return_value)(greet), {"hello": "world"}], - ], + ( + (get(path="/")(greet), {"hello": "world"}), + (get(path="/", before_request=sync_before_request_handler_with_return_value)(greet), {"hello": "moon"}), + (get(path="/", before_request=async_before_request_handler_with_return_value)(greet), {"hello": "moon"}), + (get(path="/", before_request=sync_before_request_handler_without_return_value)(greet), {"hello": "world"}), + (get(path="/", before_request=async_before_request_handler_without_return_value)(greet), {"hello": "world"}), + ), ) def test_before_request_handler_called(handler: HTTPRouteHandler, expected: dict) -> None: with create_test_client(route_handlers=handler) as client: diff --git a/tests/logging_config/test_logging_config.py b/tests/logging_config/test_logging_config.py index 2a9d1d8597..da4799464b 100644 --- a/tests/logging_config/test_logging_config.py +++ b/tests/logging_config/test_logging_config.py @@ -2,7 +2,6 @@ from unittest.mock import Mock, patch import pytest -from starlette.status import HTTP_200_OK from starlite import Request, get from starlite.config import LoggingConfig @@ -17,6 +16,7 @@ from starlite.logging.standard import ( QueueListenerHandler as StandardQueueListenerHandler, ) +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client if TYPE_CHECKING: diff --git a/tests/middleware/test_authentication.py b/tests/middleware/test_authentication.py index 0ec3c73dfe..1724318240 100644 --- a/tests/middleware/test_authentication.py +++ b/tests/middleware/test_authentication.py @@ -2,11 +2,6 @@ import pytest from pydantic import BaseModel -from starlette.status import ( - HTTP_200_OK, - HTTP_403_FORBIDDEN, - HTTP_500_INTERNAL_SERVER_ERROR, -) from starlette.websockets import WebSocketDisconnect from starlite import Starlite, get, websocket @@ -17,6 +12,11 @@ AuthenticationResult, ) from starlite.middleware.base import DefineMiddleware +from starlite.status_codes import ( + HTTP_200_OK, + HTTP_403_FORBIDDEN, + HTTP_500_INTERNAL_SERVER_ERROR, +) from starlite.testing import create_test_client if TYPE_CHECKING: diff --git a/tests/middleware/test_compression_middleware.py b/tests/middleware/test_compression_middleware.py index 4cd59b5882..7b6b52b967 100644 --- a/tests/middleware/test_compression_middleware.py +++ b/tests/middleware/test_compression_middleware.py @@ -1,12 +1,10 @@ -from typing import Any, cast +from typing import Any, AsyncIterator, cast import brotli import pytest -from starlette.datastructures import Headers -from starlette.responses import PlainTextResponse from typing_extensions import Literal -from starlite import WebSocket, get, websocket +from starlite import MediaType, WebSocket, get, websocket from starlite.config import CompressionConfig from starlite.datastructures import Stream from starlite.middleware.compression.brotli import BrotliMiddleware, CompressionEncoding @@ -16,17 +14,17 @@ BrotliMode = Literal["text", "generic", "font"] -@get(path="/") -def handler() -> PlainTextResponse: - return PlainTextResponse("_starlite_" * 4000) +@get(path="/", media_type=MediaType.TEXT) +def handler() -> str: + return "_starlite_" * 4000 -@get(path="/no-compression") -def no_compress_handler() -> PlainTextResponse: - return PlainTextResponse("_starlite_") +@get(path="/no-compression", media_type=MediaType.TEXT) +def no_compress_handler() -> str: + return "_starlite_" -async def streaming_iter(content: bytes, count: int) -> Any: +async def streaming_iter(content: bytes, count: int) -> AsyncIterator[bytes]: for _ in range(count): yield content @@ -258,5 +256,4 @@ async def websocket_handler(socket: WebSocket) -> None: route_handlers=[websocket_handler], compression_config=CompressionConfig(backend="brotli", brotli_gzip_fallback=False), ).websocket_connect("/") as ws: - headers = Headers(scope=ws.scope) - assert "Content-Encoding" not in headers + assert "Content-Encoding" not in ws.scope["headers"] diff --git a/tests/middleware/test_csrf.py b/tests/middleware/test_csrf.py index 61a0d4c1a3..8e7a71330b 100644 --- a/tests/middleware/test_csrf.py +++ b/tests/middleware/test_csrf.py @@ -5,8 +5,6 @@ import pytest from bs4 import BeautifulSoup -from starlette import status -from starlette.status import HTTP_200_OK, HTTP_201_CREATED from starlite import ( Body, @@ -23,6 +21,7 @@ put, websocket, ) +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_403_FORBIDDEN from starlite.template.jinja import JinjaTemplateEngine from starlite.template.mako import MakoTemplateEngine from starlite.testing import create_test_client @@ -72,7 +71,7 @@ def test_csrf_successful_flow() -> None: ] response = client.post("/", headers={"x-csrftoken": csrf_token}) - assert response.status_code == status.HTTP_201_CREATED + assert response.status_code == HTTP_201_CREATED @pytest.mark.parametrize( @@ -91,7 +90,7 @@ def test_unsafe_method_fails_without_csrf_header(method: str) -> None: assert csrf_token is not None response = client.request(method, "/") - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == HTTP_403_FORBIDDEN assert response.json() == {"detail": "CSRF token verification failed", "status_code": 403} @@ -106,7 +105,7 @@ def test_invalid_csrf_token() -> None: assert csrf_token is not None response = client.post("/", headers={"x-csrftoken": csrf_token + "invalid"}) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == HTTP_403_FORBIDDEN assert response.json() == {"detail": "CSRF token verification failed", "status_code": 403} @@ -120,7 +119,7 @@ def test_csrf_token_too_short() -> None: assert "csrftoken" in response.cookies response = client.post("/", headers={"x-csrftoken": "too-short"}) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == HTTP_403_FORBIDDEN assert response.json() == {"detail": "CSRF token verification failed", "status_code": 403} @@ -197,7 +196,7 @@ def form_handler(data: dict = Body(media_type=RequestEncodingType.URL_ENCODED)) _ = client.get("/") response = client.get("/") html_soup = BeautifulSoup(html.unescape(response.text), features="html.parser") - data = {"_csrf_token": html_soup.body.div.form.input.attrs.get("value")} + data = {"_csrf_token": html_soup.body.div.form.input.attrs.get("value")} # type: ignore response = client.post("/", data=data) assert response.status_code == HTTP_201_CREATED assert response.json() == data @@ -233,7 +232,7 @@ def post_handler2(data: dict = Body(media_type=RequestEncodingType.URL_ENCODED)) ) as client: data = {"field": "value"} response = client.post("/protected-handler", data=data) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == HTTP_403_FORBIDDEN response = client.post("/unprotected-handler", data=data) assert response.status_code == HTTP_201_CREATED @@ -255,7 +254,7 @@ def post_handler2(data: dict = Body(media_type=RequestEncodingType.URL_ENCODED)) ) as client: data = {"field": "value"} response = client.post("/handler", data=data) - assert response.status_code == status.HTTP_403_FORBIDDEN + assert response.status_code == HTTP_403_FORBIDDEN data = {"field": "value"} response = client.post("/handler2", data=data) diff --git a/tests/middleware/test_exception_handler_middleware.py b/tests/middleware/test_exception_handler_middleware.py index 0964eabff8..2f422f0890 100644 --- a/tests/middleware/test_exception_handler_middleware.py +++ b/tests/middleware/test_exception_handler_middleware.py @@ -2,11 +2,11 @@ from typing import TYPE_CHECKING, Any from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from starlite import HTTPException, Request, Response, Starlite, get from starlite.enums import MediaType from starlite.middleware.exceptions import ExceptionHandlerMiddleware +from starlite.status_codes import HTTP_500_INTERNAL_SERVER_ERROR from starlite.testing import create_test_client if TYPE_CHECKING: diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index 4c8e85a627..dd0fc32d05 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -1,7 +1,6 @@ from logging import INFO from typing import TYPE_CHECKING -from starlette.status import HTTP_200_OK from structlog.testing import capture_logs from starlite import ( @@ -14,6 +13,7 @@ ) from starlite.config.logging import default_handlers from starlite.middleware import LoggingMiddlewareConfig +from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client if TYPE_CHECKING: @@ -40,7 +40,7 @@ def test_logging_middleware_regular_logger(caplog: "LogCaptureFixture") -> None: route_handlers=[handler], middleware=[LoggingMiddlewareConfig().middleware] ) as client, caplog.at_level(INFO): # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"request-cookie": "abc"} + client.cookies = {"request-cookie": "abc"} # type: ignore client.app.get_logger = get_logger response = client.get("/", headers={"request-header": "1"}) assert response.status_code == HTTP_200_OK @@ -54,9 +54,9 @@ def test_logging_middleware_regular_logger(caplog: "LogCaptureFixture") -> None: 'cookies={"request-cookie":"abc"}, query={}, path_params={}, body=None' ) assert ( - caplog.messages[1] == 'HTTP Response: status_code=200, cookies={"first-cookie":"abc","Path":"/",' - '"SameSite":"lax","second-cookie":"xxx"}, headers={"token":"123","regular":"abc",' - '"content-length":"17","content-type":"application/json"}, body={"hello":"world"}' + caplog.messages[1] == 'HTTP Response: status_code=200, cookies={"first-cookie":"abc","Path":"/","SameSite":' + '"lax","second-cookie":"xxx"}, headers={"token":"123","regular":"abc","content-type":' + '"application/json","content-length":"17"}, body={"hello":"world"}' ) @@ -67,7 +67,7 @@ def test_logging_middleware_struct_logger() -> None: logging_config=StructLoggingConfig(), ) as client, capture_logs() as cap_logs: # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = {"request-cookie": "abc"} + client.cookies = {"request-cookie": "abc"} # type: ignore response = client.get("/", headers={"request-header": "1"}) assert response.status_code == HTTP_200_OK assert len(cap_logs) == 2 diff --git a/tests/middleware/test_rate_limit.py b/tests/middleware/test_rate_limit.py index aaf252bfe3..22d230bd1b 100644 --- a/tests/middleware/test_rate_limit.py +++ b/tests/middleware/test_rate_limit.py @@ -4,7 +4,6 @@ import pytest from freezegun import freeze_time from orjson import dumps, loads -from starlette.status import HTTP_200_OK, HTTP_429_TOO_MANY_REQUESTS from starlite import Request, get from starlite.middleware.rate_limit import ( @@ -13,6 +12,7 @@ DurationUnit, RateLimitConfig, ) +from starlite.status_codes import HTTP_200_OK, HTTP_429_TOO_MANY_REQUESTS from starlite.testing import create_test_client diff --git a/tests/middleware/test_session_middleware.py b/tests/middleware/test_session_middleware.py index 87392b5069..5619b7268f 100644 --- a/tests/middleware/test_session_middleware.py +++ b/tests/middleware/test_session_middleware.py @@ -9,7 +9,6 @@ from cryptography.exceptions import InvalidTag from orjson import dumps from pydantic import SecretBytes, ValidationError -from starlette.status import HTTP_201_CREATED, HTTP_500_INTERNAL_SERVER_ERROR from starlite import ( HttpMethod, @@ -27,6 +26,7 @@ SessionCookieConfig, SessionMiddleware, ) +from starlite.status_codes import HTTP_201_CREATED, HTTP_500_INTERNAL_SERVER_ERROR from starlite.testing import create_test_client @@ -149,7 +149,7 @@ def handler(request: Request) -> dict: middleware=[session_middleware.config.middleware], ) as client: # Set cookies on the client to avoid warnings about per-request cookies. - client.cookies = { + client.cookies = { # type: ignore f"{session_middleware.config.key}-{i}": text.decode("utf-8") for i, text in enumerate(ciphertext, start=0) } response = client.get( diff --git a/tests/openapi/test_controller.py b/tests/openapi/test_controller.py index db674610bf..2b187f412e 100644 --- a/tests/openapi/test_controller.py +++ b/tests/openapi/test_controller.py @@ -1,9 +1,8 @@ -from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND - from starlite import OpenAPIConfig from starlite.app import DEFAULT_OPENAPI_CONFIG from starlite.enums import MediaType from starlite.openapi.controller import OpenAPIController as _OpenAPIController +from starlite.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND from starlite.testing import create_test_client from tests.openapi.utils import PersonController, PetController diff --git a/tests/openapi/test_integration.py b/tests/openapi/test_integration.py index a369996449..66bd7a203c 100644 --- a/tests/openapi/test_integration.py +++ b/tests/openapi/test_integration.py @@ -1,10 +1,10 @@ import yaml from orjson import loads from pydantic_openapi_schema.utils import construct_open_api_with_schema_class -from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND from starlite.app import DEFAULT_OPENAPI_CONFIG from starlite.enums import OpenAPIMediaType +from starlite.status_codes import HTTP_200_OK, HTTP_404_NOT_FOUND from starlite.testing import create_test_client from tests.openapi.utils import PersonController, PetController diff --git a/tests/openapi/test_responses.py b/tests/openapi/test_responses.py index aed28da4c3..ff5cae3758 100644 --- a/tests/openapi/test_responses.py +++ b/tests/openapi/test_responses.py @@ -2,12 +2,6 @@ import pytest from pydantic import BaseModel -from starlette.status import ( - HTTP_200_OK, - HTTP_307_TEMPORARY_REDIRECT, - HTTP_400_BAD_REQUEST, - HTTP_406_NOT_ACCEPTABLE, -) from starlite import ( Cookie, @@ -35,6 +29,12 @@ create_responses, create_success_response, ) +from starlite.status_codes import ( + HTTP_200_OK, + HTTP_307_TEMPORARY_REDIRECT, + HTTP_400_BAD_REQUEST, + HTTP_406_NOT_ACCEPTABLE, +) from tests import Person from tests.openapi.utils import PersonController, PetController, PetException diff --git a/tests/plugins/piccolo_orm/test_piccolo_orm_plugin_integration.py b/tests/plugins/piccolo_orm/test_piccolo_orm_plugin_integration.py index e0b660fb88..8b52db9ef3 100644 --- a/tests/plugins/piccolo_orm/test_piccolo_orm_plugin_integration.py +++ b/tests/plugins/piccolo_orm/test_piccolo_orm_plugin_integration.py @@ -2,9 +2,9 @@ from orjson import dumps from piccolo.testing.model_builder import ModelBuilder -from starlette.status import HTTP_200_OK, HTTP_201_CREATED from starlite.plugins.piccolo_orm import PiccoloORMPlugin +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED from starlite.testing import create_test_client from .endpoints import create_concert, retrieve_studio, retrieve_venues, studio, venues diff --git a/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_config.py b/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_config.py index 14f1abda76..a35e7527ad 100644 --- a/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_config.py +++ b/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_config.py @@ -5,7 +5,6 @@ from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, sessionmaker -from starlette.status import HTTP_200_OK from starlite import LoggingConfig, Starlite, get from starlite.plugins.sql_alchemy import SQLAlchemyPlugin @@ -16,6 +15,7 @@ SQLAlchemyEngineConfig, serializer, ) +from starlite.status_codes import HTTP_200_OK from starlite.testing import RequestFactory, create_test_client from starlite.types import Scope diff --git a/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_plugin_integration.py b/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_plugin_integration.py index 69e1e8d60d..e8f59ffb73 100644 --- a/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_plugin_integration.py +++ b/tests/plugins/sql_alchemy_plugin/test_sql_alchemy_plugin_integration.py @@ -4,10 +4,10 @@ create_random_float, create_random_string, ) -from starlette.status import HTTP_200_OK, HTTP_201_CREATED from starlite import get, post from starlite.plugins.sql_alchemy import SQLAlchemyPlugin +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED from starlite.testing import create_test_client from tests.plugins.sql_alchemy_plugin import Company diff --git a/tests/plugins/tortoise_orm/__init__.py b/tests/plugins/tortoise_orm/__init__.py index dffaf4e589..07780115d8 100644 --- a/tests/plugins/tortoise_orm/__init__.py +++ b/tests/plugins/tortoise_orm/__init__.py @@ -1,6 +1,6 @@ from typing import List, cast -from tortoise import Model, Tortoise, fields +from tortoise import Model, Tortoise, fields # type: ignore from starlite.handlers.http import get, post diff --git a/tests/plugins/tortoise_orm/test_tortoise_orm_plugin_integration.py b/tests/plugins/tortoise_orm/test_tortoise_orm_plugin_integration.py index 82aae6e18a..e700ca3a8a 100644 --- a/tests/plugins/tortoise_orm/test_tortoise_orm_plugin_integration.py +++ b/tests/plugins/tortoise_orm/test_tortoise_orm_plugin_integration.py @@ -1,6 +1,5 @@ -from starlette.status import HTTP_200_OK, HTTP_201_CREATED - from starlite.plugins.tortoise_orm import TortoiseORMPlugin +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED from starlite.testing import create_test_client from tests.plugins.tortoise_orm import ( Tournament, diff --git a/tests/response/test_base_response.py b/tests/response/test_base_response.py new file mode 100644 index 0000000000..edee7cdf8f --- /dev/null +++ b/tests/response/test_base_response.py @@ -0,0 +1,178 @@ +from typing import Any, Optional + +import pytest +from pydantic_openapi_schema.v3_1_0 import Info, OpenAPI + +from starlite import MediaType, OpenAPIMediaType, get +from starlite.response import Response +from starlite.status_codes import ( + HTTP_100_CONTINUE, + HTTP_101_SWITCHING_PROTOCOLS, + HTTP_102_PROCESSING, + HTTP_103_EARLY_HINTS, + HTTP_200_OK, + HTTP_204_NO_CONTENT, + HTTP_500_INTERNAL_SERVER_ERROR, +) +from starlite.testing import create_test_client +from starlite.types import Empty + + +def test_response_headers() -> None: + @get("/") + def handler() -> Response: + return Response(content="hello world", media_type=MediaType.TEXT, headers={"first": "123", "second": 456}) + + with create_test_client(handler) as client: + response = client.get("/") + assert response.headers["first"] == "123" + assert response.headers["second"] == "456" + assert response.headers["content-length"] == "11" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + + +def test_set_cookie() -> None: + @get("/") + def handler() -> Response: + response = Response(content=None) + response.set_cookie( + "test", + "abc", + max_age=60, + expires=60, + path="/", + secure=True, + httponly=True, + samesite="lax", + ) + assert len(response.cookies) == 1 + return response + + with create_test_client(handler) as client: + response = client.get("/") + assert response.cookies.get("test") == "abc" + + +def test_delete_cookie() -> None: + @get("/create") + def create_cookie_handler() -> Response: + response = Response(content=None) + response.set_cookie( + "test", + "abc", + max_age=60, + expires=60, + path="/", + secure=True, + httponly=True, + samesite="lax", + ) + assert len(response.cookies) == 1 + return response + + @get("/delete") + def delete_cookie_handler() -> Response: + response = Response(content=None) + response.delete_cookie( + "test", + "abc", + ) + assert len(response.cookies) == 1 + return response + + with create_test_client(route_handlers=[create_cookie_handler, delete_cookie_handler]) as client: + response = client.get("/create") + assert response.cookies.get("test") == "abc" + assert client.cookies.get("test") == "abc" + response = client.get("/delete") + assert response.cookies.get("test") is None + # the commented out assert fails, because of the starlette test client's behaviour - which doesn't clear + # cookies. + # assert client.cookies.get("test") is None + + +@pytest.mark.parametrize( + "media_type, expected, should_have_content_length", + ((MediaType.TEXT, b"", False), (MediaType.HTML, b"", False), (MediaType.JSON, b"null", True)), +) +def test_empty_response(media_type: MediaType, expected: bytes, should_have_content_length: bool) -> None: + @get("/", media_type=media_type) + def handler() -> None: + return + + with create_test_client(handler) as client: + response = client.get("/") + assert response.content == expected + if should_have_content_length: + assert "content-length" in response.headers + else: + assert "content-length" not in response.headers + + +@pytest.mark.parametrize( + "status, body, should_raise", + ( + (HTTP_100_CONTINUE, None, False), + (HTTP_101_SWITCHING_PROTOCOLS, None, False), + (HTTP_102_PROCESSING, None, False), + (HTTP_103_EARLY_HINTS, None, False), + (HTTP_204_NO_CONTENT, None, False), + (HTTP_100_CONTINUE, "1", True), + (HTTP_101_SWITCHING_PROTOCOLS, "1", True), + (HTTP_102_PROCESSING, "1", True), + (HTTP_103_EARLY_HINTS, "1", True), + (HTTP_204_NO_CONTENT, "1", True), + ), +) +def test_statuses_without_body(status: int, body: Optional[str], should_raise: bool) -> None: + @get("/") + def handler() -> Response: + return Response(content=body, status_code=status) + + with create_test_client(handler) as client: + response = client.get("/") + if should_raise: + assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR + else: + assert response.status_code == status + assert "content-length" not in response.headers + + +@pytest.mark.parametrize( + "body, media_type, should_raise", + ( + ("", MediaType.TEXT, False), + ("abc", MediaType.TEXT, False), + (b"", MediaType.HTML, False), + (b"abc", MediaType.HTML, False), + ({}, MediaType.TEXT, True), + ([], MediaType.TEXT, True), + ({}, MediaType.HTML, True), + ([], MediaType.HTML, True), + ({"abc": "def"}, MediaType.JSON, False), + (Empty, MediaType.JSON, True), + (OpenAPI(info=Info(title="my-api", version="1")), OpenAPIMediaType.OPENAPI_JSON, False), + (OpenAPI(info=Info(title="my-api", version="1")), OpenAPIMediaType.OPENAPI_YAML, False), + ), +) +def test_render_method(body: Any, media_type: MediaType, should_raise: bool) -> None: + @get("/", media_type=media_type) + def handler() -> Any: + return body + + with create_test_client(handler) as client: + response = client.get("/") + if should_raise: + assert response.status_code == HTTP_500_INTERNAL_SERVER_ERROR + else: + assert response.status_code == HTTP_200_OK + + +def test_is_head_response_returns_no_body() -> None: + @get("/") + def handler() -> Response: + return Response(content="hello world", media_type=MediaType.TEXT, is_head_response=True) + + with create_test_client(handler) as client: + response = client.get("/") + assert response.text == "" diff --git a/tests/response/test_error_handling.py b/tests/response/test_error_handling.py index 0fc768e866..ddd0b71368 100644 --- a/tests/response/test_error_handling.py +++ b/tests/response/test_error_handling.py @@ -1,7 +1,7 @@ import pytest -from starlette.status import HTTP_200_OK from starlite import ImproperlyConfiguredException, MediaType, Response +from starlite.status_codes import HTTP_200_OK def test_response_error_handling() -> None: diff --git a/tests/response/test_file_response.py b/tests/response/test_file_response.py new file mode 100644 index 0000000000..f77b9c1119 --- /dev/null +++ b/tests/response/test_file_response.py @@ -0,0 +1,109 @@ +"""A large part of the tests in this file were adapted from: + +https://github.com/encode/starlette/blob/master/tests/test_responses.py And are +meant to ensure our compatibility with their API. +""" +from datetime import datetime +from email.utils import formatdate +from pathlib import Path +from typing import TYPE_CHECKING, AsyncIterator + +import anyio +import pytest + +from starlite import BackgroundTask, ImproperlyConfiguredException +from starlite.response import FileResponse +from starlite.status_codes import HTTP_200_OK +from starlite.testing import TestClient + +if TYPE_CHECKING: + from starlite.types import Receive, Scope, Send + + +def test_file_response(tmpdir: Path) -> None: + date_string = formatdate(datetime.now().timestamp(), usegmt=True) + path = tmpdir / "xyz" + content = b"" * 1000 + Path(path).write_bytes(content) + + filled_by_bg_task = "" + + async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: + for i in range(minimum, maximum + 1): + yield str(i) + if i != maximum: + yield ", " + await anyio.sleep(0) + + async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: + nonlocal filled_by_bg_task + async for thing in numbers(start, stop): + filled_by_bg_task = filled_by_bg_task + thing + + cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + response = FileResponse(path=path, filename="example.png", background=cleanup_task) + await response(scope, receive, send) + + assert filled_by_bg_task == "" + client = TestClient(app) + response = client.get("/") + expected_disposition = 'attachment; filename="example.png"' + assert response.status_code == HTTP_200_OK + assert response.content == content + assert response.headers["content-type"] == "image/png" + assert response.headers["content-disposition"] == expected_disposition + assert response.headers["content-length"] == "14000" + assert response.headers["last-modified"].lower() == date_string.lower() + assert "etag" in response.headers + assert filled_by_bg_task == "6, 7, 8, 9" + + +def test_file_response_with_directory_raises_error(tmpdir: Path) -> None: + with pytest.raises(ImproperlyConfiguredException): + FileResponse(path=tmpdir, filename="example.png") + + +def test_file_response_with_missing_file_raises_error(tmpdir: Path) -> None: + path = tmpdir / "404.txt" + with pytest.raises(ImproperlyConfiguredException): + FileResponse(path=path, filename="404.txt") + + +def test_file_response_with_chinese_filename(tmpdir: Path) -> None: + content = b"file content" + filename = "你好.txt" + path = tmpdir / filename + Path(path).write_bytes(content) + app = FileResponse(path=path, filename=filename) + client = TestClient(app) + response = client.get("/") + expected_disposition = "attachment; filename*=utf-8''%e4%bd%a0%e5%a5%bd.txt" + assert response.status_code == HTTP_200_OK + assert response.content == content + assert response.headers["content-disposition"] == expected_disposition + + +def test_file_response_with_inline_disposition(tmpdir: Path) -> None: + content = b"file content" + filename = "hello.txt" + path = tmpdir / filename + Path(path).write_bytes(content) + app = FileResponse(path=path, filename=filename, content_disposition_type="inline") + client = TestClient(app) + response = client.get("/") + expected_disposition = 'inline; filename="hello.txt"' + assert response.status_code == HTTP_200_OK + assert response.content == content + assert response.headers["content-disposition"] == expected_disposition + + +def test_file_response_known_size(tmpdir: Path) -> None: + path = tmpdir / "xyz" + content = b"" * 1000 + Path(path).write_bytes(content) + app = FileResponse(path=path, filename="example.png") + client: TestClient = TestClient(app) + response = client.get("/") + assert response.headers["content-length"] == str(len(content)) diff --git a/tests/response/test_redirect_response.py b/tests/response/test_redirect_response.py new file mode 100644 index 0000000000..18ca33c30e --- /dev/null +++ b/tests/response/test_redirect_response.py @@ -0,0 +1,63 @@ +"""A large part of the tests in this file were adapted from: + +https://github.com/encode/starlette/blob/master/tests/test_responses.py And are +meant to ensure our compatibility with their API. +""" +from typing import TYPE_CHECKING + +import pytest + +from starlite import ImproperlyConfiguredException, Response +from starlite.response import RedirectResponse +from starlite.status_codes import HTTP_200_OK +from starlite.testing import TestClient + +if TYPE_CHECKING: + from starlite.types import Receive, Scope, Send + + +def test_redirect_response() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + if scope["path"] == "/": + response = Response("hello, world", media_type="text/plain") + else: + response = RedirectResponse("/") + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/redirect") + assert response.text == "hello, world" + assert response.url == "http://testserver/" + + +def test_quoting_redirect_response() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + if scope["path"] == "/test/": + response = Response("hello, world", media_type="text/plain") + else: + response = RedirectResponse(url="/test/") + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/redirect", follow_redirects=True) + assert response.text == "hello, world" + assert str(response.url) == "http://testserver/test/" + + +def test_redirect_response_content_length_header() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + if scope["path"] == "/": + response = Response("hello", media_type="text/plain") # pragma: nocover + else: + response = RedirectResponse("/") + await response(scope, receive, send) + + client: TestClient = TestClient(app) + response = client.request("GET", "/redirect", follow_redirects=False) + assert response.url == "http://testserver/redirect" + assert "content-length" not in response.headers + + +def test_redirect_response_status_validation() -> None: + with pytest.raises(ImproperlyConfiguredException): + RedirectResponse("/", status_code=HTTP_200_OK) # type: ignore diff --git a/tests/response/test_response_headers.py b/tests/response/test_response_headers.py index 60648053c0..28640777ab 100644 --- a/tests/response/test_response_headers.py +++ b/tests/response/test_response_headers.py @@ -2,11 +2,11 @@ import pytest from pydantic import ValidationError -from starlette.status import HTTP_201_CREATED from starlite import Controller, HttpMethod, ResponseHeader, Router, Starlite, get, post from starlite.datastructures import CacheControlHeader, ETag from starlite.datastructures.headers import Header +from starlite.status_codes import HTTP_201_CREATED from starlite.testing import TestClient, create_test_client diff --git a/tests/response/test_serialization.py b/tests/response/test_serialization.py index 0bbbc9ecef..36679b61e4 100644 --- a/tests/response/test_serialization.py +++ b/tests/response/test_serialization.py @@ -5,9 +5,9 @@ import pytest from pydantic import SecretStr -from starlette.status import HTTP_200_OK from starlite import MediaType, Response +from starlite.status_codes import HTTP_200_OK from tests import Person, PersonFactory, PydanticDataClassPerson, VanillaDataClassPerson person = PersonFactory.build() diff --git a/tests/response/test_streaming_response.py b/tests/response/test_streaming_response.py new file mode 100644 index 0000000000..be0de8d380 --- /dev/null +++ b/tests/response/test_streaming_response.py @@ -0,0 +1,167 @@ +"""A large part of the tests in this file were adapted from: + +https://github.com/encode/starlette/blob/master/tests/test_responses.py And are +meant to ensure our compatibility with their API. +""" +from itertools import cycle +from typing import TYPE_CHECKING, AsyncIterator, Iterator + +import anyio + +from starlite import BackgroundTask +from starlite.response import StreamingResponse +from starlite.testing import TestClient + +if TYPE_CHECKING: + from starlite.types import Message, Receive, Scope, Send + + +def test_streaming_response_unknown_size() -> None: + app = StreamingResponse(content=iter(["hello", "world"])) + client = TestClient(app) + response = client.get("/") + assert "content-length" not in response.headers + + +def test_streaming_response_known_size() -> None: + app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"}) + client = TestClient(app) + response = client.get("/") + assert response.headers["content-length"] == "10" + + +async def test_streaming_response_stops_if_receiving_http_disconnect_with_async_iterator(anyio_backend: str) -> None: + streamed = 0 + + disconnected = anyio.Event() + + async def receive_disconnect() -> dict: + await disconnected.wait() + return {"type": "http.disconnect"} + + async def send(message: "Message") -> None: + nonlocal streamed + if message["type"] == "http.response.body": + streamed += len(message.get("body", b"")) + # Simulate disconnection after download has started + if streamed >= 16: + await disconnected.set() + + async def stream_indefinitely() -> AsyncIterator[bytes]: + while True: + # Need a sleep for the event loop to switch to another task + await anyio.sleep(0) + yield b"chunk " + + response = StreamingResponse(content=stream_indefinitely()) + + with anyio.move_on_after(1) as cancel_scope: + await response({}, receive_disconnect, send) # type: ignore + assert not cancel_scope.cancel_called, "Content streaming should stop itself." + + +async def test_streaming_response_stops_if_receiving_http_disconnect_with_sync_iterator(anyio_backend: str) -> None: + streamed = 0 + + disconnected = anyio.Event() + + async def receive_disconnect() -> dict: + await disconnected.wait() + return {"type": "http.disconnect"} + + async def send(message: "Message") -> None: + nonlocal streamed + if message["type"] == "http.response.body": + streamed += len(message.get("body", b"")) + # Simulate disconnection after download has started + if streamed >= 16: + await disconnected.set() + + response = StreamingResponse(content=cycle(["1", "2", "3"])) + + with anyio.move_on_after(1) as cancel_scope: + await response({}, receive_disconnect, send) # type: ignore + assert not cancel_scope.cancel_called, "Content streaming should stop itself." + + +def test_streaming_response() -> None: + filled_by_bg_task = "" + + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: + for i in range(minimum, maximum + 1): + yield str(i) + if i != maximum: + yield ", " + await anyio.sleep(0) + + async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: + nonlocal filled_by_bg_task + async for thing in numbers(start, stop): + filled_by_bg_task = filled_by_bg_task + thing + + cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) + generator = numbers(1, 5) + response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task) + await response(scope, receive, send) + + assert filled_by_bg_task == "" + client = TestClient(app) + response = client.get("/") + assert response.text == "1, 2, 3, 4, 5" + assert filled_by_bg_task == "6, 7, 8, 9" + + +def test_streaming_response_custom_iterator() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + class CustomAsyncIterator: + def __init__(self) -> None: + self._called = 0 + + def __aiter__(self) -> "CustomAsyncIterator": + return self + + async def __anext__(self) -> str: + if self._called == 5: + raise StopAsyncIteration() + self._called += 1 + return str(self._called) + + response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.text == "12345" + + +def test_streaming_response_custom_iterable() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + class CustomAsyncIterable: + async def __aiter__(self) -> AsyncIterator[str]: + for i in range(5): + yield str(i + 1) + + response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.text == "12345" + + +def test_sync_streaming_response() -> None: + async def app(scope: "Scope", receive: "Receive", send: "Send") -> None: + def numbers(minimum: int, maximum: int) -> Iterator[str]: + for i in range(minimum, maximum + 1): + yield str(i) + if i != maximum: + yield ", " + + generator = numbers(1, 5) + response = StreamingResponse(generator, media_type="text/plain") + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.text == "1, 2, 3, 4, 5" diff --git a/tests/routing/test_path_resolution.py b/tests/routing/test_path_resolution.py index 237b846b28..420c1a4b05 100644 --- a/tests/routing/test_path_resolution.py +++ b/tests/routing/test_path_resolution.py @@ -1,13 +1,6 @@ from typing import Any, Callable, Optional, Type import pytest -from starlette.status import ( - HTTP_200_OK, - HTTP_204_NO_CONTENT, - HTTP_400_BAD_REQUEST, - HTTP_404_NOT_FOUND, - HTTP_405_METHOD_NOT_ALLOWED, -) from starlite import ( Controller, @@ -17,6 +10,13 @@ get, post, ) +from starlite.status_codes import ( + HTTP_200_OK, + HTTP_204_NO_CONTENT, + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, + HTTP_405_METHOD_NOT_ALLOWED, +) from starlite.testing import create_test_client from tests import Person, PersonFactory diff --git a/tests/static_files/test_static_files.py b/tests/static_files/test_static_files.py index 1335f98ee9..4f07ec7a3d 100644 --- a/tests/static_files/test_static_files.py +++ b/tests/static_files/test_static_files.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import TYPE_CHECKING import pytest from pydantic import ValidationError @@ -7,10 +7,13 @@ from starlite.config import StaticFilesConfig from starlite.testing import create_test_client +if TYPE_CHECKING: + from pathlib import Path -def test_staticfiles(tmpdir: Any) -> None: - path = tmpdir.join("test.txt") - path.write("content") + +def test_staticfiles(tmpdir: "Path") -> None: + path = tmpdir / "test.txt" + path.write_text("content", "utf-8") static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir]) with create_test_client([], static_files_config=static_files_config) as client: response = client.get("/static/test.txt") @@ -18,9 +21,9 @@ def test_staticfiles(tmpdir: Any) -> None: assert response.text == "content" -def test_staticfiles_html_mode(tmpdir: Any) -> None: - path = tmpdir.join("index.html") - path.write("content") +def test_staticfiles_html_mode(tmpdir: "Path") -> None: + path = tmpdir / "index.html" + path.write_text("content", "utf-8") static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir], html_mode=True) with create_test_client([], static_files_config=static_files_config) as client: response = client.get("/static") @@ -28,9 +31,9 @@ def test_staticfiles_html_mode(tmpdir: Any) -> None: assert response.text == "content" -def test_staticfiles_for_slash_path(tmpdir: Any) -> None: - path = tmpdir.join("text.txt") - path.write("content") +def test_staticfiles_for_slash_path(tmpdir: "Path") -> None: + path = tmpdir / "text.txt" + path.write_text("content", "utf-8") static_files_config = StaticFilesConfig(path="/", directories=[tmpdir]) with create_test_client([], static_files_config=static_files_config) as client: @@ -39,9 +42,9 @@ def test_staticfiles_for_slash_path(tmpdir: Any) -> None: assert response.text == "content" -def test_config_validation(tmpdir: Any) -> None: - path = tmpdir.join("text.txt") - path.write("content") +def test_config_validation(tmpdir: "Path") -> None: + path = tmpdir / "text.txt" + path.write_text("content", "utf-8") with pytest.raises(ValidationError): StaticFilesConfig(path="", directories=[tmpdir]) @@ -50,9 +53,9 @@ def test_config_validation(tmpdir: Any) -> None: StaticFilesConfig(path="/{param:int}", directories=[tmpdir]) -def test_path_inside_static(tmpdir: Any) -> None: - path = tmpdir.join("test.txt") - path.write("content") +def test_path_inside_static(tmpdir: "Path") -> None: + path = tmpdir / "test.txt" + path.write_text("content", "utf-8") @get("/static/strange/{f:str}") def handler(f: str) -> str: @@ -67,17 +70,17 @@ def handler(f: str) -> str: app.register(handler) -def test_multiple_configs(tmpdir: Any) -> None: - root1 = tmpdir.mkdir("1") - root2 = tmpdir.mkdir("2") - path1 = root1.join("test.txt") - path1.write("content1") - path2 = root2.join("test.txt") - path2.write("content2") +def test_multiple_configs(tmpdir: "Path") -> None: + root1 = tmpdir.mkdir("1") # type: ignore + root2 = tmpdir.mkdir("2") # type: ignore + path1 = root1 / "test.txt" # pyright: ignore + path1.write_text("content1", "utf-8") + path2 = root2 / "test.txt" # pyright: ignore + path2.write_text("content2", "utf-8") static_files_config = [ - StaticFilesConfig(path="/1", directories=[root1]), - StaticFilesConfig(path="/2", directories=[root2]), + StaticFilesConfig(path="/1", directories=[root1]), # pyright: ignore + StaticFilesConfig(path="/2", directories=[root2]), # pyright: ignore ] with create_test_client([], static_files_config=static_files_config) as client: response = client.get("/1/test.txt") @@ -89,10 +92,9 @@ def test_multiple_configs(tmpdir: Any) -> None: assert response.text == "content2" -def test_static_substring_of_self(tmpdir: Any) -> None: - path = tmpdir.mkdir("static_part").mkdir("static") - path = path.join("test.txt") - path.write("content") +def test_static_substring_of_self(tmpdir: "Path") -> None: + path = tmpdir.mkdir("static_part").mkdir("static") / "test.txt" # type: ignore + path.write_text("content", "utf-8") static_files_config = StaticFilesConfig(path="/static", directories=[tmpdir]) with create_test_client([], static_files_config=static_files_config) as client: diff --git a/tests/test_controller.py b/tests/test_controller.py index bb1b5fe458..31f311aac4 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -2,7 +2,6 @@ import pytest from pydantic import BaseModel -from starlette.status import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite import ( Controller, @@ -17,6 +16,7 @@ websocket, ) from starlite.connection import WebSocket +from starlite.status_codes import HTTP_200_OK, HTTP_201_CREATED, HTTP_204_NO_CONTENT from starlite.testing import create_test_client from tests import Person, PersonFactory diff --git a/tests/test_dependency.py b/tests/test_dependency.py index e2ebeb644a..e1847a719d 100644 --- a/tests/test_dependency.py +++ b/tests/test_dependency.py @@ -1,11 +1,11 @@ from typing import Any, Dict, Optional import pytest -from starlette.status import HTTP_200_OK, HTTP_500_INTERNAL_SERVER_ERROR from starlite import Controller, Dependency, Provide, Starlite, get from starlite.constants import EXTRA_KEY_IS_DEPENDENCY from starlite.exceptions import ImproperlyConfiguredException +from starlite.status_codes import HTTP_200_OK, HTTP_500_INTERNAL_SERVER_ERROR from starlite.testing import create_test_client diff --git a/tests/test_exception_handlers.py b/tests/test_exception_handlers.py index 60a340a74a..098a37d141 100644 --- a/tests/test_exception_handlers.py +++ b/tests/test_exception_handlers.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING, Type import pytest -from starlette.status import HTTP_400_BAD_REQUEST from starlite import ( Controller, @@ -15,6 +14,7 @@ ValidationException, get, ) +from starlite.status_codes import HTTP_400_BAD_REQUEST from starlite.testing import create_test_client if TYPE_CHECKING: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3cdd348c69..3a4d4f27c5 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,7 +3,6 @@ from hypothesis import given from hypothesis import strategies as st from starlette.exceptions import HTTPException as StarletteHTTPException -from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from starlite.enums import MediaType from starlite.exceptions import ( @@ -12,6 +11,7 @@ StarLiteException, ValidationException, ) +from starlite.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from starlite.utils.exception import create_exception_response diff --git a/tests/test_guards.py b/tests/test_guards.py index 78bbd4f354..c4ca52d4da 100644 --- a/tests/test_guards.py +++ b/tests/test_guards.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING import pytest -from starlette.status import HTTP_200_OK, HTTP_403_FORBIDDEN from starlette.websockets import WebSocketDisconnect from starlite import ( @@ -15,6 +14,7 @@ ) from starlite.connection import WebSocket from starlite.exceptions import PermissionDeniedException +from starlite.status_codes import HTTP_200_OK, HTTP_403_FORBIDDEN from starlite.testing import create_test_client from starlite.types import Receive, Scope, Send @@ -53,7 +53,7 @@ def test_guards_with_asgi_handler() -> None: @asgi(path="/secret", guards=[local_guard]) async def my_asgi_handler(scope: Scope, receive: Receive, send: Send) -> None: response = Response(media_type=MediaType.JSON, status_code=HTTP_200_OK, content={"hello": "world"}) - await response(scope=scope, receive=receive, send=send) # type: ignore[arg-type] + await response(scope=scope, receive=receive, send=send) with create_test_client(guards=[app_guard], route_handlers=[my_asgi_handler]) as client: response = client.get("/secret") diff --git a/tests/test_parsers.py b/tests/test_parsers.py index eadce7d5a5..d98ca9a235 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -1,9 +1,12 @@ +from typing import Dict + +import pytest from pydantic import BaseConfig from pydantic.fields import ModelField -from starlite import RequestEncodingType +from starlite import Cookie, RequestEncodingType from starlite.datastructures import FormMultiDict -from starlite.parsers import parse_form_data, parse_query_params +from starlite.parsers import parse_cookie_string, parse_form_data, parse_query_params from starlite.testing import RequestFactory @@ -51,3 +54,25 @@ def test_parse_form_data() -> None: "healthy": True, "polluting": False, } + + +@pytest.mark.parametrize( + "cookie_string, expected", + ( + ("ABC = 123; efg = 456", {"ABC": "123", "efg": "456"}), + (("foo= ; bar="), {"foo": "", "bar": ""}), + ('foo="bar=123456789&name=moisheZuchmir"', {"foo": "bar=123456789&name=moisheZuchmir"}), + ("email=%20%22%2c%3b%2f", {"email": ' ",;/'}), + ("foo=%1;bar=bar", {"foo": "%1", "bar": "bar"}), + ("foo=bar;fizz ; buzz", {"": "buzz", "foo": "bar"}), + (" fizz; foo= bar", {"": "fizz", "foo": "bar"}), + ("foo=false;bar=bar;foo=true", {"bar": "bar", "foo": "true"}), + ("foo=;bar=bar;foo=boo", {"bar": "bar", "foo": "boo"}), + ( + Cookie(key="abc", value="123", path="/head", domain="localhost").to_header(header=""), + {"Domain": "localhost", "Path": "/head", "SameSite": "lax", "abc": "123"}, + ), + ), +) +def test_parse_cookie_string(cookie_string: str, expected: Dict[str, str]) -> None: + assert parse_cookie_string(cookie_string) == expected diff --git a/tests/test_signature.py b/tests/test_signature.py index 259d9b4c83..c8e76aa57e 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, ValidationError from pydantic.error_wrappers import ErrorWrapper from starlette.datastructures import URL -from starlette.status import HTTP_204_NO_CONTENT from starlite import HTTPException, Provide, get from starlite.connection import WebSocket @@ -17,6 +16,7 @@ ) from starlite.params import Dependency from starlite.signature import SignatureModel, SignatureModelFactory +from starlite.status_codes import HTTP_204_NO_CONTENT from starlite.testing import RequestFactory, create_test_client from tests.plugins.test_base import AModel, APlugin diff --git a/tests/test_typing.py b/tests/test_typing.py index c0b315767e..503e707232 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -119,7 +119,7 @@ class Foo: @pytest.mark.parametrize( "cls, should_raise", - [(Foo, True), (Person, False), (VanillaDataClassPerson, False), (PydanticDataClassPerson, False)], + ((Foo, True), (Person, False), (VanillaDataClassPerson, False), (PydanticDataClassPerson, False)), ) def test_validation(cls: Any, should_raise: bool) -> None: """Test that Partial returns no annotations for classes that don't inherit diff --git a/tests/test_utils.py b/tests/test_utils.py index c9a77a9bc1..5809ca9e6a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import pytest -from starlite.utils.predicates import is_async_callable +from starlite.utils import is_async_callable class AsyncTestCallable: diff --git a/tests/utils/test_exceptions.py b/tests/utils/test_exceptions.py index afd582ce87..6853f194e5 100644 --- a/tests/utils/test_exceptions.py +++ b/tests/utils/test_exceptions.py @@ -1,9 +1,9 @@ from typing import Any import pytest -from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from starlite import HTTPException, InternalServerException, ValidationException +from starlite.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from starlite.types import ExceptionHandlersMap from starlite.utils.exception import get_exception_handler diff --git a/tests/utils/test_extractors.py b/tests/utils/test_extractors.py index f49c0a49e4..6d7547a80d 100644 --- a/tests/utils/test_extractors.py +++ b/tests/utils/test_extractors.py @@ -1,10 +1,10 @@ from typing import Any, List import pytest -from starlette.status import HTTP_200_OK from starlite import Cookie, MediaType, Request, RequestEncodingType, Response from starlite.connection import empty_receive +from starlite.status_codes import HTTP_200_OK from starlite.testing import RequestFactory from starlite.utils import ConnectionDataExtractor from starlite.utils.extractors import ResponseDataExtractor @@ -23,15 +23,15 @@ async def test_connection_data_extractor() -> None: request.scope["path_params"] = {"first": "10", "second": "20", "third": "30"} extractor = ConnectionDataExtractor(parse_body=True, parse_query=True) extracted_data = extractor(request) - assert await extracted_data["body"] == await request.json() - assert extracted_data["content_type"] == request.content_type - assert extracted_data["headers"] == dict(request.headers) - assert extracted_data["headers"] == dict(request.headers) - assert extracted_data["path"] == request.scope["path"] - assert extracted_data["path"] == request.scope["path"] - assert extracted_data["path_params"] == request.scope["path_params"] - assert extracted_data["query"] == request.query_params - assert extracted_data["scheme"] == request.scope["scheme"] + assert await extracted_data.get("body") == await request.json() # type: ignore + assert extracted_data.get("content_type") == request.content_type + assert extracted_data.get("headers") == dict(request.headers) + assert extracted_data.get("headers") == dict(request.headers) + assert extracted_data.get("path") == request.scope["path"] + assert extracted_data.get("path") == request.scope["path"] + assert extracted_data.get("path_params") == request.scope["path_params"] + assert extracted_data.get("query") == request.query_params + assert extracted_data.get("scheme") == request.scope["scheme"] def test_parse_query() -> None: @@ -41,27 +41,27 @@ def test_parse_query() -> None: ) parsed_extracted_data = ConnectionDataExtractor(parse_query=True)(request) unparsed_extracted_data = ConnectionDataExtractor(parse_query=False)(request) - assert parsed_extracted_data["query"] == request.query_params - assert unparsed_extracted_data["query"] == request.scope["query_string"] + assert parsed_extracted_data.get("query") == request.query_params + assert unparsed_extracted_data.get("query") == request.scope["query_string"] # Close to avoid warnings about un-awaited coroutines. - parsed_extracted_data["body"].close() - unparsed_extracted_data["body"].close() + parsed_extracted_data.get("body").close() # type: ignore + unparsed_extracted_data.get("body").close() # type: ignore async def test_parse_json_data() -> None: request = factory.post(path="/a/b/c", data={"hello": "world"}) - assert await ConnectionDataExtractor(parse_body=True)(request)["body"] == await request.json() - assert await ConnectionDataExtractor(parse_body=False)(request)["body"] == await request.body() + assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == await request.json() # type: ignore + assert await ConnectionDataExtractor(parse_body=False)(request).get("body") == await request.body() # type: ignore async def test_parse_form_data() -> None: request = factory.post(path="/a/b/c", data={"file": b"123"}, request_media_type=RequestEncodingType.MULTI_PART) - assert await ConnectionDataExtractor(parse_body=True)(request)["body"] == dict(await request.form()) + assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == dict(await request.form()) # type: ignore async def test_parse_url_encoded() -> None: request = factory.post(path="/a/b/c", data={"key": "123"}, request_media_type=RequestEncodingType.URL_ENCODED) - assert await ConnectionDataExtractor(parse_body=True)(request)["body"] == dict(await request.form()) + assert await ConnectionDataExtractor(parse_body=True)(request).get("body") == dict(await request.form()) # type: ignore @pytest.mark.parametrize( @@ -70,9 +70,9 @@ async def test_parse_url_encoded() -> None: def test_request_extraction_header_obfuscation(req: Request[Any, Any]) -> None: extractor = ConnectionDataExtractor(obfuscate_headers={"special"}) extracted_data = extractor(req) - assert extracted_data["headers"] == {"special": "*****"} + assert extracted_data.get("headers") == {"special": "*****"} # Close to avoid warnings about un-awaited coroutines. - extracted_data["body"].close() + extracted_data.get("body").close() # type: ignore @pytest.mark.parametrize( @@ -85,13 +85,13 @@ def test_request_extraction_header_obfuscation(req: Request[Any, Any]) -> None: def test_request_extraction_cookie_obfuscation(req: Request[Any, Any], key: str) -> None: extractor = ConnectionDataExtractor(obfuscate_cookies={"special"}) extracted_data = extractor(req) - assert extracted_data["cookies"] == {"Path": "/", "SameSite": "lax", key: "*****"} + assert extracted_data.get("cookies") == {"Path": "/", "SameSite": "lax", key: "*****"} # Close to avoid warnings about un-awaited coroutines. - extracted_data["body"].close() + extracted_data.get("body").close() # type: ignore async def test_response_data_extractor() -> None: - headers = {"common": "abc", "special": "123", "content-type": "application/json; charset=utf-8"} + headers = {"common": "abc", "special": "123", "content-type": "application/json"} cookies = [Cookie(key="regular"), Cookie(key="auth")] response = Response( media_type=MediaType.JSON, @@ -107,11 +107,11 @@ async def test_response_data_extractor() -> None: async def send(message: "Any") -> None: messages.append(message) - await response({}, empty_receive, send) + await response({}, empty_receive, send) # type: ignore[arg-type] assert len(messages) == 2 extracted_data = extractor(messages) # type: ignore - assert extracted_data["status_code"] == HTTP_200_OK - assert extracted_data["body"] == b'{"hello":"world"}' - assert extracted_data["headers"] == {**headers, "content-length": "17"} - assert extracted_data["cookies"] == {"Path": "/", "SameSite": "lax", "auth": "None", "regular": "None"} + assert extracted_data.get("status_code") == HTTP_200_OK + assert extracted_data.get("body") == b'{"hello":"world"}' + assert extracted_data.get("headers") == {**headers, "content-length": "17"} + assert extracted_data.get("cookies") == {"Path": "/", "SameSite": "lax", "auth": "", "regular": ""} diff --git a/tests/utils/test_model.py b/tests/utils/test_model.py index 711871bfeb..128aadc053 100644 --- a/tests/utils/test_model.py +++ b/tests/utils/test_model.py @@ -7,7 +7,7 @@ from starlite.utils import model if TYPE_CHECKING: - from pytest import MonkeyPatch # noqa: PT013 + from pytest import MonkeyPatch def test_convert_dataclass_to_model_cache(monkeypatch: "MonkeyPatch") -> None: diff --git a/tests/utils/test_url.py b/tests/utils/test_url.py index 0831819921..f421f94c1c 100644 --- a/tests/utils/test_url.py +++ b/tests/utils/test_url.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize( "base,fragment, expected", - [ + ( ("/path/", "sub", "/path/sub"), ("/path/", "/sub/", "/path/sub"), ("path/", "sub", "/path/sub"), @@ -14,7 +14,7 @@ ("path/", "sub/", "/path/sub"), ("path", "sub/", "/path/sub"), ("/", "/root/sub", "/root/sub"), - ], + ), ) def test_join_url_fragments(base: str, fragment: str, expected: str) -> None: assert join_paths([base, fragment]) == expected