diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a24c77b378..25e7e6f8a5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -105,6 +105,7 @@ repos: args: ["--unsafe-load-any-extension=y"] additional_dependencies: [ + multidict, orjson, piccolo, picologging, @@ -128,6 +129,7 @@ repos: [ aiomcache, httpx, + multidict, orjson, piccolo, picologging, @@ -160,6 +162,7 @@ repos: httpx, hypothesis, mako, + multidict, orjson, jinja2, piccolo, diff --git a/docs/reference/datastructures/3-header.md b/docs/reference/datastructures/3-headers.md similarity index 73% rename from docs/reference/datastructures/3-header.md rename to docs/reference/datastructures/3-headers.md index 4229beaf66..dbcc160094 100644 --- a/docs/reference/datastructures/3-header.md +++ b/docs/reference/datastructures/3-headers.md @@ -1,5 +1,19 @@ # Headers +::: starlite.datastructures.Headers + options: + members: + - __init__ + - from_scope + +::: starlite.datastructures.MutableScopeHeaders + options: + members: + - __init__ + - add + - getall + - extend_header_value + ::: starlite.datastructures.ResponseHeader options: members: diff --git a/docs/reference/datastructures/7-form-multi-dict.md b/docs/reference/datastructures/7-form-multi-dict.md index f42024e777..7ceb339f3e 100644 --- a/docs/reference/datastructures/7-form-multi-dict.md +++ b/docs/reference/datastructures/7-form-multi-dict.md @@ -4,3 +4,4 @@ options: members: - close + - multi_items diff --git a/docs/usage/2-route-handlers/0-route-handlers-concept.md b/docs/usage/2-route-handlers/0-route-handlers-concept.md index 7e8f0045c7..6721ccb1d0 100644 --- a/docs/usage/2-route-handlers/0-route-handlers-concept.md +++ b/docs/usage/2-route-handlers/0-route-handlers-concept.md @@ -99,7 +99,7 @@ Additionally, you can specify the following special kwargs, what's called "reser ```python from typing import Any, Dict from starlite import State, Request, get -from starlette.datastructures import Headers +from starlite.datastructures import Headers @get(path="/") diff --git a/mkdocs.yml b/mkdocs.yml index 0eeab93202..52944fc8ce 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -185,7 +185,7 @@ nav: - reference/datastructures/0-state.md - reference/datastructures/1-cookie.md - reference/datastructures/2-provide.md - - reference/datastructures/3-header.md + - reference/datastructures/3-headers.md - reference/datastructures/4-background.md - reference/datastructures/5-response-containers.md - reference/datastructures/6-upload-file.md diff --git a/poetry.lock b/poetry.lock index c47cd074ee..044fa5a68d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -510,6 +510,14 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "multidict" +version = "6.0.2" +description = "multidict implementation" +category = "main" +optional = false +python-versions = ">=3.7" + [[package]] name = "mypy-extensions" version = "0.4.3" @@ -1134,7 +1142,7 @@ structlog = ["structlog"] [metadata] lock-version = "1.1" python-versions = ">=3.7,<4.0" -content-hash = "416c4f61a33cfb6efde867abb686e9bb20a5300fef6faee8ba98d4c559acf34e" +content-hash = "d2d47a7d6b94922f3af7004d0c785ebed19108dcdf699b42e627430a00211421" [metadata.files] aiomcache = [ @@ -1650,6 +1658,67 @@ markupsafe = [ {file = "MarkupSafe-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:46d00d6cfecdde84d40e572d63735ef81423ad31184100411e6e3388d405e247"}, {file = "MarkupSafe-2.1.1.tar.gz", hash = "sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b"}, ] +multidict = [ + {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"}, + {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"}, + {file = "multidict-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:041b81a5f6b38244b34dc18c7b6aba91f9cdaf854d9a39e5ff0b58e2b5773b9c"}, + {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fdda29a3c7e76a064f2477c9aab1ba96fd94e02e386f1e665bca1807fc5386f"}, + {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3368bf2398b0e0fcbf46d85795adc4c259299fec50c1416d0f77c0a843a3eed9"}, + {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4f052ee022928d34fe1f4d2bc743f32609fb79ed9c49a1710a5ad6b2198db20"}, + {file = "multidict-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:225383a6603c086e6cef0f2f05564acb4f4d5f019a4e3e983f572b8530f70c88"}, + {file = "multidict-6.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50bd442726e288e884f7be9071016c15a8742eb689a593a0cac49ea093eef0a7"}, + {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:47e6a7e923e9cada7c139531feac59448f1f47727a79076c0b1ee80274cd8eee"}, + {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:0556a1d4ea2d949efe5fd76a09b4a82e3a4a30700553a6725535098d8d9fb672"}, + {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:626fe10ac87851f4cffecee161fc6f8f9853f0f6f1035b59337a51d29ff3b4f9"}, + {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:8064b7c6f0af936a741ea1efd18690bacfbae4078c0c385d7c3f611d11f0cf87"}, + {file = "multidict-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2d36e929d7f6a16d4eb11b250719c39560dd70545356365b494249e2186bc389"}, + {file = "multidict-6.0.2-cp310-cp310-win32.whl", hash = "sha256:fcb91630817aa8b9bc4a74023e4198480587269c272c58b3279875ed7235c293"}, + {file = "multidict-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:8cbf0132f3de7cc6c6ce00147cc78e6439ea736cee6bca4f068bcf892b0fd658"}, + {file = "multidict-6.0.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:05f6949d6169878a03e607a21e3b862eaf8e356590e8bdae4227eedadacf6e51"}, + {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2c2e459f7050aeb7c1b1276763364884595d47000c1cddb51764c0d8976e608"}, + {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d0509e469d48940147e1235d994cd849a8f8195e0bca65f8f5439c56e17872a3"}, + {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:514fe2b8d750d6cdb4712346a2c5084a80220821a3e91f3f71eec11cf8d28fd4"}, + {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19adcfc2a7197cdc3987044e3f415168fc5dc1f720c932eb1ef4f71a2067e08b"}, + {file = "multidict-6.0.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b9d153e7f1f9ba0b23ad1568b3b9e17301e23b042c23870f9ee0522dc5cc79e8"}, + {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:aef9cc3d9c7d63d924adac329c33835e0243b5052a6dfcbf7732a921c6e918ba"}, + {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4571f1beddff25f3e925eea34268422622963cd8dc395bb8778eb28418248e43"}, + {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:d48b8ee1d4068561ce8033d2c344cf5232cb29ee1a0206a7b828c79cbc5982b8"}, + {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:45183c96ddf61bf96d2684d9fbaf6f3564d86b34cb125761f9a0ef9e36c1d55b"}, + {file = "multidict-6.0.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:75bdf08716edde767b09e76829db8c1e5ca9d8bb0a8d4bd94ae1eafe3dac5e15"}, + {file = "multidict-6.0.2-cp37-cp37m-win32.whl", hash = "sha256:a45e1135cb07086833ce969555df39149680e5471c04dfd6a915abd2fc3f6dbc"}, + {file = "multidict-6.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6f3cdef8a247d1eafa649085812f8a310e728bdf3900ff6c434eafb2d443b23a"}, + {file = "multidict-6.0.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:0327292e745a880459ef71be14e709aaea2f783f3537588fb4ed09b6c01bca60"}, + {file = "multidict-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e875b6086e325bab7e680e4316d667fc0e5e174bb5611eb16b3ea121c8951b86"}, + {file = "multidict-6.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:feea820722e69451743a3d56ad74948b68bf456984d63c1a92e8347b7b88452d"}, + {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cc57c68cb9139c7cd6fc39f211b02198e69fb90ce4bc4a094cf5fe0d20fd8b0"}, + {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:497988d6b6ec6ed6f87030ec03280b696ca47dbf0648045e4e1d28b80346560d"}, + {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:89171b2c769e03a953d5969b2f272efa931426355b6c0cb508022976a17fd376"}, + {file = "multidict-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:684133b1e1fe91eda8fa7447f137c9490a064c6b7f392aa857bba83a28cfb693"}, + {file = "multidict-6.0.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd9fc9c4849a07f3635ccffa895d57abce554b467d611a5009ba4f39b78a8849"}, + {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e07c8e79d6e6fd37b42f3250dba122053fddb319e84b55dd3a8d6446e1a7ee49"}, + {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4070613ea2227da2bfb2c35a6041e4371b0af6b0be57f424fe2318b42a748516"}, + {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:47fbeedbf94bed6547d3aa632075d804867a352d86688c04e606971595460227"}, + {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:5774d9218d77befa7b70d836004a768fb9aa4fdb53c97498f4d8d3f67bb9cfa9"}, + {file = "multidict-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2957489cba47c2539a8eb7ab32ff49101439ccf78eab724c828c1a54ff3ff98d"}, + {file = "multidict-6.0.2-cp38-cp38-win32.whl", hash = "sha256:e5b20e9599ba74391ca0cfbd7b328fcc20976823ba19bc573983a25b32e92b57"}, + {file = "multidict-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:8004dca28e15b86d1b1372515f32eb6f814bdf6f00952699bdeb541691091f96"}, + {file = "multidict-6.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2e4a0785b84fb59e43c18a015ffc575ba93f7d1dbd272b4cdad9f5134b8a006c"}, + {file = "multidict-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6701bf8a5d03a43375909ac91b6980aea74b0f5402fbe9428fc3f6edf5d9677e"}, + {file = "multidict-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a007b1638e148c3cfb6bf0bdc4f82776cef0ac487191d093cdc316905e504071"}, + {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07a017cfa00c9890011628eab2503bee5872f27144936a52eaab449be5eaf032"}, + {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c207fff63adcdf5a485969131dc70e4b194327666b7e8a87a97fbc4fd80a53b2"}, + {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:373ba9d1d061c76462d74e7de1c0c8e267e9791ee8cfefcf6b0b2495762c370c"}, + {file = "multidict-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfba7c6d5d7c9099ba21f84662b037a0ffd4a5e6b26ac07d19e423e6fdf965a9"}, + {file = "multidict-6.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19d9bad105dfb34eb539c97b132057a4e709919ec4dd883ece5838bcbf262b80"}, + {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:de989b195c3d636ba000ee4281cd03bb1234635b124bf4cd89eeee9ca8fcb09d"}, + {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7c40b7bbece294ae3a87c1bc2abff0ff9beef41d14188cda94ada7bcea99b0fb"}, + {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:d16cce709ebfadc91278a1c005e3c17dd5f71f5098bfae1035149785ea6e9c68"}, + {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:a2c34a93e1d2aa35fbf1485e5010337c72c6791407d03aa5f4eed920343dd360"}, + {file = "multidict-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:feba80698173761cddd814fa22e88b0661e98cb810f9f986c54aa34d281e4937"}, + {file = "multidict-6.0.2-cp39-cp39-win32.whl", hash = "sha256:23b616fdc3c74c9fe01d76ce0d1ce872d2d396d8fa8e4899398ad64fb5aa214a"}, + {file = "multidict-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:4bae31803d708f6f15fd98be6a6ac0b6958fcf68fda3c77a048a4f9073704aae"}, + {file = "multidict-6.0.2.tar.gz", hash = "sha256:5ff3bd75f38e4c43f1f470f2df7a4d430b821c4ce22be384e1459cb57d6bb013"}, +] mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, diff --git a/pyproject.toml b/pyproject.toml index f453095de1..86e62b055f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ starlette = ">=0.21" starlite-multipart = ">=1.2.0" structlog = { version = "*", optional = true } typing-extensions = "*" +multidict = ">=6.0.2" [tool.poetry.group.dev.dependencies] aiomcache = "*" diff --git a/starlite/connection/base.py b/starlite/connection/base.py index b622ecae15..e876d97784 100644 --- a/starlite/connection/base.py +++ b/starlite/connection/base.py @@ -10,8 +10,9 @@ cast, ) -from starlette.datastructures import URL, Address, Headers, URLPath +from starlette.datastructures import URL, Address, URLPath +from starlite.datastructures.headers import Headers from starlite.datastructures.state import State from starlite.exceptions import ImproperlyConfiguredException from starlite.parsers import parse_cookie_string, parse_query_params @@ -151,7 +152,7 @@ def headers(self) -> Headers: """ if self._headers is Empty: self.scope.setdefault("headers", []) - self._headers = self.scope["_headers"] = Headers(scope=self.scope) # type: ignore[typeddict-item] + self._headers = self.scope["_headers"] = Headers.from_scope(self.scope) # type: ignore[typeddict-item] return cast("Headers", self._headers) @property diff --git a/starlite/connection/request.py b/starlite/connection/request.py index 79106a745f..1e8592d758 100644 --- a/starlite/connection/request.py +++ b/starlite/connection/request.py @@ -197,6 +197,6 @@ async def send_push_promise(self, path: str) -> None: if "http.response.push" in extensions: raw_headers = [] for name in SERVER_PUSH_HEADERS: - for value in self.headers.getlist(name): + for value in self.headers.getall(name, []): raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) await self.send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/starlite/connection/websocket.py b/starlite/connection/websocket.py index d9b20eeba4..9093545800 100644 --- a/starlite/connection/websocket.py +++ b/starlite/connection/websocket.py @@ -12,7 +12,6 @@ ) from orjson import OPT_OMIT_MICROSECONDS, OPT_SERIALIZE_NUMPY, dumps, loads -from starlette.datastructures import Headers from starlite.connection.base import ( ASGIConnection, @@ -21,6 +20,7 @@ empty_receive, empty_send, ) +from starlite.datastructures.headers import Headers from starlite.exceptions import WebSocketDisconnect, WebSocketException from starlite.status_codes import WS_1000_NORMAL_CLOSURE from starlite.utils.serialization import default_serializer @@ -137,10 +137,10 @@ async def accept( _headers: List[Tuple[bytes, bytes]] = headers if isinstance(headers, list) else [] if isinstance(headers, dict): - _headers = Headers(headers=headers).raw + _headers = Headers(headers=headers).to_header_list() if isinstance(headers, Headers): - _headers = headers.raw + _headers = headers.to_header_list() event: "WebSocketAcceptEvent" = { "type": "websocket.accept", diff --git a/starlite/datastructures/__init__.py b/starlite/datastructures/__init__.py index 03ed427f4d..4036535933 100644 --- a/starlite/datastructures/__init__.py +++ b/starlite/datastructures/__init__.py @@ -1,7 +1,12 @@ from starlite.datastructures.background_tasks import BackgroundTask, BackgroundTasks from starlite.datastructures.cookie import Cookie from starlite.datastructures.form_multi_dict import FormMultiDict -from starlite.datastructures.headers import CacheControlHeader, ETag +from starlite.datastructures.headers import ( + CacheControlHeader, + ETag, + Headers, + MutableScopeHeaders, +) from starlite.datastructures.provide import Provide from starlite.datastructures.response_containers import ( File, @@ -17,11 +22,13 @@ __all__ = ( "BackgroundTask", "BackgroundTasks", - "Cookie", "CacheControlHeader", + "Cookie", "ETag", "File", "FormMultiDict", + "Headers", + "MutableScopeHeaders", "Provide", "Redirect", "ResponseContainer", diff --git a/starlite/datastructures/form_multi_dict.py b/starlite/datastructures/form_multi_dict.py index ae225d7231..57d2726d2c 100644 --- a/starlite/datastructures/form_multi_dict.py +++ b/starlite/datastructures/form_multi_dict.py @@ -1,11 +1,37 @@ -from typing import Any +from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union -from starlette.datastructures import ImmutableMultiDict +from multidict import MultiDict, MultiDictProxy from starlite.datastructures.upload_file import UploadFile +from starlite.utils import deprecated -class FormMultiDict(ImmutableMultiDict[str, Any]): +class FormMultiDict(MultiDictProxy[Any]): + def __init__( + self, args: Optional[Union["FormMultiDict", Mapping[str, Any], Iterable[Tuple[str, Any]]]] = None + ) -> None: + super().__init__(MultiDict(args or {})) + + def multi_items(self) -> List[Tuple[str, Any]]: + """Get all keys and values, including duplicates. + + Returns: + A list of tuples containing key-value pairs + """ + return [(key, value) for key in set(self) for value in self.getall(key)] + + @deprecated("1.36.0", alternative="FormMultiDict.getall") + def getlist(self, key: str) -> List[str]: + """Get all values. + + Args: + key: The key + + Returns: + A list of values + """ + return super().getall(key, []) + async def close(self) -> None: """Closes all files in the multi-dict. diff --git a/starlite/datastructures/headers.py b/starlite/datastructures/headers.py index 705b8095cd..1b852f7dcc 100644 --- a/starlite/datastructures/headers.py +++ b/starlite/datastructures/headers.py @@ -1,15 +1,182 @@ import re from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Optional - +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, + cast, +) + +from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping from pydantic import BaseModel, Extra, Field, ValidationError, validator from typing_extensions import Annotated from starlite.exceptions import ImproperlyConfiguredException +if TYPE_CHECKING: + from starlite.types.asgi_types import HeaderScope, Message, RawHeadersList + ETAG_RE = re.compile(r'([Ww]/)?"(.+)"') +def _encode_headers(headers: Iterable[Tuple[str, str]]) -> "RawHeadersList": + return [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers] + + +class Headers(CIMultiDictProxy[str]): + """An immutable, case-insensitive [multidict](https://multidict.aio- + libs.org/en/stable/multidict.html#cimultidictproxy) for HTTP headers.""" + + def __init__(self, headers: Optional[Union[Mapping[str, str], "RawHeadersList", MultiMapping]] = None) -> None: + if not isinstance(headers, MultiMapping): + headers_: Union[Mapping[str, str], List[Tuple[str, str]]] = {} + if isinstance(headers, list): + headers_ = [(key.decode("latin-1"), value.decode("latin-1")) for key, value in headers] + elif headers: + headers_ = headers + super().__init__(CIMultiDict(headers_)) + else: + super().__init__(headers) + + @classmethod + def from_scope(cls, scope: "HeaderScope") -> "Headers": + """ + Create headers from a send-message. + Args: + scope: An ASGI Scope + + Returns: + Headers + + Raises: + ValueError: If the message does not have a `headers` key + """ + return cls(scope["headers"]) + + def to_header_list(self) -> "RawHeadersList": + """Raw header value. + + Returns: + A list of tuples contain the header and header-value as bytes + """ + return _encode_headers((key, value) for key in set(self) for value in self.getall(key)) + + +class MutableScopeHeaders(MutableMapping): + """A case-insensitive, multidict-like structure that can be used to mutate + headers within a [Scope][starlite.types.Scope].""" + + def __init__(self, scope: Optional["HeaderScope"] = None) -> None: + self.headers: "RawHeadersList" + if scope is not None: + self.headers = scope["headers"] + else: + self.headers = [] + + @classmethod + def from_message(cls, message: "Message") -> "MutableScopeHeaders": + """Construct a header from a [Message][starlite.types.Message]. + + Raises: + ValueError: If the message does not have a `headers` key + """ + if "headers" not in message: + raise ValueError(f"Invalid message type: {message['type']!r}") + return cls(cast("HeaderScope", message)) + + def add(self, name: str, value: str) -> None: + """Add a header to the scope keeping duplicates.""" + self.headers.append((name.lower().encode("latin-1"), value.encode("latin-1"))) + + def getall(self, name: str, default: Optional[List[str]] = None) -> List[str]: + """Get all values of a header. + + Args: + name: Header name + default: Default value to return if `name` is not found + + Returns: + A list of strings + + Raises: + `KeyError` if no header for `name` was found and `default` is not given + """ + name = name.lower() + values = [ + header_value.decode("latin-1") + for header_name, header_value in self.headers + if header_name.decode("latin-1").lower() == name + ] + if not values: + if default: + return default + raise KeyError + return values + + def extend_header_value(self, name: str, value: str) -> None: + """Extend a multivalued header (that is, a header that can take a comma + separated list). If the header previously did not exist, it will be + added. + + Args: + name: Header name + value: Header value to add + + Returns: + None + """ + existing = self.get(name) + if existing is not None: + value = ", ".join([*existing.split(","), value]) + self[name] = value + + def __getitem__(self, name: str) -> str: + """Get the first header matching `name`""" + name = name.lower() + for header in self.headers: + if header[0].decode("latin-1").lower() == name: + return header[1].decode("latin-1") + raise KeyError + + def _find_indices(self, name: str) -> List[int]: + name = name.lower() + return [i for i, (name_, _) in enumerate(self.headers) if name_.decode("latin-1").lower() == name] + + def __setitem__(self, name: str, value: str) -> None: + """Set a header in the scope, overwriting duplicates.""" + name_encoded = name.lower().encode("latin-1") + value_encoded = value.encode("latin-1") + indices = self._find_indices(name) + if not indices: + self.headers.append((name_encoded, value_encoded)) + else: + for i in indices[1:]: + del self.headers[i] + self.headers[indices[0]] = (name_encoded, value_encoded) + + def __delitem__(self, name: str) -> None: + """Delete all headers matching `name`""" + indices = self._find_indices(name) + for i in indices[::-1]: + del self.headers[i] + + def __len__(self) -> int: + return len(self.headers) + + def __iter__(self) -> Iterator[str]: + """iter over header names including duplicates.""" + return iter(h[0].decode("latin-1") for h in self.headers) + + class Header(BaseModel, ABC): """An abstract type for HTTP headers.""" diff --git a/starlite/middleware/compression/brotli.py b/starlite/middleware/compression/brotli.py index 7a6e1c52c8..8b1d109aa0 100644 --- a/starlite/middleware/compression/brotli.py +++ b/starlite/middleware/compression/brotli.py @@ -2,10 +2,11 @@ from enum import Enum from typing import TYPE_CHECKING, Optional, cast -from starlette.datastructures import Headers, MutableHeaders from starlette.middleware.gzip import GZipResponder from typing_extensions import Literal +from starlite.datastructures import MutableScopeHeaders +from starlite.datastructures.headers import Headers from starlite.enums import ScopeType from starlite.exceptions import MissingDependencyException from starlite.middleware.base import MiddlewareProtocol @@ -75,7 +76,7 @@ def __init__( async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None: if scope["type"] == ScopeType.HTTP: - headers = Headers(scope=scope) + headers = Headers.from_scope(scope) if CompressionEncoding.BROTLI in headers.get("Accept-Encoding", ""): await self.brotli_responder(scope, receive, send) return @@ -171,19 +172,19 @@ async def send_wrapper(message: "Message") -> None: elif not more_body: # Standard Brotli response. body = self.br_file.process(body) + self.br_file.finish() - headers = MutableHeaders(raw=initial_message["headers"]) + headers = MutableScopeHeaders(initial_message) headers["Content-Encoding"] = CompressionEncoding.BROTLI headers["Content-Length"] = str(len(body)) - headers.add_vary_header("Accept-Encoding") + headers.extend_header_value("vary", "Accept-Encoding") message["body"] = body await send(initial_message) await send(message) else: # Initial body in streaming Brotli response. - headers = MutableHeaders(raw=initial_message["headers"]) + headers = MutableScopeHeaders(initial_message) headers["Content-Encoding"] = CompressionEncoding.BROTLI - headers.add_vary_header("Accept-Encoding") + headers.extend_header_value("vary", "Accept-Encoding") del headers["Content-Length"] self.br_buffer.write(self.br_file.process(body) + self.br_file.flush()) message["body"] = self.br_buffer.getvalue() diff --git a/starlite/middleware/csrf.py b/starlite/middleware/csrf.py index 2365dfde4d..02b52400b0 100644 --- a/starlite/middleware/csrf.py +++ b/starlite/middleware/csrf.py @@ -2,8 +2,7 @@ import secrets from typing import TYPE_CHECKING, Any, Optional, Pattern -from starlette.datastructures import MutableHeaders - +from starlite.datastructures import MutableScopeHeaders from starlite.datastructures.cookie import Cookie from starlite.enums import RequestEncodingType, ScopeType from starlite.exceptions import PermissionDeniedException @@ -112,7 +111,7 @@ async def send_wrapper(message: "Message") -> None: return send_wrapper def _set_cookie_if_needed(self, message: "HTTPSendMessage", token: str) -> None: - headers = MutableHeaders(scope=message) + headers = MutableScopeHeaders.from_message(message) cookie = Cookie( key=self.config.cookie_name, value=token, @@ -122,7 +121,7 @@ def _set_cookie_if_needed(self, message: "HTTPSendMessage", token: str) -> None: samesite=self.config.cookie_samesite, domain=self.config.cookie_domain, ) - headers.append("set-cookie", cookie.to_header(header="")) + headers.add("set-cookie", cookie.to_header(header="")) def _decode_csrf_token(self, token: str) -> Optional[str]: """Decode a CSRF token and validate its HMAC.""" diff --git a/starlite/middleware/rate_limit.py b/starlite/middleware/rate_limit.py index 9f8d7e84e1..70eb3e4e4b 100644 --- a/starlite/middleware/rate_limit.py +++ b/starlite/middleware/rate_limit.py @@ -16,10 +16,10 @@ from orjson import dumps, loads from pydantic import BaseModel, validator -from starlette.datastructures import MutableHeaders from typing_extensions import Literal from starlite.connection import Request +from starlite.datastructures import MutableScopeHeaders from starlite.enums import ScopeType from starlite.exceptions import TooManyRequestsException from starlite.middleware.base import DefineMiddleware @@ -128,9 +128,9 @@ async def send_wrapper(message: "Message") -> None: """ if message["type"] == "http.response.start": message.setdefault("headers", []) - headers = MutableHeaders(scope=message) + headers = MutableScopeHeaders(message) for key, value in self.create_response_headers(cache_object=cache_object).items(): - headers.append(key, value) + headers.add(key, value) await send(message) return send_wrapper diff --git a/starlite/middleware/session/base.py b/starlite/middleware/session/base.py index 3bbc1925a3..951e431ca4 100644 --- a/starlite/middleware/session/base.py +++ b/starlite/middleware/session/base.py @@ -19,10 +19,10 @@ from orjson import OPT_SERIALIZE_NUMPY, dumps, loads from pydantic import BaseConfig, BaseModel, PrivateAttr, conint, conlist, constr -from starlette.datastructures import MutableHeaders from typing_extensions import Literal from starlite import ASGIConnection, Cookie, DefineMiddleware +from starlite.datastructures import MutableScopeHeaders from starlite.middleware.base import MiddlewareProtocol from starlite.middleware.util import should_bypass_middleware from starlite.types import Empty @@ -273,7 +273,7 @@ async def store_in_message( None """ scope = connection.scope - headers = MutableHeaders(scope=message) + headers = MutableScopeHeaders.from_message(message) session_id = connection.cookies.get(self.config.key) if session_id == "null": session_id = None @@ -287,7 +287,7 @@ async def store_in_message( if scope_session is Empty: await self.delete(session_id) - headers.append( + headers.add( "Set-Cookie", Cookie(value="null", key=self.config.key, expires=0, **cookie_params).to_header(header=""), ) diff --git a/starlite/middleware/session/cookie_backend.py b/starlite/middleware/session/cookie_backend.py index 342c3e5b29..07903812cd 100644 --- a/starlite/middleware/session/cookie_backend.py +++ b/starlite/middleware/session/cookie_backend.py @@ -8,8 +8,8 @@ from orjson import dumps, loads from pydantic import SecretBytes, validator -from starlette.datastructures import MutableHeaders +from starlite.datastructures import MutableScopeHeaders from starlite.datastructures.cookie import Cookie from starlite.exceptions import MissingDependencyException from starlite.types import Empty @@ -128,14 +128,14 @@ async def store_in_message( """ scope = connection.scope - headers = MutableHeaders(scope=message) + headers = MutableScopeHeaders.from_message(message) cookie_keys = self.get_cookie_keys(connection) if scope_session and scope_session is not Empty: data = self.dump_data(scope_session, scope=scope) cookie_params = self.config.dict(exclude_none=True, exclude={"secret", "key"}) for cookie in self._create_session_cookies(data, cookie_params): - headers.append("Set-Cookie", cookie.to_header(header="")) + headers.add("Set-Cookie", cookie.to_header(header="")) # Cookies with the same key overwrite the earlier cookie with that key. To expire earlier session # cookies, first check how many session cookies will not be overwritten in this upcoming response. # If leftover cookies are greater than or equal to 1, that means older session cookies have to be @@ -146,7 +146,7 @@ async def store_in_message( for cookie_key in cookies_to_clear: cookie_params = self.config.dict(exclude_none=True, exclude={"secret", "max_age", "key"}) - headers.append( + headers.add( "Set-Cookie", Cookie(value="null", key=cookie_key, expires=0, **cookie_params).to_header(header=""), ) diff --git a/starlite/testing/test_client/client.py b/starlite/testing/test_client/client.py index b1c730d78a..4fbc869d13 100644 --- a/starlite/testing/test_client/client.py +++ b/starlite/testing/test_client/client.py @@ -16,9 +16,9 @@ from urllib.parse import urljoin from anyio.from_thread import BlockingPortal, start_blocking_portal -from starlette.datastructures import MutableHeaders from starlite import ASGIConnection, HttpMethod, ImproperlyConfiguredException +from starlite.datastructures import MutableScopeHeaders from starlite.exceptions import MissingDependencyException from starlite.testing.test_client.life_span_handler import LifeSpanHandler from starlite.testing.test_client.transport import ( @@ -26,6 +26,7 @@ TestClientTransport, ) from starlite.types import AnyIOBackend, ASGIApp, HTTPResponseStartEvent +from starlite.utils import deprecated try: from httpx import USE_CLIENT_DEFAULT, Client, Cookies, Request, Response @@ -57,9 +58,9 @@ T = TypeVar("T", bound=ASGIApp) -def fake_http_send_message(headers: MutableHeaders) -> HTTPResponseStartEvent: +def fake_http_send_message(headers: MutableScopeHeaders) -> HTTPResponseStartEvent: headers.setdefault("content-type", "application/text") - return HTTPResponseStartEvent(type="http.response.start", status=200, headers=headers.raw) + return HTTPResponseStartEvent(type="http.response.start", status=200, headers=headers.headers) def fake_asgi_connection(app: ASGIApp, cookies: Dict[str, str]) -> ASGIConnection[Any, Any, Any]: @@ -146,12 +147,8 @@ def __init__( ) @property + @deprecated("1.34.0", alternative="session_backend", pending=True, kind="property") def session(self) -> "CookieBackend": - warnings.warn( - "Accessing the session via this property is deprecated and will be removed in future version." - "To access the session backend directly, use the session_backend attribute", - PendingDeprecationWarning, - ) from starlite.middleware.session.cookie_backend import CookieBackend if not isinstance(self.session_backend, CookieBackend): @@ -604,6 +601,7 @@ def websocket_connect( else: raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + @deprecated("1.34.0", alternative="set_session_data", pending=True) def create_session_cookies(self, session_data: Dict[str, Any]) -> Dict[str, str]: """Creates raw session cookies that are loaded into the session by the Session Middleware. It creates cookies the same way as if they are @@ -643,15 +641,11 @@ def test_something(self, test_client: TestClient) -> None: test_client.get(url="/my_route") ``` """ - warnings.warn( - "This method is deprecated and will be removed in a future version. Use" - "TestClient.set_session_data instead", - PendingDeprecationWarning, - ) if self._session_backend is None: return {} return self._create_session_cookies(self.session, session_data) + @deprecated("1.34.0", alternative="get_session_data", pending=True) def get_session_from_cookies(self) -> Dict[str, Any]: """Raw session cookies are a serialized image of session which are created by session middleware and sent with the response. To assert @@ -673,11 +667,6 @@ def test_something(self, test_client: TestClient) -> None: assert "user" in session ``` """ - warnings.warn( - "This method is deprecated and will be removed in a future version. Use" - "TestClient.get_session_data instead", - PendingDeprecationWarning, - ) if self._session_backend is None: return {} return self.get_session_data() @@ -689,7 +678,7 @@ def _create_session_cookies(backend: "CookieBackend", data: Dict[str, Any]) -> D async def _set_session_data_async(self, data: Dict[str, Any]) -> None: # TODO: Expose this in the async client - mutable_headers = MutableHeaders({}) + mutable_headers = MutableScopeHeaders() await self.session_backend.store_in_message( scope_session=data, message=fake_http_send_message(mutable_headers), @@ -698,7 +687,7 @@ async def _set_session_data_async(self, data: Dict[str, Any]) -> None: cookies=dict(self.cookies), ), ) - response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.raw) + response = Response(200, request=Request("GET", self.base_url), headers=mutable_headers.headers) cookies = Cookies(CookieJar()) cookies.extract_cookies(response) diff --git a/starlite/types/asgi_types.py b/starlite/types/asgi_types.py index cb91a34a7f..886aa01051 100644 --- a/starlite/types/asgi_types.py +++ b/starlite/types/asgi_types.py @@ -63,13 +63,16 @@ class ASGIVersion(TypedDict): version: Literal["3.0"] -class BaseScope(TypedDict): +class HeaderScope(TypedDict): + headers: "RawHeadersList" + + +class BaseScope(HeaderScope): app: "Starlite" asgi: ASGIVersion auth: Any client: Optional[Tuple[str, int]] extensions: Optional[Dict[str, Dict[object, object]]] - headers: List[Tuple[bytes, bytes]] http_version: str path: str path_params: Dict[str, str] @@ -106,10 +109,9 @@ class HTTPRequestEvent(TypedDict): more_body: bool -class HTTPResponseStartEvent(TypedDict): +class HTTPResponseStartEvent(HeaderScope): type: Literal["http.response.start"] status: int - headers: List[Tuple[bytes, bytes]] class HTTPResponseBodyEvent(TypedDict): @@ -118,10 +120,9 @@ class HTTPResponseBodyEvent(TypedDict): more_body: bool -class HTTPServerPushEvent(TypedDict): +class HTTPServerPushEvent(HeaderScope): type: Literal["http.response.push"] path: str - headers: List[Tuple[bytes, bytes]] class HTTPDisconnectEvent(TypedDict): @@ -132,10 +133,9 @@ class WebSocketConnectEvent(TypedDict): type: Literal["websocket.connect"] -class WebSocketAcceptEvent(TypedDict): +class WebSocketAcceptEvent(HeaderScope): type: Literal["websocket.accept"] subprotocol: Optional[str] - headers: List[Tuple[bytes, bytes]] class WebSocketReceiveEvent(TypedDict): @@ -150,10 +150,9 @@ class WebSocketSendEvent(TypedDict): text: Optional[str] -class WebSocketResponseStartEvent(TypedDict): +class WebSocketResponseStartEvent(HeaderScope): type: Literal["websocket.http.response.start"] status: int - headers: List[Tuple[bytes, bytes]] class WebSocketResponseBodyEvent(TypedDict): @@ -239,3 +238,4 @@ class LifeSpanShutdownFailedEvent(TypedDict): Receive = Callable[..., Awaitable[Union[HTTPReceiveMessage, WebSocketReceiveMessage]]] Send = Callable[[Message], Awaitable[None]] ASGIApp = Callable[[Scope, Receive, Send], Awaitable[None]] +RawHeadersList = List[Tuple[bytes, bytes]] diff --git a/starlite/utils/__init__.py b/starlite/utils/__init__.py index 6d254b30a9..83391dcc3b 100644 --- a/starlite/utils/__init__.py +++ b/starlite/utils/__init__.py @@ -1,3 +1,5 @@ +from starlite.utils.deprecation import deprecated, warn_deprecation + from .csrf import generate_csrf_hash, generate_csrf_token from .dependency import is_dependency_field, should_skip_dependency_validation from .exception import ( @@ -44,6 +46,8 @@ "create_exception_response", "create_parsed_model_field", "default_serializer", + "deprecated", + "warn_deprecation", "find_index", "generate_csrf_hash", "generate_csrf_token", diff --git a/starlite/utils/deprecation.py b/starlite/utils/deprecation.py new file mode 100644 index 0000000000..c3fdc30809 --- /dev/null +++ b/starlite/utils/deprecation.py @@ -0,0 +1,96 @@ +import inspect +from functools import wraps +from typing import Callable, Optional, TypeVar +from warnings import warn + +from typing_extensions import Literal, ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") +DeprecatedKind = Literal["function", "method", "attribute", "property", "class", "parameter"] + + +def warn_deprecation( + version: str, + deprecated_name: str, + kind: DeprecatedKind, + *, + removal_in: Optional[str] = None, + alternative: Optional[str] = None, + info: Optional[str] = None, + pending: bool = False, +) -> None: + """Warn about a call to a (soon to be) deprecated function. + + Args: + version: Starlite version where the deprecation will occur + deprecated_name: Name of the deprecated function + removal_in: Starlite version where the deprecated function will be removed + alternative: Name of a function that should be used instead + info: Additional information + pending: Use `PendingDeprecationWarning` instead of `DeprecationWarning` + kind: Type of the deprecated thing + """ + parts = [] + access_type = "Call to" if kind in {"function", "method"} else "Use of" + removal_in = removal_in or "the next major version" + if pending: + parts.append(f"{access_type} {kind} awaiting deprecation {deprecated_name!r}") + else: + parts.append(f"{access_type} deprecated {kind} {deprecated_name!r}") + parts.append(f"Deprecated in starlite {version}") + if removal_in: + parts.append(f"This {kind} will be removed in {removal_in}") + if alternative: + parts.append(f"Use {alternative!r} instead") + if info: + parts.append(info) + + text = ". ".join(parts) + warning_class = PendingDeprecationWarning if pending else DeprecationWarning + + warn(text, warning_class) + + +def deprecated( + version: str, + *, + removal_in: Optional[str] = None, + alternative: Optional[str] = None, + info: Optional[str] = None, + pending: bool = False, + kind: Optional[Literal["function", "method", "property"]] = None, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + """Create a decorator wrapping a function, method or property with a + warning call about a (pending) deprecation. + + Args: + version: Starlite version where the deprecation will occur + removal_in: Starlite version where the deprecated function will be removed + alternative: Name of a function that should be used instead + info: Additional information + pending: Use `PendingDeprecationWarning` instead of `DeprecationWarning` + kind: Type of the deprecated callable. If `None`, will use `inspect` to figure + out if it's a function or method + + Returns: + A decorator wrapping the function call with a warning + """ + + def decorator(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + warn_deprecation( + version=version, + deprecated_name=func.__name__, + info=info, + alternative=alternative, + pending=pending, + removal_in=removal_in, + kind=kind or ("method" if inspect.ismethod(func) else "function"), + ) + return func(*args, **kwargs) + + return wrapped + + return decorator diff --git a/tests/app/test_before_send.py b/tests/app/test_before_send.py index 0f83d53983..7adb45ba4f 100644 --- a/tests/app/test_before_send.py +++ b/tests/app/test_before_send.py @@ -1,8 +1,7 @@ from typing import TYPE_CHECKING, Dict -from starlette.datastructures import MutableHeaders - from starlite import get +from starlite.datastructures import MutableScopeHeaders from starlite.status_codes import HTTP_200_OK from starlite.testing import create_test_client @@ -18,8 +17,8 @@ def handler() -> Dict[str, str]: async def before_send_hook_handler(message: "Message", state: "State") -> None: if message["type"] == "http.response.start": - headers = MutableHeaders(scope=message) - headers.append("My Header", state.message) + headers = MutableScopeHeaders(message) + headers.add("My Header", state.message) def on_startup(state: "State") -> None: state.message = "value injected during send" diff --git a/tests/connection/websocket/test_websocket.py b/tests/connection/websocket/test_websocket.py index 62d5f727ec..d36aabce9c 100644 --- a/tests/connection/websocket/test_websocket.py +++ b/tests/connection/websocket/test_websocket.py @@ -8,9 +8,9 @@ import anyio import pytest -from starlette.datastructures import Headers from starlite.connection import WebSocket +from starlite.datastructures.headers import Headers from starlite.exceptions import WebSocketDisconnect, WebSocketException from starlite.handlers.websocket import websocket from starlite.status_codes import WS_1001_GOING_AWAY diff --git a/tests/datastructures/test_headers.py b/tests/datastructures/test_headers.py index 70ac6866e5..d71062fa70 100644 --- a/tests/datastructures/test_headers.py +++ b/tests/datastructures/test_headers.py @@ -1,8 +1,160 @@ +from typing import TYPE_CHECKING + import pytest from pydantic import ValidationError +from pytest import FixtureRequest -from starlite.datastructures import CacheControlHeader, ETag +from starlite.datastructures import ( + CacheControlHeader, + ETag, + Headers, + MutableScopeHeaders, +) from starlite.exceptions import ImproperlyConfiguredException +from starlite.types.asgi_types import HTTPResponseBodyEvent, HTTPResponseStartEvent + +if TYPE_CHECKING: + from starlite.types.asgi_types import RawHeadersList + + +@pytest.fixture +def raw_headers() -> "RawHeadersList": + return [(b"foo", b"bar")] + + +@pytest.fixture +def mutable_headers(raw_headers: "RawHeadersList") -> MutableScopeHeaders: + return MutableScopeHeaders({"headers": raw_headers}) + + +@pytest.fixture(params=[True, False]) +def existing_headers_key(request: FixtureRequest) -> str: + return "Foo" if request.param else "foo" + + +def test_headers_from_mapping() -> None: + headers = Headers({"foo": "bar", "baz": "zab"}) + assert headers["foo"] == "bar" + assert headers["baz"] == "zab" + + +def test_headers_from_raw_list() -> None: + headers = Headers([(b"foo", b"bar"), (b"foo", b"baz")]) + assert headers.getall("foo") == ["bar", "baz"] + + +def test_headers_from_scope(raw_headers: "RawHeadersList") -> None: + headers = Headers.from_scope( + HTTPResponseStartEvent(type="http.response.start", status=200, headers=[(b"foo", b"bar"), (b"foo", b"baz")]) + ) + assert headers.getall("foo") == ["bar", "baz"] + + +def test_headers_to_header_list() -> None: + raw = [(b"foo", b"bar"), (b"foo", b"baz")] + headers = Headers(raw) + assert headers.to_header_list() == raw + + +def test_mutable_scope_headers_from_message(raw_headers: "RawHeadersList") -> None: + headers = MutableScopeHeaders.from_message( + HTTPResponseStartEvent(type="http.response.start", status=200, headers=raw_headers) + ) + assert headers.headers == raw_headers + + +def test_mutable_scope_headers_from_message_invalid_type() -> None: + with pytest.raises(ValueError): + MutableScopeHeaders.from_message(HTTPResponseBodyEvent(type="http.response.body", body=b"", more_body=False)) + + +def test_mutable_scope_headers_add( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders, existing_headers_key: str +) -> None: + mutable_headers.add(existing_headers_key, "baz") + assert raw_headers == [(b"foo", b"bar"), (b"foo", b"baz")] + + +def test_mutable_scope_headers_getall_singular_value( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders, existing_headers_key: str +) -> None: + assert mutable_headers.getall(existing_headers_key) == ["bar"] + + +def test_mutable_scope_headers_getall_multi_value( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders, existing_headers_key: str +) -> None: + mutable_headers.add(existing_headers_key, "baz") + assert mutable_headers.getall("foo") == ["bar", "baz"] + + +def test_mutable_scope_headers_getall_not_found_no_default(mutable_headers: MutableScopeHeaders) -> None: + with pytest.raises(KeyError): + mutable_headers.getall("bar") + + +def test_mutable_scope_headers_getall_not_found_default(mutable_headers: MutableScopeHeaders) -> None: + assert mutable_headers.getall("bar", ["default"]) == ["default"] + + +def test_mutable_scope_headers_extend_header_value( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders +) -> None: + mutable_headers.extend_header_value("foo", "baz") + assert raw_headers == [(b"foo", b"bar, baz")] + + +def test_mutable_scope_headers_extend_header_value_new_header( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders +) -> None: + mutable_headers.extend_header_value("bar", "baz") + assert raw_headers == [(b"foo", b"bar"), (b"bar", b"baz")] + + +def test_mutable_scope_headers_getitem(mutable_headers: MutableScopeHeaders, existing_headers_key: str) -> None: + assert mutable_headers[existing_headers_key] == "bar" + + +def test_mutable_scope_headers_getitem_not_found(mutable_headers: MutableScopeHeaders) -> None: + with pytest.raises(KeyError): + mutable_headers["bar"] + + +def test_mutable_scope_headers_setitem_existing_key( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders, existing_headers_key: str +) -> None: + mutable_headers[existing_headers_key] = "baz" + assert raw_headers == [(b"foo", b"baz")] + + +def test_mutable_scope_headers_setitem_new_key( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders +) -> None: + mutable_headers["bar"] = "baz" + assert raw_headers == [(b"foo", b"bar"), (b"bar", b"baz")] + + +def test_mutable_scope_headers_setitem_delitem( + raw_headers: "RawHeadersList", mutable_headers: MutableScopeHeaders, existing_headers_key: str +) -> None: + mutable_headers.add("foo", "baz") + mutable_headers["bar"] = "baz" + del mutable_headers[existing_headers_key] + assert raw_headers == [(b"bar", b"baz")] + + +def test_mutable_scope_header_len(mutable_headers: MutableScopeHeaders) -> None: + assert len(mutable_headers) == 1 + mutable_headers.add("foo", "bar") + assert len(mutable_headers) == 2 + mutable_headers["bar"] = "baz" + assert len(mutable_headers) == 3 + + +def test_mutable_scope_header_iter(mutable_headers: MutableScopeHeaders) -> None: + mutable_headers.add("foo", "baz") + mutable_headers["bar"] = "zab" + assert list(mutable_headers) == ["foo", "foo", "bar"] def test_cache_control_to_header() -> None: