Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare serve for potentially supporting graceful shutdown #2357

Merged
merged 2 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion axum/benches/benches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use axum::{
};
use serde::{Deserialize, Serialize};
use std::{
future::IntoFuture,
io::BufRead,
process::{Command, Stdio},
};
Expand Down Expand Up @@ -161,7 +162,8 @@ impl BenchmarkBuilder {
let addr = listener.local_addr().unwrap();

std::thread::spawn(move || {
rt.block_on(axum::serve(listener, app)).unwrap();
rt.block_on(axum::serve(listener, app).into_future())
.unwrap();
});

let mut cmd = Command::new("rewrk");
Expand Down
156 changes: 119 additions & 37 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

use std::{
convert::Infallible,
future::Future,
future::{Future, IntoFuture},
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
Expand Down Expand Up @@ -86,48 +87,129 @@ use tower_service::Service;
/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub async fn serve<M, S>(tcp_listener: TcpListener, mut make_service: M) -> io::Result<()>
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
loop {
let (tcp_stream, remote_addr) = tcp_listener.accept().await?;
let tcp_stream = TokioIo::new(tcp_stream);

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});

let hyper_service = TowerToHyperService {
service: tower_service,
};

tokio::task::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
.await
{
Ok(()) => {}
Err(_err) => {
// This error only appears when the client doesn't send a request and
// terminate the connection.
//
// If client sends one request then terminate connection whenever, it doesn't
// appear.
}
Serve {
tcp_listener,
make_service,
_marker: PhantomData,
}
}

/// Future returned by [`serve`].
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub struct Serve<M, S> {
tcp_listener: TcpListener,
make_service: M,
_marker: PhantomData<S>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> std::fmt::Debug for Serve<M, S>
where
M: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
make_service,
_marker: _,
} = self;

f.debug_struct("Serve")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.finish()
}
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> IntoFuture for Serve<M, S>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;

fn into_future(self) -> Self::IntoFuture {
private::ServeFuture(Box::pin(async move {
let Self {
tcp_listener,
mut make_service,
_marker: _,
} = self;

loop {
let (tcp_stream, remote_addr) = tcp_listener.accept().await?;
let tcp_stream = TokioIo::new(tcp_stream);

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});

let hyper_service = TowerToHyperService {
service: tower_service,
};

tokio::task::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
.await
{
Ok(()) => {}
Err(_err) => {
// This error only appears when the client doesn't send a request and
// terminate the connection.
//
// If client sends one request then terminate connection whenever, it doesn't
// appear.
}
}
});
}
});
}))
}
}

mod private {
use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};

pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);

impl Future for ServeFuture {
type Output = io::Result<()>;

#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}

impl std::fmt::Debug for ServeFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServeFuture").finish_non_exhaustive()
}
}
}

Expand Down
7 changes: 5 additions & 2 deletions examples/testing-websockets/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ where
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddr};
use std::{
future::IntoFuture,
net::{Ipv4Addr, SocketAddr},
};
use tokio_tungstenite::tungstenite;

// We can integration test one handler by running the server in a background task and
Expand All @@ -103,7 +106,7 @@ mod tests {
.await
.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(axum::serve(listener, app()));
tokio::spawn(axum::serve(listener, app()).into_future());

let (mut socket, _response) =
tokio_tungstenite::connect_async(format!("ws://{addr}/integration-testable"))
Expand Down
Loading