Skip to content

Commit

Permalink
Add OptionalQuery extractor (#2310)
Browse files Browse the repository at this point in the history
Co-authored-by: David Pedersen <[email protected]>
  • Loading branch information
mikhailantoshkin and davidpdrsn authored Nov 18, 2023
1 parent 6e984b7 commit 39cc596
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 2 deletions.
2 changes: 2 additions & 0 deletions axum-extra/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ and this project adheres to [Semantic Versioning].

# Unreleased

- **added:** `OptionalQuery` extractor ([#2310])
- **added:** `TypedHeader` which used to be in `axum` ([#1850])
- **added:** `Clone` implementation for `ErasedJson` ([#2142])
- **breaking:** Update to prost 0.12. Used for the `Protobuf` extractor
- **breaking:** Make `tokio` an optional dependency

[#1850]: https://github.com/tokio-rs/axum/pull/1850
[#2142]: https://github.com/tokio-rs/axum/pull/2142
[#2310]: https://github.com/tokio-rs/axum/pull/2310

# 0.7.4 (18. April, 2023)

Expand Down
2 changes: 1 addition & 1 deletion axum-extra/src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use self::cookie::SignedCookieJar;
pub use self::form::{Form, FormRejection};

#[cfg(feature = "query")]
pub use self::query::{Query, QueryRejection};
pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejection};

#[cfg(feature = "multipart")]
pub use self::multipart::Multipart;
Expand Down
203 changes: 202 additions & 1 deletion axum-extra/src/extract/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,124 @@ impl std::error::Error for QueryRejection {
}
}

/// Extractor that deserializes query strings into `None` if no query parameters are present.
/// Otherwise behaviour is identical to [`Query`]
///
/// `T` is expected to implement [`serde::Deserialize`].
///
/// # Example
///
/// ```rust,no_run
/// use axum::{routing::get, Router};
/// use axum_extra::extract::OptionalQuery;
/// use serde::Deserialize;
///
/// #[derive(Deserialize)]
/// struct Pagination {
/// page: usize,
/// per_page: usize,
/// }
///
/// // This will parse query strings like `?page=2&per_page=30` into `Some(Pagination)` and
/// // empty query string into `None`
/// async fn list_things(OptionalQuery(pagination): OptionalQuery<Pagination>) {
/// match pagination {
/// Some(Pagination{ page, per_page }) => { /* return specified page */ },
/// None => { /* return fist page */ }
/// }
/// // ...
/// }
///
/// let app = Router::new().route("/list_things", get(list_things));
/// # let _: Router = app;
/// ```
///
/// If the query string cannot be parsed it will reject the request with a `400
/// Bad Request` response.
///
/// For handling values being empty vs missing see the [query-params-with-empty-strings][example]
/// example.
///
/// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs
#[cfg_attr(docsrs, doc(cfg(feature = "query")))]
#[derive(Debug, Clone, Copy, Default)]
pub struct OptionalQuery<T>(pub Option<T>);

#[async_trait]
impl<T, S> FromRequestParts<S> for OptionalQuery<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = OptionalQueryRejection;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if let Some(query) = parts.uri.query() {
let value = serde_html_form::from_str(query).map_err(|err| {
OptionalQueryRejection::FailedToDeserializeQueryString(Error::new(err))
})?;
Ok(OptionalQuery(Some(value)))
} else {
Ok(OptionalQuery(None))
}
}
}

impl<T> std::ops::Deref for OptionalQuery<T> {
type Target = Option<T>;

#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<T> std::ops::DerefMut for OptionalQuery<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

/// Rejection used for [`OptionalQuery`].
///
/// Contains one variant for each way the [`OptionalQuery`] extractor can fail.
#[derive(Debug)]
#[non_exhaustive]
#[cfg(feature = "query")]
pub enum OptionalQueryRejection {
#[allow(missing_docs)]
FailedToDeserializeQueryString(Error),
}

impl IntoResponse for OptionalQueryRejection {
fn into_response(self) -> Response {
match self {
Self::FailedToDeserializeQueryString(inner) => (
StatusCode::BAD_REQUEST,
format!("Failed to deserialize query string: {inner}"),
)
.into_response(),
}
}
}

impl fmt::Display for OptionalQueryRejection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::FailedToDeserializeQueryString(inner) => inner.fmt(f),
}
}
}

impl std::error::Error for OptionalQueryRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::FailedToDeserializeQueryString(inner) => Some(inner),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -121,7 +239,7 @@ mod tests {
use serde::Deserialize;

#[tokio::test]
async fn supports_multiple_values() {
async fn query_supports_multiple_values() {
#[derive(Deserialize)]
struct Data {
#[serde(rename = "value")]
Expand All @@ -145,4 +263,87 @@ mod tests {
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}

#[tokio::test]
async fn optional_query_supports_multiple_values() {
#[derive(Deserialize)]
struct Data {
#[serde(rename = "value")]
values: Vec<String>,
}

let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
data.map(|Data { values }| values.join(","))
.unwrap_or("None".to_owned())
}),
);

let client = TestClient::new(app);

let res = client
.post("/?value=one&value=two")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("")
.send()
.await;

assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "one,two");
}

#[tokio::test]
async fn optional_query_deserializes_no_parameters_into_none() {
#[derive(Deserialize)]
struct Data {
value: String,
}

let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
match data {
None => "None".into(),
Some(data) => data.value,
}
}),
);

let client = TestClient::new(app);

let res = client.post("/").body("").send().await;

assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "None");
}

#[tokio::test]
async fn optional_query_preserves_parsing_errors() {
#[derive(Deserialize)]
struct Data {
value: String,
}

let app = Router::new().route(
"/",
post(|OptionalQuery(data): OptionalQuery<Data>| async move {
match data {
None => "None".into(),
Some(data) => data.value,
}
}),
);

let client = TestClient::new(app);

let res = client
.post("/?other=something")
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
.body("")
.send()
.await;

assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
}

0 comments on commit 39cc596

Please sign in to comment.