From 207f3467c07c8c2de706109b2e880773d93e1dfe Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Mon, 20 Nov 2023 15:29:40 +0100 Subject: [PATCH] Add [`Body::into_data_stream`] --- axum-core/src/body.rs | 45 +++++++++++++++++++++++++++-- axum-extra/src/extract/multipart.rs | 2 +- axum-extra/src/json_lines.rs | 4 +-- axum/src/extract/multipart.rs | 2 +- examples/stream-to-file/src/main.rs | 2 +- 5 files changed, 48 insertions(+), 7 deletions(-) diff --git a/axum-core/src/body.rs b/axum-core/src/body.rs index 00a1e98e54..adc79e87e3 100644 --- a/axum-core/src/body.rs +++ b/axum-core/src/body.rs @@ -66,6 +66,16 @@ impl Body { stream: SyncWrapper::new(stream), }) } + + /// Convert the body into a [`Stream`] of data frames. + /// + /// Non-data frames (such as trailers) will be discarded. Use [`http_body_util::BodyStream`] if + /// you need a [`Stream`] of all frame types. + /// + /// [`http_body_util::BodyStream`]: https://docs.rs/http-body-util/latest/http_body_util/struct.BodyStream.html + pub fn into_data_stream(self) -> BodyDataStream { + BodyDataStream { inner: self } + } } impl Default for Body { @@ -117,13 +127,21 @@ impl http_body::Body for Body { } } -impl Stream for Body { +/// A stream of data frames. +/// +/// Created with [`Body::into_data_stream`]. +#[derive(Debug)] +pub struct BodyDataStream { + inner: Body, +} + +impl Stream for BodyDataStream { type Item = Result; #[inline] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { - match futures_util::ready!(self.as_mut().poll_frame(cx)?) { + match futures_util::ready!(Pin::new(&mut self.inner).poll_frame(cx)?) { Some(frame) => match frame.into_data() { Ok(data) => return Poll::Ready(Some(Ok(data))), Err(_frame) => {} @@ -134,6 +152,29 @@ impl Stream for Body { } } +impl http_body::Body for BodyDataStream { + type Data = Bytes; + type Error = Error; + + #[inline] + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Pin::new(&mut self.inner).poll_frame(cx) + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } +} + pin_project! { struct StreamBody { #[pin] diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs index 2c86f6f085..8c78a77722 100644 --- a/axum-extra/src/extract/multipart.rs +++ b/axum-extra/src/extract/multipart.rs @@ -100,7 +100,7 @@ where async fn from_request(req: Request, _state: &S) -> Result { let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?; let stream = req.with_limited_body().into_body(); - let multipart = multer::Multipart::new(stream, boundary); + let multipart = multer::Multipart::new(stream.into_data_stream(), boundary); Ok(Self { inner: multipart }) } } diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index f7bc506d1b..d72c23b6c6 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -111,8 +111,8 @@ where // `Stream::lines` isn't a thing so we have to convert it into an `AsyncRead` // so we can call `AsyncRead::lines` and then convert it back to a `Stream` let body = req.into_body(); - - let stream = TryStreamExt::map_err(body, |err| io::Error::new(io::ErrorKind::Other, err)); + let stream = body.into_data_stream(); + let stream = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err)); let read = StreamReader::new(stream); let lines_stream = LinesStream::new(read.lines()); diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 249c0aae2d..227e983a4b 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -72,7 +72,7 @@ where async fn from_request(req: Request, _state: &S) -> Result { let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?; let stream = req.with_limited_body().into_body(); - let multipart = multer::Multipart::new(stream, boundary); + let multipart = multer::Multipart::new(stream.into_data_stream(), boundary); Ok(Self { inner: multipart }) } } diff --git a/examples/stream-to-file/src/main.rs b/examples/stream-to-file/src/main.rs index f3263c4deb..a595d0d834 100644 --- a/examples/stream-to-file/src/main.rs +++ b/examples/stream-to-file/src/main.rs @@ -53,7 +53,7 @@ async fn save_request_body( Path(file_name): Path, request: Request, ) -> Result<(), (StatusCode, String)> { - stream_to_file(&file_name, request.into_body()).await + stream_to_file(&file_name, request.into_body().into_data_stream()).await } // Handler that returns HTML for a multipart form.