Skip to content

Commit

Permalink
refactor: remove select macro of interruption handling (#291)
Browse files Browse the repository at this point in the history
* refactor: handle interrupt without `select!` macro -> clean up deps and features

* refactor features flag: `"__rt__", `"__rt_native__"` -> just `"__rt_native"` and def `__rt_native = ["__rt__", ...`
  • Loading branch information
kanarus authored Nov 28, 2024
1 parent 22d4ca0 commit dff904c
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 113 deletions.
51 changes: 37 additions & 14 deletions ohkami/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ features = ["rt_tokio", "nightly", "sse", "ws"]
ohkami_lib = { version = "=0.21.0", path = "../ohkami_lib" }
ohkami_macros = { version = "=0.21.0", path = "../ohkami_macros" }

tokio = { version = "1", optional = true, features = ["rt", "net", "time"] }
tokio = { version = "1", optional = true }
async-std = { version = "1", optional = true }
smol = { version = "2", optional = true }
nio = { version = "0.0", optional = true }
Expand All @@ -39,28 +39,51 @@ sha2 = { version = "0.10", default-features = false }

ctrlc = { version = "3.4", optional = true }
num_cpus = { version = "1.16", optional = true }
futures-util = { version = "0.3", optional = true, default-features = false, features = ["io", "async-await-macro"] }
futures-util = { version = "0.3", optional = true, default-features = false }
mews = { version = "0.2", optional = true }


[features]
rt_tokio = ["__rt__", "__rt_native__", "dep:tokio", "tokio/io-util", "tokio/macros", "mews?/rt_tokio" ]
rt_async-std = ["__rt__", "__rt_native__", "dep:async-std", "dep:futures-util", "mews?/rt_async-std"]
rt_smol = ["__rt__", "__rt_native__", "dep:smol", "dep:futures-util", "mews?/rt_smol" ]
rt_nio = ["__rt__", "__rt_native__", "dep:nio", "dep:tokio", "tokio/io-util", "mews?/rt_nio" ]
rt_glommio = ["__rt__", "__rt_native__", "dep:glommio", "dep:futures-util", "dep:num_cpus", "mews?/rt_glommio" ]
rt_worker = ["__rt__", "dep:worker", "ohkami_macros/worker"]

nightly = []
sse = ["ohkami_lib/stream"]
ws = ["ohkami_lib/stream", "dep:mews"]
rt_tokio = ["__rt_native__",
"dep:tokio","tokio/rt","tokio/net","tokio/time",
"tokio/io-util",
"mews?/rt_tokio",
]
rt_async-std = ["__rt_native__",
"dep:async-std",
"dep:futures-util","futures-util/io",
"mews?/rt_async-std",
]
rt_smol = ["__rt_native__",
"dep:smol",
"dep:futures-util","futures-util/io",
"mews?/rt_smol",
]
rt_nio = ["__rt_native__",
"dep:nio",
"dep:tokio","tokio/io-util",
"mews?/rt_nio"
]
rt_glommio = ["__rt_native__",
"dep:glommio",
"dep:futures-util","futures-util/io",
"mews?/rt_glommio",
"dep:num_cpus",
]
rt_worker = ["__rt__",
"dep:worker",
"ohkami_macros/worker"
]
nightly = []
sse = ["ohkami_lib/stream"]
ws = ["ohkami_lib/stream", "dep:mews"]

##### internal #####
__rt__ = []
__rt_native__ = ["dep:ctrlc"]
__rt_native__ = ["__rt__", "dep:ctrlc"]

##### DEBUG #####
DEBUG = ["tokio?/rt-multi-thread"]
DEBUG = ["tokio?/rt-multi-thread", "tokio?/macros"]
#default = [
# "nightly",
# "sse",
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/fang/builtin/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl<Inner: FangProc> FangProc for CORSProc<Inner> {


#[cfg(debug_assertions)]
#[cfg(feature="rt_tokio")]
#[cfg(all(feature="rt_tokio", feature="DEBUG"))]
#[cfg(test)]
mod test {
use crate::prelude::*;
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/fang/builtin/jwt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ impl<Payload: for<'de> Deserialize<'de>> JWT<Payload> {


#[cfg(debug_assertions)]
#[cfg(feature="rt_tokio")]
#[cfg(all(feature="rt_tokio", feature="DEBUG"))]
#[cfg(test)] mod test {
use super::{JWT, JWTToken};
use crate::__rt__::test;
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/fang/builtin/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ const _: () = {
};


#[cfg(all(test, debug_assertions, feature="rt_tokio"))]
#[cfg(all(test, debug_assertions, feature="rt_tokio", feature="DEBUG"))]
#[crate::__rt__::test] async fn test_timeout() {
use crate::prelude::*;
use crate::testing::*;
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/fang/middleware/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub trait FangAction: Clone + Send + Sync + 'static {



#[cfg(all(test, debug_assertions, feature="rt_tokio"))]
#[cfg(all(test, debug_assertions, feature="rt_tokio", feature="DEBUG"))]
mod test {
use super::*;
use crate::prelude::*;
Expand Down
40 changes: 20 additions & 20 deletions ohkami/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ compile_error! {"
#[cfg(feature="__rt_native__")]
mod __rt__ {
#[cfg(test)]
#[cfg(feature="rt_tokio")]
#[cfg(all(feature="rt_tokio", feature="DEBUG"))]
pub(crate) use tokio::test;

#[cfg(feature="rt_tokio")]
Expand Down Expand Up @@ -107,25 +107,25 @@ mod __rt__ {
#[cfg(feature="rt_glommio")]
pub(crate) use futures_util::AsyncWriteExt as AsyncWrite;

#[cfg(feature="rt_tokio")]
pub(crate) use tokio::select;
#[cfg(feature="rt_async-std")]
pub(crate) use futures_util::select;
#[cfg(feature="rt_smol")]
pub(crate) use futures_util::select;
#[cfg(feature="rt_nio")]
pub(crate) use tokio::select;
#[cfg(feature="rt_glommio")]
pub(crate) use futures_util::select;

#[cfg(any(feature="rt_tokio", feature="rt_nio"))]
pub(crate) const fn selectable<F: std::future::Future>(future: F) -> F {
future
}
#[cfg(any(feature="rt_async-std", feature="rt_smol", feature="rt_glommio"))]
pub(crate) fn selectable<F: std::future::Future>(future: F) -> ::futures_util::future::Fuse<F> {
::futures_util::FutureExt::fuse(future)
}
// #[cfg(feature="rt_tokio")]
// pub(crate) use tokio::select;
// #[cfg(feature="rt_async-std")]
// pub(crate) use futures_util::select;
// #[cfg(feature="rt_smol")]
// pub(crate) use futures_util::select;
// #[cfg(feature="rt_nio")]
// pub(crate) use tokio::select;
// #[cfg(feature="rt_glommio")]
// pub(crate) use futures_util::select;
//
// #[cfg(any(feature="rt_tokio", feature="rt_nio"))]
// pub(crate) const fn selectable<F: std::future::Future>(future: F) -> F {
// future
// }
// #[cfg(any(feature="rt_async-std", feature="rt_smol", feature="rt_glommio"))]
// pub(crate) fn selectable<F: std::future::Future>(future: F) -> ::futures_util::future::Fuse<F> {
// ::futures_util::FutureExt::fuse(future)
// }

#[cfg(any(feature="rt_tokio", feature="rt_async-std", feature="rt_smol", feature="rt_nio"))]
mod task {
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/ohkami/_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(non_snake_case)]
#![cfg(debug_assertions)]
#![cfg(feature="rt_tokio")] // for `#[__rt__::test]`
#![cfg(all(feature="rt_tokio", feature="DEBUG"))] // for `#[__rt__::test]`

use crate::__rt__;
use crate::prelude::*;
Expand Down
138 changes: 68 additions & 70 deletions ohkami/src/ohkami/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,47 +305,40 @@ impl Ohkami {
/// ```
pub async fn howl(self, address: impl __rt__::ToSocketAddrs) {
let router = Arc::new(self.into_router().finalize());

let listener = __rt__::bind(address).await;

let (wg, inturrupt) = (sync::WaitGroup::new(), sync::CtrlC::new());
let (wg, ctrl_c) = (sync::WaitGroup::new(), sync::CtrlC::new());

loop {
__rt__::select! {
accept = __rt__::selectable(listener.accept()) => {
let (connection, addr) = {
#[cfg(any(feature="rt_tokio", feature="rt_async-std", feature="rt_smol", feature="rt_nio"))] {
let Ok((connection, addr)) = accept else {continue};
(connection, addr)
}
#[cfg(any(feature="rt_glommio"))] {
let Ok(connection) = accept else {continue};
let Ok(addr) = connection.peer_addr() else {continue};
(connection, addr)
}
};

let session = Session::new(
router.clone(),
connection,
addr.ip()
);

let wg = wg.add();
__rt__::spawn(async move {
session.manage().await;
wg.done();
});
while let Some(accept) = ctrl_c.until_interrupt(listener.accept()).await {
let (connection, addr) = {
#[cfg(any(feature="rt_tokio", feature="rt_async-std", feature="rt_smol", feature="rt_nio"))] {
let Ok((connection, addr)) = accept else {continue};
(connection, addr)
}
_ = __rt__::selectable(inturrupt.catch()) => {
crate::DEBUG!("Recieved Ctrl-C, trying graceful shutdown...");
drop(listener);
break
#[cfg(any(feature="rt_glommio"))] {
let Ok(connection) = accept else {continue};
let Ok(addr) = connection.peer_addr() else {continue};
(connection, addr)
}
}
};

let session = Session::new(
router.clone(),
connection,
addr.ip()
);

let wg = wg.add();
__rt__::spawn(async move {
session.manage().await;
wg.done();
});
}

crate::DEBUG!("Waiting {} session(s) to finish...", wg.count());
crate::DEBUG!("interrupted, trying graceful shutdown...");
drop(listener);

crate::DEBUG!("waiting {} session(s) to finish...", wg.count());
wg.await;
}

Expand Down Expand Up @@ -467,41 +460,6 @@ mod sync {
static CATCH: AtomicBool = AtomicBool::new(false);

impl CtrlC {
pub fn catch(&self) -> impl Future<Output = ()> + 'static {
return Inturrupt;

struct Inturrupt;
impl Future for Inturrupt {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if CATCH.load(Ordering::SeqCst) {
crate::DEBUG!("[CtrlC::catch] Ready");
Poll::Ready(())
} else {
#[cfg(any(feature="rt_tokio", feature="rt_async-std", feature="rt_smol", feature="rt_nio"))] {
let prev_waker = WAKER.swap(
Box::into_raw(Box::new(cx.waker().clone())),
Ordering::SeqCst
);
if !prev_waker.is_null() {
unsafe {prev_waker.drop_in_place()}
}
}
#[cfg(any(feature="rt_glommio"))] {
let current_id = glommio::executor().id();
let current_waker = cx.waker().clone();
let mut lock = WAKER.lock().unwrap();
match lock.iter_mut().find(|(id, _)| (*id == current_id)) {
Some(prev) => *prev = (current_id, current_waker),
None => lock.push((current_id, current_waker)),
}
}
Poll::Pending
}
}
}
}

pub fn new() -> Self {
#[cfg(any(feature="rt_tokio", feature="rt_async-std", feature="rt_smol", feature="rt_nio"))]
::ctrlc::set_handler(|| {
Expand All @@ -516,14 +474,54 @@ mod sync {
::ctrlc::try_set_handler(|| {
CATCH.store(true, Ordering::SeqCst);
let lock = &mut *WAKER.lock().unwrap();
crate::DEBUG!("Finally {} executors on {} CPU", lock.len(), num_cpus::get());
crate::DEBUG!("Finally {} executors on {} CPU(s)", lock.len(), num_cpus::get());
for (_, w) in std::mem::take(lock) {
w.wake();
}
}).ok();

Self
}

pub fn until_interrupt<T>(&self, task: impl Future<Output = T>) -> impl Future<Output = Option<T>> {
return UntilInterrupt(task);

struct UntilInterrupt<F: Future>(F);
impl<F: Future> Future for UntilInterrupt<F> {
type Output = Option<F::Output>;

#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match unsafe {Pin::new_unchecked(&mut self.get_unchecked_mut().0)}.poll(cx) {
Poll::Ready(t) => Poll::Ready(Some(t)),
Poll::Pending => if CATCH.load(Ordering::SeqCst) {
crate::DEBUG!("[CtrlC::catch] Ready");
Poll::Ready(None)
} else {
#[cfg(any(feature="rt_tokio", feature="rt_async-std", feature="rt_smol", feature="rt_nio"))] {
let prev_waker = WAKER.swap(
Box::into_raw(Box::new(cx.waker().clone())),
Ordering::SeqCst
);
if !prev_waker.is_null() {
unsafe {prev_waker.drop_in_place()}
}
}
#[cfg(any(feature="rt_glommio"))] {
let current_id = glommio::executor().id();
let current_waker = cx.waker().clone();
let mut lock = WAKER.lock().unwrap();
match lock.iter_mut().find(|(id, _)| (*id == current_id)) {
Some(prev) => *prev = (current_id, current_waker),
None => lock.push((current_id, current_waker)),
}
}
Poll::Pending
}
}
}
}
}
}
};
}
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/request/_test_extract.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![cfg(debug_assertions)]
#![cfg(feature="rt_tokio")]
#![cfg(all(feature="rt_tokio", feature="DEBUG"))]

use crate::prelude::*;
use crate::testing::*;
Expand Down
2 changes: 1 addition & 1 deletion ohkami/src/request/_test_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn parse_path() {
assert_eq!(&*path, "/");
}

#[cfg(feature="rt_tokio")]
#[cfg(all(feature="rt_tokio", feature="DEBUG"))]
#[crate::__rt__::test] async fn test_parse_request() {
use super::{RequestHeader, RequestHeaders};
use std::pin::Pin;
Expand Down
3 changes: 1 addition & 2 deletions ohkami/src/response/_test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#![cfg(feature="rt_tokio")]
#![cfg(all(feature="rt_tokio", feature="DEBUG"))]

use crate::Response;


macro_rules! assert_bytes_eq {
($res:expr, $expected:expr) => {
{
Expand Down

0 comments on commit dff904c

Please sign in to comment.