Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add QueryMultiDict #759

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions docs/reference/datastructures/7-form-multi-dict.md

This file was deleted.

13 changes: 13 additions & 0 deletions docs/reference/datastructures/7-multi-dicts.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Multi-Dicts

::: starlite.datastructures.FormMultiDict
options:
members:
- close
- multi_items

::: starlite.datastructures.QueryMultiDict
options:
members:
- from_query_string
- dict
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ nav:
- reference/datastructures/4-background.md
- reference/datastructures/5-response-containers.md
- reference/datastructures/6-upload-file.md
- reference/datastructures/7-form-multi-dict.md
- reference/datastructures/7-multi-dicts.md
- Exceptions:
- reference/exceptions/0-base-exceptions.md
- reference/exceptions/1-http-exceptions.md
Expand Down
2 changes: 1 addition & 1 deletion starlite/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def default_cache_key_builder(request: "Request[Any, Any]") -> str:
Returns:
str: combination of url path and query parameters
"""
query_params: List[Tuple[str, Any]] = list(request.query_params.items())
query_params: List[Tuple[str, Any]] = list(request.query_params.dict().items())
query_params.sort(key=lambda x: x[0])
return request.url.path + urlencode(query_params, doseq=True)

Expand Down
21 changes: 6 additions & 15 deletions starlite/connection/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar, Union, cast

from starlette.datastructures import URL, Address, URLPath

from starlite.datastructures.headers import Headers
from starlite.datastructures.multi_dicts import QueryMultiDict
from starlite.datastructures.state import State
from starlite.exceptions import ImproperlyConfiguredException
from starlite.parsers import parse_cookie_string, parse_query_params
from starlite.parsers import parse_cookie_string
from starlite.types.empty import Empty

if TYPE_CHECKING:
Expand Down Expand Up @@ -156,14 +147,14 @@ def headers(self) -> Headers:
return cast("Headers", self._headers)

@property
def query_params(self) -> Dict[str, List[str]]:
def query_params(self) -> QueryMultiDict:
"""
Returns:
A normalized dict of query parameters. Multiple values for the same key are returned as a list.
"""
if self._parsed_query is Empty:
self._parsed_query = self.scope["_parsed_query"] = parse_query_params(self.scope.get("query_string", b"")) # type: ignore[typeddict-item]
return cast("Dict[str, List[str]]", self._parsed_query)
self._parsed_query = self.scope["_parsed_query"] = QueryMultiDict.from_query_string(self.scope.get("query_string", b"").decode("utf-8")) # type: ignore[typeddict-item]
return cast("QueryMultiDict", self._parsed_query)

@property
def path_params(self) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion starlite/connection/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
empty_receive,
empty_send,
)
from starlite.datastructures.form_multi_dict import FormMultiDict
from starlite.datastructures.multi_dicts import FormMultiDict
from starlite.datastructures.upload_file import UploadFile
from starlite.enums import RequestEncodingType
from starlite.exceptions import InternalServerException
Expand Down
2 changes: 1 addition & 1 deletion starlite/datastructures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from starlite.datastructures.background_tasks import BackgroundTask, BackgroundTasks
from starlite.datastructures.cookie import Cookie
from starlite.datastructures.form_multi_dict import FormMultiDict
from starlite.datastructures.headers import (
CacheControlHeader,
ETag,
Headers,
MutableScopeHeaders,
)
from starlite.datastructures.multi_dicts import FormMultiDict
from starlite.datastructures.provide import Provide
from starlite.datastructures.response_containers import (
File,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from urllib.parse import parse_qsl

from multidict import MultiDict, MultiDictProxy

Expand Down Expand Up @@ -41,3 +42,41 @@ async def close(self) -> None:
for _, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()


class QueryMultiDict(MultiDict[Any]):
Goldziher marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self, args: Optional[Union["QueryMultiDict", Mapping[str, Any], Iterable[Tuple[str, Any]]]] = None
) -> None:
super().__init__(MultiDict(args or {}))

@classmethod
def from_query_string(cls, query_string: str) -> "QueryMultiDict":
"""Creates a QueryMultiDict from a query string.

Args:
query_string: A query string.

Returns:
A QueryMultiDict instance
"""
_bools = {"true": True, "false": False, "True": True, "False": False}
return cls(
(k, v) if v not in _bools else (k, _bools[v]) for k, v in parse_qsl(query_string, keep_blank_values=True)
)

def dict(self) -> Dict[str, List[Any]]:
Goldziher marked this conversation as resolved.
Show resolved Hide resolved
"""

Returns:
A dict of lists
"""
out: Dict[str, List[Any]] = {}

for k, v in self.items():
if k in out:
out[k].append(v)
else:
out[k] = [v]

return out
4 changes: 3 additions & 1 deletion starlite/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,9 @@ def to_kwargs(self, connection: Union["WebSocket", "Request"]) -> Dict[str, Any]
Returns:
A string keyed dictionary of kwargs expected by the handler function and its dependencies.
"""
connection_query_params = {k: self._sequence_or_scalar_param(k, v) for k, v in connection.query_params.items()}
connection_query_params = {
k: self._sequence_or_scalar_param(k, v) for k, v in connection.query_params.dict().items()
}

path_params = self._collect_params(
params=connection.path_params, expected=self.expected_path_params, url=connection.url
Expand Down
41 changes: 3 additions & 38 deletions starlite/parsers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from contextlib import suppress
from functools import reduce
from http.cookies import _unquote as unquote_cookie
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from urllib.parse import parse_qsl, unquote
from typing import TYPE_CHECKING, Any, Dict
from urllib.parse import unquote

from orjson import JSONDecodeError, loads
from pydantic.fields import SHAPE_LIST, SHAPE_SINGLETON
Expand All @@ -15,46 +14,12 @@

from pydantic.fields import ModelField

from starlite.datastructures.form_multi_dict import FormMultiDict
from starlite.datastructures.multi_dicts import FormMultiDict

_true_values = {"True", "true"}
_false_values = {"False", "false"}


def _query_param_reducer(acc: Dict[str, List[Any]], cur: Tuple[str, str]) -> Dict[str, List[str]]:
"""
Reducer function - acc is a dictionary, cur is a tuple of key + value

We use reduce because python implements reduce in C, which makes it faster than a regular for loop in most cases.
"""
key, value = cur

if value in _true_values:
value = True # type: ignore
elif value in _false_values:
value = False # type: ignore

if key in acc:
acc[key].append(value)
else:
acc[key] = [value]
return acc


def parse_query_params(query_string: bytes) -> Dict[str, List[str]]:
"""Parses and normalize a given connection's query parameters into a
regular dictionary.

Args:
query_string: A byte-string containing a query

Returns:
A string keyed dictionary of values.
"""

return reduce(_query_param_reducer, parse_qsl(query_string.decode("utf-8"), keep_blank_values=True), {})


def parse_form_data(media_type: "RequestEncodingType", form_data: "FormMultiDict", field: "ModelField") -> Any:
"""Transforms the multidict into a regular dict, try to load json on all
non-file values.
Expand Down
2 changes: 1 addition & 1 deletion starlite/utils/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def extract_query(self, connection: "ASGIConnection[Any, Any, Any]") -> Any:
Returns:
Either a dictionary with the connection's parsed query string or the raw query byte-string.
"""
return connection.query_params if self.parse_query else connection.scope.get("query_string", b"")
return connection.query_params.dict() if self.parse_query else connection.scope.get("query_string", b"")

@staticmethod
def extract_path_params(connection: "ASGIConnection[Any, Any, Any]") -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion tests/connection/request/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:

client = TestClient(app)
response = client.get("/?a=123&b=456")
assert response.json() == {"params": {"a": ["123"], "b": ["456"]}}
assert response.json() == {"params": {"a": "123", "b": "456"}}


def test_request_headers() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/connection/websocket/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def app(scope: "Scope", receive: "Receive", send: "Send") -> None:

with TestClient(app).websocket_connect("/?a=abc&b=456") as websocket:
data = websocket.receive_json()
assert data == {"params": {"a": ["abc"], "b": ["456"]}}
assert data == {"params": {"a": "abc", "b": "456"}}


def test_websocket_headers() -> None:
Expand Down
28 changes: 28 additions & 0 deletions tests/datastructures/test_multi_dicts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from starlite.datastructures.multi_dicts import QueryMultiDict
from starlite.testing import RequestFactory


def test_query_multi_dict_parse_query_params() -> None:
query = {
"value": "10",
"veggies": ["tomato", "potato", "aubergine"],
"calories": "122.53",
"healthy": True,
"polluting": False,
}
request = RequestFactory().get(query_params=query) # type: ignore
result = QueryMultiDict.from_query_string(query_string=request.scope.get("query_string", b"").decode("utf-8"))

assert result.getall("value") == ["10"]
assert result.getall("veggies") == ["tomato", "potato", "aubergine"]
assert result.getall("calories") == ["122.53"]
assert result.getall("healthy") == [True]
assert result.getall("polluting") == [False]

assert result.dict() == {
"value": ["10"],
"veggies": ["tomato", "potato", "aubergine"],
"calories": ["122.53"],
"healthy": [True],
"polluting": [False],
}
22 changes: 1 addition & 21 deletions tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,7 @@

from starlite import Cookie, RequestEncodingType
from starlite.datastructures import FormMultiDict
from starlite.parsers import parse_cookie_string, parse_form_data, parse_query_params
from starlite.testing import RequestFactory


def test_parse_query_params() -> None:
query = {
"value": "10",
"veggies": ["tomato", "potato", "aubergine"],
"calories": "122.53",
"healthy": True,
"polluting": False,
}
request = RequestFactory().get(query_params=query) # type: ignore[arg-type]
result = parse_query_params(query_string=request.scope.get("query_string", b""))
assert result == {
"value": ["10"],
"veggies": ["tomato", "potato", "aubergine"],
"calories": ["122.53"],
"healthy": [True],
"polluting": [False],
}
from starlite.parsers import parse_cookie_string, parse_form_data


def test_parse_form_data() -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/testing/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
post,
)
from starlite.datastructures import Cookie
from starlite.datastructures.multi_dicts import QueryMultiDict
from starlite.enums import ParamType
from starlite.middleware.session import SessionCookieConfig
from starlite.testing import RequestFactory, TestClient
Expand Down Expand Up @@ -178,7 +179,7 @@ def handler() -> None:
assert request.base_url == f"{scheme}://{server}:{port}{root_path}/"
assert request.url == f"{scheme}://{server}:{port}{root_path}{path}"
assert request.method == HttpMethod.GET
assert request.query_params == {}
assert request.query_params == QueryMultiDict()
assert request.user == user
assert request.auth == auth
assert request.session == session
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_connection_data_extractor() -> None:
assert extracted_data.get("path") == request.scope["path"]
assert extracted_data.get("path") == request.scope["path"]
assert extracted_data.get("path_params") == request.scope["path_params"]
assert extracted_data.get("query") == request.query_params
assert extracted_data.get("query") == request.query_params.dict()
assert extracted_data.get("scheme") == request.scope["scheme"]


Expand All @@ -41,7 +41,7 @@ def test_parse_query() -> None:
)
parsed_extracted_data = ConnectionDataExtractor(parse_query=True)(request)
unparsed_extracted_data = ConnectionDataExtractor()(request)
assert parsed_extracted_data.get("query") == request.query_params
assert parsed_extracted_data.get("query") == request.query_params.dict()
assert unparsed_extracted_data.get("query") == request.scope["query_string"]
# Close to avoid warnings about un-awaited coroutines.
parsed_extracted_data.get("body").close() # type: ignore
Expand Down