From f9882e04778692d97b5ca48ab82eaccfb605e1cc Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 14:33:21 +0900 Subject: [PATCH 01/15] @2024-07-19 14:33+9:00 --- ohkami/Cargo.toml | 26 ++--- ohkami/src/lib.rs | 3 + ohkami/src/response/content.rs | 6 ++ ohkami/src/response/mod.rs | 63 +++++++++-- ohkami/src/session/mod.rs | 29 ++--- ohkami/src/websocket/context.rs | 142 ------------------------ ohkami/src/websocket/frame.rs | 2 +- ohkami/src/websocket/message.rs | 2 +- ohkami/src/websocket/mod.rs | 174 +++++++++++++++++++++++++++--- ohkami/src/websocket/session.rs | 74 +++++++++++++ ohkami/src/websocket/upgrade.rs | 131 ---------------------- ohkami/src/websocket/websocket.rs | 99 ----------------- 12 files changed, 328 insertions(+), 423 deletions(-) delete mode 100644 ohkami/src/websocket/context.rs create mode 100644 ohkami/src/websocket/session.rs delete mode 100644 ohkami/src/websocket/upgrade.rs delete mode 100644 ohkami/src/websocket/websocket.rs diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 826164db..4d734f32 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -33,11 +33,11 @@ rustc-hash = { version = "2.0" } hmac = { version = "0.12", default-features = false } sha2 = { version = "0.10", default-features = false } -#sha1 = { version = "0.10", optional = true, default-features = false } +sha1 = { version = "0.10", optional = true, default-features = false } [features] -default = ["testing"] +#default = ["testing"] rt_tokio = ["dep:tokio"] rt_async-std = ["dep:async-std"] @@ -45,8 +45,8 @@ rt_worker = ["dep:worker", "ohkami_macros/worker"] nightly = [] testing = [] -#websocket = ["dep:sha1"] sse = ["ohkami_lib/stream"] +ws = ["dep:sha1"] ##### DEBUG ##### DEBUG = [ @@ -54,13 +54,13 @@ DEBUG = [ "tokio?/rt-multi-thread", "async-std?/attributes", ] -#default = [ -# "nightly", -# "testing", -# "sse", -# #"websocket", -# "rt_tokio", -# #"rt_async-std", -# #"rt_worker", -# "DEBUG", -#] \ No newline at end of file +default = [ + "nightly", + "testing", + "sse", + "ws", + "rt_tokio", + #"rt_async-std", + #"rt_worker", + "DEBUG", +] \ No newline at end of file diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 29e422bc..6d3e27ab 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -105,6 +105,9 @@ pub mod builtin; pub mod typed; +#[cfg(feature="ws")] +pub mod websocket; + #[cfg(feature="testing")] #[cfg(any(feature="rt_tokio",feature="rt_async-std",feature="rt_worker"))] pub mod testing; diff --git a/ohkami/src/response/content.rs b/ohkami/src/response/content.rs index fc969cae..215d8611 100644 --- a/ohkami/src/response/content.rs +++ b/ohkami/src/response/content.rs @@ -11,6 +11,9 @@ pub enum Content { #[cfg(feature="sse")] Stream(std::pin::Pin> + Send>>), + + #[cfg(feature="ws")] + WebSocket(crate::websocket::Handler), } const _: () = { impl Default for Content { fn default() -> Self { @@ -42,6 +45,9 @@ pub enum Content { #[cfg(feature="sse")] Self::Stream(_) => f.write_str("{stream}"), + + #[cfg(feature="ws")] + Self::WebSocket(_) => f.write_str("{websocket}"), } } } diff --git a/ohkami/src/response/mod.rs b/ohkami/src/response/mod.rs index ac4e231f..2b804558 100644 --- a/ohkami/src/response/mod.rs +++ b/ohkami/src/response/mod.rs @@ -142,7 +142,11 @@ impl Response { #[cfg(feature="sse")] Content::Stream(_) => self.headers.set() - .ContentLength(None) + .ContentLength(None), + + #[cfg(feature="ws")] + Content::WebSocket(_) => self.headers.set() + .ContentLength(None), }; } } @@ -150,7 +154,7 @@ impl Response { #[cfg(any(feature="rt_tokio",feature="rt_async-std"))] impl Response { #[cfg_attr(not(feature="sse"), inline)] - pub(crate) async fn send(mut self, conn: &mut (impl AsyncWriter + Unpin)) { + pub(crate) async fn send(mut self, conn: &mut (impl AsyncWriter + Unpin + 'static)) { self.complete(); match self.content { @@ -162,8 +166,8 @@ impl Response { crate::push_unchecked!(buf <- self.status.line()); self.headers.write_unchecked_to(&mut buf); } - conn.write_all(&buf).await.expect("Failed to send response"); + conn.flush().await.expect("Failed to flush TCP connection"); } Content::Payload(bytes) => { @@ -176,8 +180,8 @@ impl Response { self.headers.write_unchecked_to(&mut buf); crate::push_unchecked!(buf <- bytes); } - conn.write_all(&buf).await.expect("Failed to send response"); + conn.flush().await.expect("Failed to flush TCP connection"); } #[cfg(feature="sse")] @@ -189,8 +193,9 @@ impl Response { crate::push_unchecked!(buf <- self.status.line()); self.headers.write_unchecked_to(&mut buf); } - conn.write_all(&buf).await.expect("Failed to send response"); + conn.flush().await.expect("Failed to flush TCP connection"); + while let Some(chunk) = stream.next().await { match chunk { Err(msg) => { @@ -225,10 +230,37 @@ impl Response { } } conn.write_all(b"0\r\n\r\n").await.expect("Failed to send response"); + conn.flush().await.expect("Failed to flush TCP connection"); } - } - conn.flush().await.expect("Failed to flush TCP connection"); + #[cfg(feature="ws")] + Content::WebSocket(handler) => { + let mut buf = Vec::::with_capacity( + self.status.line().len() + + self.headers.size + ); unsafe { + crate::push_unchecked!(buf <- self.status.line()); + self.headers.write_unchecked_to(&mut buf); + } + conn.write_all(&buf).await.expect("Failed to send response"); + conn.flush().await.expect("Failed to flush TCP connection"); + + /* this doesn't match in testing */ + if let Some(tcp_stream) = ::downcast_mut::(conn) { + use crate::websocket::{Session, Config}; + + /* FIXME: make Config configurable */ + let ws = Session::new(tcp_stream, Config::default()); + + crate::__rt__::task::spawn({ + let h = handler(ws); + async move { + h.await + } + }); + } + } + } } } @@ -368,6 +400,14 @@ impl Response { } } +#[cfg(feature="ws")] +impl Response { + pub(crate) fn with_websocket(mut self, handler: crate::websocket::Handler) -> Self { + self.content = Content::WebSocket(handler); + self + } +} + const _: () = { impl std::fmt::Debug for Response { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -380,7 +420,7 @@ const _: () = { Content::Payload(bytes) => Content::Payload(bytes.clone()), #[cfg(feature="sse")] - Content::Stream(_) => Content::Stream(Box::pin({ + Content::Stream(_) => Content::Stream(Box::pin({ struct DummyStream; impl ohkami_lib::Stream for DummyStream { type Item = Result; @@ -389,7 +429,12 @@ const _: () = { } } DummyStream - })) + })), + + #[cfg(feature="ws")] + Content::WebSocket(_) => Content::WebSocket(Box::new({ + |_| Box::pin(async {/* dummy handler */}) + })), } }; this.complete(); diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 7a1d822a..69e0220c 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -36,21 +36,24 @@ impl Session { crate::Response::InternalServerError() } - /* async-std doesn't provide split */ - #[cfg(feature="rt_tokio")] - let (mut r, mut w) = self.connection.split(); - #[cfg(feature="rt_async-std")] - let c = &mut self.connection; + // /* async-std doesn't provide split */ + // #[cfg(feature="rt_tokio")] + // let (mut r, mut w) = self.connection.split(); + // #[cfg(feature="rt_async-std")] + // let c = &mut self.connection; - #[cfg(feature="rt_tokio")] - macro_rules! read {($req:ident) => {$req.as_mut().read(&mut r)};} - #[cfg(feature="rt_async-std")] - macro_rules! read {($req:ident) => {$req.as_mut().read(c)};} + // #[cfg(feature="rt_tokio")] + // macro_rules! read {($req:ident) => {$req.as_mut().read(&mut r)};} + // #[cfg(feature="rt_async-std")] + // macro_rules! read {($req:ident) => {$req.as_mut().read(c)};} - #[cfg(feature="rt_tokio")] - macro_rules! send {($res:ident) => {$res.send(&mut w)};} - #[cfg(feature="rt_async-std")] - macro_rules! send {($res:ident) => {$res.send(c)};} + // #[cfg(feature="rt_tokio")] + // macro_rules! send {($res:ident) => {$res.send(&mut w)};} + // #[cfg(feature="rt_async-std")] + // macro_rules! send {($res:ident) => {$res.send(c)};} + + macro_rules! read {($req:ident) => {$req.as_mut().read(&mut self.connection)};} + macro_rules! send {($res:ident) => {$res.send(&mut self.connection)};} timeout_in(std::time::Duration::from_secs(crate::env::OHKAMI_KEEPALIVE_TIMEOUT()), async { loop { diff --git a/ohkami/src/websocket/context.rs b/ohkami/src/websocket/context.rs deleted file mode 100644 index 3371924d..00000000 --- a/ohkami/src/websocket/context.rs +++ /dev/null @@ -1,142 +0,0 @@ -use std::{future::Future, borrow::Cow}; -use super::{assume_upgradable, UpgradeID}; -use super::websocket::Config; -use super::{WebSocket}; -use crate::{Response, Request}; -use crate::__rt__::{task}; -use crate::http::{Method}; -use crate::layer0_lib::{base64}; - - -pub struct WebSocketContext { - id: Option, - config: Config, - - on_failed_upgrade: UFH, - - sec_websocket_key: Cow<'static, str>, - selected_protocol: Option>, - sec_websocket_protocol: Option>, -} - -pub trait UpgradeFailureHandler { - fn handle(self, error: UpgradeError); -} -pub enum UpgradeError { - NotRequestedUpgrade, -} -pub struct DefaultUpgradeFailureHandler; -impl UpgradeFailureHandler for DefaultUpgradeFailureHandler { - fn handle(self, _: UpgradeError) {/* discard error */} -} - -impl WebSocketContext { - pub(crate) fn new(req: &mut Request) -> Result { - if req.method != Method::GET { - return Err((|| Response::BadRequest().text("Method is not `GET`"))()) - } - if req.headers.Connection() != Some("upgrade") { - return Err((|| Response::BadRequest().text("Connection header is not `upgrade`"))()) - } - if req.headers.Upgrade() != Some("websocket") { - return Err((|| Response::BadRequest().text("Upgrade header is not `websocket`"))()) - } - if req.headers.SecWebSocketVersion() != Some("13") { - return Err((|| Response::BadRequest().text("Sec-WebSocket-Version header is not `13`"))()) - } - - let sec_websocket_key = Cow::Owned(req.headers.SecWebSocketKey() - .ok_or_else(|| Response::BadRequest().text("Sec-WebSocket-Key header is missing"))? - .to_string()); - - let sec_websocket_protocol = req.headers.SecWebSocketProtocol() - .map(|swp| Cow::Owned(swp.to_string())); - - Ok(Self { - id: req.upgrade_id, - config: Config::default(), - on_failed_upgrade: DefaultUpgradeFailureHandler, - selected_protocol: None, - sec_websocket_key, - sec_websocket_protocol, - }) - } -} - -impl WebSocketContext { - pub fn write_buffer_size(mut self, size: usize) -> Self { - self.config.write_buffer_size = size; - self - } - pub fn max_write_buffer_size(mut self, size: usize) -> Self { - self.config.max_write_buffer_size = size; - self - } - pub fn max_message_size(mut self, size: usize) -> Self { - self.config.max_message_size = Some(size); - self - } - pub fn max_frame_size(mut self, size: usize) -> Self { - self.config.max_frame_size = Some(size); - self - } - pub fn accept_unmasked_frames(mut self) -> Self { - self.config.accept_unmasked_frames = true; - self - } - - pub fn protocols>>(mut self, protocols: impl Iterator) -> Self { - if let Some(req_protocols) = &self.sec_websocket_protocol { - self.selected_protocol = protocols.map(Into::into) - .find(|p| req_protocols.split(',').any(|req_p| req_p.trim() == p)) - } - self - } -} - -impl WebSocketContext { - pub fn on_upgrade + Send + 'static>( - self, - handler: impl Fn(WebSocket) -> Fut + Send + Sync + 'static - ) -> Response { - #[inline] fn sign(sec_websocket_key: &str) -> String { - use ::sha1::{Sha1, Digest}; - - let mut sha1 = ::new(); - sha1.update(sec_websocket_key.as_bytes()); - sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - base64::encode(sha1.finalize()) - } - - let Self { - config, - on_failed_upgrade, - selected_protocol, - sec_websocket_key, - .. - } = self; - - task::spawn({ - async move { - let stream = match self.id { - None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), - Some(id) => assume_upgradable(id).await, - }; - - let ws = WebSocket::new(stream, config); - handler(ws).await - } - }); - - let mut handshake_res = Response::SwitchingProtocols(); - handshake_res.headers.set() - .Connection("Update") - .Upgrade("websocket") - .SecWebSocketAccept(sign(&sec_websocket_key)); - if let Some(protocol) = selected_protocol { - handshake_res.headers.set() - .SecWebSocketProtocol(protocol.to_string()); - } - handshake_res - } -} diff --git a/ohkami/src/websocket/frame.rs b/ohkami/src/websocket/frame.rs index 71221d22..3128611d 100644 --- a/ohkami/src/websocket/frame.rs +++ b/ohkami/src/websocket/frame.rs @@ -1,6 +1,6 @@ use std::io::{Error, ErrorKind}; use crate::__rt__::{AsyncReader, AsyncWriter}; -use super::websocket::Config; +use super::Config; #[derive(PartialEq)] diff --git a/ohkami/src/websocket/message.rs b/ohkami/src/websocket/message.rs index b119f147..4123fb0e 100644 --- a/ohkami/src/websocket/message.rs +++ b/ohkami/src/websocket/message.rs @@ -1,6 +1,6 @@ use std::{borrow::Cow, io::{Error, ErrorKind}}; use crate::{__rt__::{AsyncReader, AsyncWriter}}; -use super::{frame::{Frame, OpCode, CloseCode}, websocket::Config}; +use super::{frame::{Frame, OpCode, CloseCode}, Config}; const PING_PONG_PAYLOAD_LIMIT: usize = 125; diff --git a/ohkami/src/websocket/mod.rs b/ohkami/src/websocket/mod.rs index f8ed0a89..38a36935 100644 --- a/ohkami/src/websocket/mod.rs +++ b/ohkami/src/websocket/mod.rs @@ -1,20 +1,166 @@ -#[cfg(not(target_pointer_width = "64"))] -compile_error!{ "pointer width must be 64" } +#![cfg(all( + feature="ws", +))] -mod websocket; -mod context; +mod session; mod message; -mod upgrade; mod frame; -pub use { - message::{Message}, - websocket::{WebSocket}, - context::{WebSocketContext, UpgradeError}, -}; -pub(crate) use { - upgrade::{UpgradeID, request_upgrade_id, reserve_upgrade}, +pub use message::Message; +pub(crate) use session::WebSocket as Session; + +use ohkami_lib::base64; +use std::{future::Future, borrow::Cow, pin::Pin}; +use crate::{FromRequest, IntoResponse, Method, Request, Response}; +use crate::__rt__::{task, AsyncReader, AsyncWriter, TcpStream}; + + +// #[derive(Clone)] +pub struct WebSocketContext<'req> { + sec_websocket_key: &'req str, +} const _: () = { + impl<'req> FromRequest<'req> for WebSocketContext<'req> { + type Error = std::convert::Infallible; + + fn from_request(req: &'req Request) -> Option> { + req.headers.SecWebSocketKey().map(|swk| Ok(Self { + sec_websocket_key: swk, + })) + } + } + + impl<'ws> WebSocketContext<'ws> { + pub fn connect + Send + 'static>(self, + handler: impl Fn(Session<'ws, TcpStream>) -> Fut + Send + Sync + 'static + ) -> WebSocket { + #[inline] fn signed(sec_websocket_key: &str) -> String { + use ::sha1::{Sha1, Digest}; + let mut sha1 = ::new(); + sha1.update(sec_websocket_key.as_bytes()); + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + base64::encode(sha1.finalize()) + } + + WebSocket { + sec_websocket_key: signed(self.sec_websocket_key), + handler: Box::new(move |ws| Box::pin({ + let h = handler(unsafe {std::mem::transmute::<_, Session<'ws, _>>(ws)}); + async {h.await} + })) + } + } + } }; -pub(crate) use upgrade::{ - assume_upgradable, + +pub(crate) type Handler = Box) -> Pin + Send + '_>> + + Send + Sync + 'static +>; + +// #[derive(Clone)] +pub struct WebSocket { + sec_websocket_key: String, + handler: Handler, +} impl IntoResponse for WebSocket { + fn into_response(self) -> Response { + Response::SwitchingProtocols().with_headers(|h|h + .Connection("Update") + .Upgrade("websocket") + .SecWebSocketAccept(self.sec_websocket_key) + ).with_websocket(self.handler) + } +} + +/// ## Note +/// +/// - Currently, subprotocols with `Sec-WebSocket-Protocol` is not supported +//#[derive(Clone)] +pub struct Config { + pub write_buffer_size: usize, + pub max_write_buffer_size: usize, + pub accept_unmasked_frames: bool, + pub max_message_size: Option, + pub max_frame_size: Option, +} const _: () = { + impl Default for Config { + fn default() -> Self { + Self { + write_buffer_size: 128 * 1024, // 128 KiB + max_write_buffer_size: usize::MAX, + accept_unmasked_frames: false, + max_message_size: Some(64 << 20), + max_frame_size: Some(16 << 20), + } + } + } }; + +// impl WebSocket { +// /// shortcut for `WebSocket::with(Config::default())` +// pub fn new + Send>( +// handler: impl Fn(Session<'_, TcpStream>) -> Fut + 'static +// ) -> Self { +// Self::with(Config::default(), handler) +// } +// +// pub fn with + Send>( +// config: Config, +// handler: impl Fn(Session<'_, TcpStream>) -> Fut + 'static +// ) -> Self { +// task::spawn(async move { +// todo!() +// }); +// +// Self { config, handler: } +// } +// } +// +// + +/* +impl WebSocket { + pub fn on_upgrade + Send + 'static>( + self, + handler: impl Fn(WebSocket) -> Fut + Send + Sync + 'static + ) -> Response { + #[inline] fn sign(sec_websocket_key: &str) -> String { + use ::sha1::{Sha1, Digest}; + + let mut sha1 = ::new(); + sha1.update(sec_websocket_key.as_bytes()); + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + base64::encode(sha1.finalize()) + } + + let Self { + config, + selected_protocol, + sec_websocket_key, + .. + } = self; + + task::spawn({ + async move { + let stream = match self.id { + None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), + Some(id) => assume_upgradable(id).await, + }; + + let ws = WebSocket::new(stream, config); + handler(ws).await + } + }); + + let mut handshake_res = Response::SwitchingProtocols(); + handshake_res.headers.set() + .Connection("Update") + .Upgrade("websocket") + .SecWebSocketAccept(sign(&sec_websocket_key)); + if let Some(protocol) = selected_protocol { + handshake_res.headers.set() + .SecWebSocketProtocol(protocol.to_string()); + } + handshake_res + } +} +*/ diff --git a/ohkami/src/websocket/session.rs b/ohkami/src/websocket/session.rs new file mode 100644 index 00000000..ac1f9517 --- /dev/null +++ b/ohkami/src/websocket/session.rs @@ -0,0 +1,74 @@ +use std::io::Error; +use super::{Message, Config}; +use crate::__rt__::{AsyncWriter, AsyncReader}; + + +/* Used only in `ohkami::websocket::WebSocket::{new, with}` and NOT `use`able by user */ +pub struct WebSocket<'ws, Conn: AsyncWriter + AsyncReader + Unpin> { + conn: &'ws mut Conn, + config: Config, + n_buffered: usize, +} + +impl<'ws, Conn: AsyncWriter + AsyncReader + Unpin> WebSocket<'ws, Conn> { + pub(crate) fn new(conn: &'ws mut Conn, config: Config) -> Self { + Self { conn, config, n_buffered:0 } + } +} + +impl<'ws, Conn: AsyncWriter + AsyncReader + Unpin> WebSocket<'ws, Conn> { + pub async fn recv(&mut self) -> Result, Error> { + Message::read_from(self.conn, &self.config).await + } +} + +// ============================================================================= +pub(super) async fn send( + message: Message, + stream: &mut (impl AsyncWriter + Unpin), + config: &Config, + n_buffered: &mut usize, +) -> Result<(), Error> { + message.write(stream, config).await?; + flush(stream, n_buffered).await?; + Ok(()) +} +pub(super) async fn write( + message: Message, + stream: &mut (impl AsyncWriter + Unpin), + config: &Config, + n_buffered: &mut usize, +) -> Result { + let n = message.write(stream, config).await?; + + *n_buffered += n; + if *n_buffered > config.write_buffer_size { + if *n_buffered > config.max_write_buffer_size { + panic!("Buffered messages is larger than `max_write_buffer_size`"); + } else { + flush(stream, n_buffered).await? + } + } + + Ok(n) +} +pub(super) async fn flush( + stream: &mut (impl AsyncWriter + Unpin), + n_buffered: &mut usize, +) -> Result<(), Error> { + stream.flush().await + .map(|_| *n_buffered = 0) +} +// ============================================================================= + +impl<'ws, Conn: AsyncWriter + AsyncReader + Unpin> WebSocket<'ws, Conn> { + pub async fn send(&mut self, message: Message) -> Result<(), Error> { + send(message, &mut self.conn, &self.config, &mut self.n_buffered).await + } + pub async fn write(&mut self, message: Message) -> Result { + write(message, &mut self.conn, &self.config, &mut self.n_buffered).await + } + pub async fn flush(&mut self) -> Result<(), Error> { + flush(&mut self.conn, &mut self.n_buffered).await + } +} diff --git a/ohkami/src/websocket/upgrade.rs b/ohkami/src/websocket/upgrade.rs deleted file mode 100644 index 3be7527f..00000000 --- a/ohkami/src/websocket/upgrade.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::{ - sync::{OnceLock, atomic::{AtomicBool, Ordering}}, - pin::Pin, cell::UnsafeCell, - future::Future, -}; -use crate::__rt__::{TcpStream}; - - -pub async fn request_upgrade_id() -> UpgradeID { - struct ReserveUpgrade; - impl Future for ReserveUpgrade { - type Output = UpgradeID; - fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - let Some(mut streams) = UpgradeStreams().request_reservation() - else {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; - - let id = UpgradeID(match streams.iter().position(|cell| cell.is_empty()) { - Some(i) => i, - None => {streams.push(StreamCell::new()); streams.len() - 1}, - }); - - streams[id.as_usize()].reserved = true; - - std::task::Poll::Ready(id) - } - } - - ReserveUpgrade.await -} - -/// SAFETY: This must be called after the corresponded `request_upgrade_id` -pub unsafe fn reserve_upgrade(id: UpgradeID, stream: TcpStream) { - #[cfg(debug_assertions)] assert!( - UpgradeStreams().get_mut().get(id.as_usize()).is_some_and( - |cell| cell.reserved && cell.stream.is_some()), - "Cell not reserved" - ); - - (UpgradeStreams().get_mut())[id.as_usize()].stream = Some(stream); -} - -pub async fn assume_upgradable(id: UpgradeID) -> TcpStream { - struct AssumeUpgradable{id: UpgradeID} - impl Future for AssumeUpgradable { - type Output = TcpStream; - fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - let Some(StreamCell { reserved, stream }) = (unsafe {UpgradeStreams().get_mut()}).get_mut(self.id.as_usize()) - else {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; - - if stream.is_none() - {cx.waker().wake_by_ref(); return std::task::Poll::Pending}; - - *reserved = false; - - std::task::Poll::Ready(unsafe {stream.take().unwrap_unchecked()}) - } - } - - AssumeUpgradable{id}.await -} - - -static UPGRADE_STREAMS: OnceLock = OnceLock::new(); - -#[allow(non_snake_case)] fn UpgradeStreams() -> &'static UpgradeStreams { - UPGRADE_STREAMS.get_or_init(UpgradeStreams::new) -} - -struct UpgradeStreams { - in_scanning: AtomicBool, - streams: UnsafeCell>>, -} const _: () = { - unsafe impl Sync for UpgradeStreams {} - - impl UpgradeStreams { - fn new() -> Self { - Self { - in_scanning: AtomicBool::new(false), - streams: UnsafeCell::new(Vec::new()), - } - } - unsafe fn get_mut(&self) -> &mut Vec> { - &mut *self.streams.get() - } - fn request_reservation(&self) -> Option> { - self.in_scanning.compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed) - .ok().and(Some(ReservationLock(unsafe {self.get_mut()}))) - } - } - - struct ReservationLock<'scan, Stream = TcpStream>(&'scan mut Vec>); - impl<'scan, Stream> Drop for ReservationLock<'scan, Stream> { - fn drop(&mut self) { - UpgradeStreams().in_scanning.store(false, Ordering::Release) - } - } - impl<'scan, Stream> std::ops::Deref for ReservationLock<'scan, Stream> { - type Target = Vec>; - fn deref(&self) -> &Self::Target {&*self.0} - } - impl<'scan, Stream> std::ops::DerefMut for ReservationLock<'scan, Stream> { - fn deref_mut(&mut self) -> &mut Self::Target {self.0} - } -}; - -struct StreamCell { - reserved: bool, - stream: Option, -} const _: () = { - impl StreamCell { - fn new() -> Self { - Self { - reserved: false, - stream: None, - } - } - fn is_empty(&self) -> bool { - (!self.reserved) && self.stream.is_none() - } - } -}; - -#[derive(Clone, Copy)] -pub struct UpgradeID(usize); -const _: () = { - impl UpgradeID { - fn as_usize(&self) -> usize { - self.0 - } - } -}; diff --git a/ohkami/src/websocket/websocket.rs b/ohkami/src/websocket/websocket.rs deleted file mode 100644 index a01e0cb6..00000000 --- a/ohkami/src/websocket/websocket.rs +++ /dev/null @@ -1,99 +0,0 @@ -use std::io::Error; -use super::{Message}; -use crate::__rt__::{AsyncWriter, AsyncReader, TcpStream}; - -//#[cfg(test)] use crate::layer6_testing::TestStream as Stream; -//#[cfg(not(test))] use crate::__rt__::TcpStream as Stream; - - -/// In current version, `split` to read / write halves is not supported -pub struct WebSocket { - stream: Stream, - config: Config, - - n_buffered: usize, -} - -// :fields may set through `WebSocketContext`'s methods -pub struct Config { - pub(crate) write_buffer_size: usize, - pub(crate) max_write_buffer_size: usize, - pub(crate) max_message_size: Option, - pub(crate) max_frame_size: Option, - pub(crate) accept_unmasked_frames: bool, -} const _: () = { - impl Default for Config { - fn default() -> Self { - Self { - write_buffer_size: 128 * 1024, // 128 KiB - max_write_buffer_size: usize::MAX, - max_message_size: Some(64 << 20), - max_frame_size: Some(16 << 20), - accept_unmasked_frames: false, - } - } - } -}; - -impl WebSocket { - pub(crate) fn new(stream: Stream, config: Config) -> Self { - Self { stream, config, n_buffered:0 } - } -} - -impl WebSocket { - pub async fn recv(&mut self) -> Result, Error> { - Message::read_from(&mut self.stream, &self.config).await - } -} - -// ============================================================================= -pub(super) async fn send( - message: Message, - stream: &mut (impl AsyncWriter + Unpin), - config: &Config, - n_buffered: &mut usize, -) -> Result<(), Error> { - message.write(stream, config).await?; - flush(stream, n_buffered).await?; - Ok(()) -} -pub(super) async fn write( - message: Message, - stream: &mut (impl AsyncWriter + Unpin), - config: &Config, - n_buffered: &mut usize, -) -> Result { - let n = message.write(stream, config).await?; - - *n_buffered += n; - if *n_buffered > config.write_buffer_size { - if *n_buffered > config.max_write_buffer_size { - panic!("Buffered messages is larger than `max_write_buffer_size`"); - } else { - flush(stream, n_buffered).await? - } - } - - Ok(n) -} -pub(super) async fn flush( - stream: &mut (impl AsyncWriter + Unpin), - n_buffered: &mut usize, -) -> Result<(), Error> { - stream.flush().await - .map(|_| *n_buffered = 0) -} -// ============================================================================= - -impl WebSocket { - pub async fn send(&mut self, message: Message) -> Result<(), Error> { - send(message, &mut self.stream, &self.config, &mut self.n_buffered).await - } - pub async fn write(&mut self, message: Message) -> Result { - write(message, &mut self.stream, &self.config, &mut self.n_buffered).await - } - pub async fn flush(&mut self) -> Result<(), Error> { - flush(&mut self.stream, &mut self.n_buffered).await - } -} From f4c11026780b849e8d4e8cd2049d399063c5ff4a Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 14:45:44 +0900 Subject: [PATCH 02/15] @2024-07-19 14:45+9:00 --- ohkami/Cargo.toml | 22 +++++++++++----------- ohkami/src/response/mod.rs | 19 +++++++------------ ohkami/src/websocket/frame.rs | 15 ++++----------- ohkami/src/websocket/message.rs | 2 +- ohkami/src/websocket/mod.rs | 9 ++++----- 5 files changed, 27 insertions(+), 40 deletions(-) diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 4d734f32..0e0cde90 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -37,7 +37,7 @@ sha1 = { version = "0.10", optional = true, default-features = false } [features] -#default = ["testing"] +default = ["testing"] rt_tokio = ["dep:tokio"] rt_async-std = ["dep:async-std"] @@ -54,13 +54,13 @@ DEBUG = [ "tokio?/rt-multi-thread", "async-std?/attributes", ] -default = [ - "nightly", - "testing", - "sse", - "ws", - "rt_tokio", - #"rt_async-std", - #"rt_worker", - "DEBUG", -] \ No newline at end of file +#default = [ +# "nightly", +# "testing", +# "sse", +# "ws", +# "rt_tokio", +# #"rt_async-std", +# #"rt_worker", +# "DEBUG", +#] \ No newline at end of file diff --git a/ohkami/src/response/mod.rs b/ohkami/src/response/mod.rs index 2b804558..80693ba8 100644 --- a/ohkami/src/response/mod.rs +++ b/ohkami/src/response/mod.rs @@ -167,7 +167,7 @@ impl Response { self.headers.write_unchecked_to(&mut buf); } conn.write_all(&buf).await.expect("Failed to send response"); - conn.flush().await.expect("Failed to flush TCP connection"); + conn.flush().await.expect("Failed to flush connection"); } Content::Payload(bytes) => { @@ -181,7 +181,7 @@ impl Response { crate::push_unchecked!(buf <- bytes); } conn.write_all(&buf).await.expect("Failed to send response"); - conn.flush().await.expect("Failed to flush TCP connection"); + conn.flush().await.expect("Failed to flush connection"); } #[cfg(feature="sse")] @@ -194,7 +194,7 @@ impl Response { self.headers.write_unchecked_to(&mut buf); } conn.write_all(&buf).await.expect("Failed to send response"); - conn.flush().await.expect("Failed to flush TCP connection"); + conn.flush().await.expect("Failed to flush connection"); while let Some(chunk) = stream.next().await { match chunk { @@ -225,12 +225,12 @@ impl Response { println!("\n[sending chunk]\n{}", chunk.escape_ascii()); conn.write_all(&chunk).await.expect("Failed to send response"); - conn.flush().await.ok(); + conn.flush().await.expect("Failed to flush connection"); } } } conn.write_all(b"0\r\n\r\n").await.expect("Failed to send response"); - conn.flush().await.expect("Failed to flush TCP connection"); + conn.flush().await.expect("Failed to flush connection"); } #[cfg(feature="ws")] @@ -243,7 +243,7 @@ impl Response { self.headers.write_unchecked_to(&mut buf); } conn.write_all(&buf).await.expect("Failed to send response"); - conn.flush().await.expect("Failed to flush TCP connection"); + conn.flush().await.expect("Failed to flush connection"); /* this doesn't match in testing */ if let Some(tcp_stream) = ::downcast_mut::(conn) { @@ -252,12 +252,7 @@ impl Response { /* FIXME: make Config configurable */ let ws = Session::new(tcp_stream, Config::default()); - crate::__rt__::task::spawn({ - let h = handler(ws); - async move { - h.await - } - }); + handler(ws).await } } } diff --git a/ohkami/src/websocket/frame.rs b/ohkami/src/websocket/frame.rs index 3128611d..7f10cd51 100644 --- a/ohkami/src/websocket/frame.rs +++ b/ohkami/src/websocket/frame.rs @@ -63,7 +63,6 @@ pub enum CloseCode { pub struct Frame { pub is_final: bool, pub opcode: OpCode, - pub mask: Option<[u8; 4]>, pub payload: Vec, } impl Frame { pub(crate) async fn read_from( @@ -97,10 +96,7 @@ pub struct Frame { } }; if let Some(limit) = &config.max_frame_size { (&len <= limit).then_some(()) - .ok_or_else(|| Error::new( - ErrorKind::InvalidData, - "Incoming frame is too large" - ))?; + .ok_or_else(|| Error::new(ErrorKind::InvalidData, "Incoming frame is too large"))?; } len @@ -108,10 +104,7 @@ pub struct Frame { let mask = if second & 0x80 == 0 { (config.accept_unmasked_frames).then_some(None) - .ok_or_else(|| Error::new( - ErrorKind::InvalidData, - "Client frame is unmasked" - ))? + .ok_or_else(|| Error::new(ErrorKind::InvalidData, "Client frame is unmasked"))? } else { let mut mask_bytes = [0; 4]; if let Err(e) = stream.read_exact(&mut mask_bytes).await { @@ -142,7 +135,7 @@ pub struct Frame { payload }; - Ok(Some(Self { is_final, opcode, mask, payload })) + Ok(Some(Self { is_final, opcode, payload })) } pub(crate) async fn write_unmasked(self, @@ -150,7 +143,7 @@ pub struct Frame { _config: &Config, ) -> Result { fn into_bytes(frame: Frame) -> Vec { - let Frame { is_final, opcode, payload, mask:_ } = frame; + let Frame { is_final, opcode, payload } = frame; let (payload_len_byte, payload_len_bytes) = match payload.len() { ..=125 => (payload.len() as u8, None), diff --git a/ohkami/src/websocket/message.rs b/ohkami/src/websocket/message.rs index 4123fb0e..1ba43146 100644 --- a/ohkami/src/websocket/message.rs +++ b/ohkami/src/websocket/message.rs @@ -64,7 +64,7 @@ impl Message { } }; - Frame { is_final: false, mask: None, opcode, payload } + Frame { is_final: false, opcode, payload } } pub(crate) async fn write(self, diff --git a/ohkami/src/websocket/mod.rs b/ohkami/src/websocket/mod.rs index 38a36935..84765661 100644 --- a/ohkami/src/websocket/mod.rs +++ b/ohkami/src/websocket/mod.rs @@ -10,9 +10,8 @@ pub use message::Message; pub(crate) use session::WebSocket as Session; use ohkami_lib::base64; -use std::{future::Future, borrow::Cow, pin::Pin}; -use crate::{FromRequest, IntoResponse, Method, Request, Response}; -use crate::__rt__::{task, AsyncReader, AsyncWriter, TcpStream}; +use std::{future::Future, pin::Pin}; +use crate::{__rt__, FromRequest, IntoResponse, Request, Response}; // #[derive(Clone)] @@ -31,7 +30,7 @@ pub struct WebSocketContext<'req> { impl<'ws> WebSocketContext<'ws> { pub fn connect + Send + 'static>(self, - handler: impl Fn(Session<'ws, TcpStream>) -> Fut + Send + Sync + 'static + handler: impl Fn(Session<'ws, __rt__::TcpStream>) -> Fut + Send + Sync + 'static ) -> WebSocket { #[inline] fn signed(sec_websocket_key: &str) -> String { use ::sha1::{Sha1, Digest}; @@ -53,7 +52,7 @@ pub struct WebSocketContext<'req> { }; pub(crate) type Handler = Box) -> Pin + Send + '_>> + Fn(Session<'_, __rt__::TcpStream>) -> Pin + Send + '_>> + Send + Sync + 'static >; From 2af39378fbd67e7c03c6f3fcb75f30cbee0988af Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 14:50:00 +0900 Subject: [PATCH 03/15] @2024-07-19 14:49+9:00 --- ohkami/src/lib.rs | 2 +- ohkami/src/response/content.rs | 2 +- ohkami/src/response/mod.rs | 25 ++++++++++++++----------- ohkami/src/websocket/mod.rs | 7 +------ 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 6d3e27ab..a65ae06a 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -105,7 +105,7 @@ pub mod builtin; pub mod typed; -#[cfg(feature="ws")] +#[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] pub mod websocket; #[cfg(feature="testing")] diff --git a/ohkami/src/response/content.rs b/ohkami/src/response/content.rs index 215d8611..8b45e609 100644 --- a/ohkami/src/response/content.rs +++ b/ohkami/src/response/content.rs @@ -12,7 +12,7 @@ pub enum Content { #[cfg(feature="sse")] Stream(std::pin::Pin> + Send>>), - #[cfg(feature="ws")] + #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] WebSocket(crate::websocket::Handler), } const _: () = { impl Default for Content { diff --git a/ohkami/src/response/mod.rs b/ohkami/src/response/mod.rs index 80693ba8..fb98c553 100644 --- a/ohkami/src/response/mod.rs +++ b/ohkami/src/response/mod.rs @@ -134,19 +134,22 @@ impl Response { .ContentLength(None), _ => self.headers.set() .ContentLength("0") - } + }; } - Content::Payload(bytes) => self.headers.set() - .ContentLength(ohkami_lib::num::itoa(bytes.len())), + Content::Payload(bytes) => { + self.headers.set() + .ContentLength(ohkami_lib::num::itoa(bytes.len())); + } #[cfg(feature="sse")] - Content::Stream(_) => self.headers.set() - .ContentLength(None), + Content::Stream(_) => { + self.headers.set() + .ContentLength(None); + } - #[cfg(feature="ws")] - Content::WebSocket(_) => self.headers.set() - .ContentLength(None), + #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] + Content::WebSocket(_) => (), }; } } @@ -233,7 +236,7 @@ impl Response { conn.flush().await.expect("Failed to flush connection"); } - #[cfg(feature="ws")] + #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] Content::WebSocket(handler) => { let mut buf = Vec::::with_capacity( self.status.line().len() + @@ -395,7 +398,7 @@ impl Response { } } -#[cfg(feature="ws")] +#[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] impl Response { pub(crate) fn with_websocket(mut self, handler: crate::websocket::Handler) -> Self { self.content = Content::WebSocket(handler); @@ -426,7 +429,7 @@ const _: () = { DummyStream })), - #[cfg(feature="ws")] + #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] Content::WebSocket(_) => Content::WebSocket(Box::new({ |_| Box::pin(async {/* dummy handler */}) })), diff --git a/ohkami/src/websocket/mod.rs b/ohkami/src/websocket/mod.rs index 84765661..bf272eb3 100644 --- a/ohkami/src/websocket/mod.rs +++ b/ohkami/src/websocket/mod.rs @@ -1,6 +1,4 @@ -#![cfg(all( - feature="ws", -))] +#![cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] mod session; mod message; @@ -14,7 +12,6 @@ use std::{future::Future, pin::Pin}; use crate::{__rt__, FromRequest, IntoResponse, Request, Response}; -// #[derive(Clone)] pub struct WebSocketContext<'req> { sec_websocket_key: &'req str, } const _: () = { @@ -56,7 +53,6 @@ pub(crate) type Handler = Box; -// #[derive(Clone)] pub struct WebSocket { sec_websocket_key: String, handler: Handler, @@ -73,7 +69,6 @@ pub struct WebSocket { /// ## Note /// /// - Currently, subprotocols with `Sec-WebSocket-Protocol` is not supported -//#[derive(Clone)] pub struct Config { pub write_buffer_size: usize, pub max_write_buffer_size: usize, From fbc2cf079b0f51f313a97c1cf13ee517ac918ef4 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 14:54:32 +0900 Subject: [PATCH 04/15] fix Taskfile --- Taskfile.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Taskfile.yaml b/Taskfile.yaml index cf24b078..9ef4ce42 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -58,7 +58,7 @@ tasks: dir: ohkami cmds: - cargo test --lib --features rt_tokio,DEBUG,{{.MAYBE_NIGHTLY}} - - cargo test --lib --features rt_tokio,DEBUG,sse,{{.MAYBE_NIGHTLY}} + - cargo test --lib --features rt_tokio,DEBUG,sse,ws,{{.MAYBE_NIGHTLY}} test_rt_async-std: vars: @@ -67,7 +67,7 @@ tasks: dir: ohkami cmds: - cargo test --lib --features rt_async-std,DEBUG,{{.MAYBE_NIGHTLY}} - - cargo test --lib --features rt_async-std,DEBUG,sse,{{.MAYBE_NIGHTLY}} + - cargo test --lib --features rt_async-std,DEBUG,sse,ws,{{.MAYBE_NIGHTLY}} test_rt_worker: vars: @@ -100,6 +100,7 @@ tasks: cmds: - cargo check --lib --features rt_tokio,{{.MAYBE_NIGHTLY}} - cargo check --lib --features rt_tokio,sse,{{.MAYBE_NIGHTLY}} + - cargo check --lib --features rt_tokio,sse,ws,{{.MAYBE_NIGHTLY}} check_rt_async-std: vars: @@ -109,6 +110,7 @@ tasks: cmds: - cargo check --lib --features rt_async-std,{{.MAYBE_NIGHTLY}} - cargo check --lib --features rt_async-std,sse,{{.MAYBE_NIGHTLY}} + - cargo check --lib --features rt_async-std,sse,ws,{{.MAYBE_NIGHTLY}} check_rt_worker: vars: From e3efbcbee24f42858c48b0104873235400ad1a1b Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 15:20:32 +0900 Subject: [PATCH 05/15] make Config configurable & fix test_doc --- README.md | 21 ++++++++++++++++++++- Taskfile.yaml | 2 +- ohkami/Cargo.toml | 2 +- ohkami/src/response/content.rs | 5 ++++- ohkami/src/response/mod.rs | 22 +++++++++++----------- ohkami/src/websocket/mod.rs | 11 ++++++++++- ohkami/src/websocket/session.rs | 2 ++ 7 files changed, 49 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 382631bd..75d115cd 100644 --- a/README.md +++ b/README.md @@ -296,6 +296,25 @@ async fn main() {
+### WebSocket with `"ws"` feature + +Currently, WebSocket on `rt_worker` is not supported. + +```rust,no_run +use ohkami::prelude::*; +use ohkami::websocket::{WebSocketContext, WebSocket, Message}; + +async fn echo_text(c: WebSocketContext) -> WebSocket { + c.connect(|ws| async move { + while let Ok(Some(Message::Text(text)) = ws.recv().await) { + ws.send(Message::Text(text)).await.expect("Failed to send text"); + } + }) +} +``` + +
+ ### Pack of Ohkamis ```rust,no_run @@ -384,7 +403,7 @@ async fn test_my_ohkami() { - [ ] HTTP/3 - [ ] HTTPS - [x] Server-Sent Events -- [ ] WebSocket +- [x] WebSocket ## MSRV (Minimum Supported Rust Version) diff --git a/Taskfile.yaml b/Taskfile.yaml index 9ef4ce42..f5926e8e 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -32,7 +32,7 @@ tasks: test_doc: dir: ohkami cmds: - - cargo test --doc --features DEBUG,rt_tokio,sse + - cargo test --doc --features DEBUG,rt_tokio,sse,ws test_examples: dir: examples diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 0e0cde90..0145ff65 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -14,7 +14,7 @@ license = "MIT" [package.metadata.docs.rs] -features = ["rt_tokio", "nightly", "sse"] +features = ["rt_tokio", "nightly", "sse", "ws"] [dependencies] diff --git a/ohkami/src/response/content.rs b/ohkami/src/response/content.rs index 8b45e609..c19707f1 100644 --- a/ohkami/src/response/content.rs +++ b/ohkami/src/response/content.rs @@ -3,6 +3,9 @@ use ohkami_lib::CowSlice; #[cfg(feature="sse")] use ohkami_lib::Stream; +#[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] +use crate::websocket::{Config, Handler}; + pub enum Content { None, @@ -13,7 +16,7 @@ pub enum Content { Stream(std::pin::Pin> + Send>>), #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] - WebSocket(crate::websocket::Handler), + WebSocket((Config, Handler)), } const _: () = { impl Default for Content { fn default() -> Self { diff --git a/ohkami/src/response/mod.rs b/ohkami/src/response/mod.rs index fb98c553..3620d070 100644 --- a/ohkami/src/response/mod.rs +++ b/ohkami/src/response/mod.rs @@ -237,7 +237,7 @@ impl Response { } #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] - Content::WebSocket(handler) => { + Content::WebSocket((config, handler)) => { let mut buf = Vec::::with_capacity( self.status.line().len() + self.headers.size @@ -250,11 +250,7 @@ impl Response { /* this doesn't match in testing */ if let Some(tcp_stream) = ::downcast_mut::(conn) { - use crate::websocket::{Session, Config}; - - /* FIXME: make Config configurable */ - let ws = Session::new(tcp_stream, Config::default()); - + let ws = crate::websocket::Session::new(tcp_stream, config); handler(ws).await } } @@ -400,8 +396,11 @@ impl Response { #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] impl Response { - pub(crate) fn with_websocket(mut self, handler: crate::websocket::Handler) -> Self { - self.content = Content::WebSocket(handler); + pub(crate) fn with_websocket(mut self, + config: crate::websocket::Config, + handler: crate::websocket::Handler + ) -> Self { + self.content = Content::WebSocket((config, handler)); self } } @@ -430,9 +429,10 @@ const _: () = { })), #[cfg(all(feature="ws", any(feature="rt_tokio",feature="rt_async-std")))] - Content::WebSocket(_) => Content::WebSocket(Box::new({ - |_| Box::pin(async {/* dummy handler */}) - })), + Content::WebSocket(_) => Content::WebSocket(( + crate::websocket::Config::default(), + Box::new(|_| Box::pin(async {/* dummy handler */})) + )), } }; this.complete(); diff --git a/ohkami/src/websocket/mod.rs b/ohkami/src/websocket/mod.rs index bf272eb3..7b6a7ac7 100644 --- a/ohkami/src/websocket/mod.rs +++ b/ohkami/src/websocket/mod.rs @@ -28,6 +28,13 @@ pub struct WebSocketContext<'req> { impl<'ws> WebSocketContext<'ws> { pub fn connect + Send + 'static>(self, handler: impl Fn(Session<'ws, __rt__::TcpStream>) -> Fut + Send + Sync + 'static + ) -> WebSocket { + self.connect_with(Config::default(), handler) + } + + pub fn connect_with + Send + 'static>(self, + config: Config, + handler: impl Fn(Session<'ws, __rt__::TcpStream>) -> Fut + Send + Sync + 'static ) -> WebSocket { #[inline] fn signed(sec_websocket_key: &str) -> String { use ::sha1::{Sha1, Digest}; @@ -38,6 +45,7 @@ pub struct WebSocketContext<'req> { } WebSocket { + config, sec_websocket_key: signed(self.sec_websocket_key), handler: Box::new(move |ws| Box::pin({ let h = handler(unsafe {std::mem::transmute::<_, Session<'ws, _>>(ws)}); @@ -54,6 +62,7 @@ pub(crate) type Handler = Box; pub struct WebSocket { + config: Config, sec_websocket_key: String, handler: Handler, } impl IntoResponse for WebSocket { @@ -62,7 +71,7 @@ pub struct WebSocket { .Connection("Update") .Upgrade("websocket") .SecWebSocketAccept(self.sec_websocket_key) - ).with_websocket(self.handler) + ).with_websocket(self.config, self.handler) } } diff --git a/ohkami/src/websocket/session.rs b/ohkami/src/websocket/session.rs index ac1f9517..8f7133e8 100644 --- a/ohkami/src/websocket/session.rs +++ b/ohkami/src/websocket/session.rs @@ -65,9 +65,11 @@ impl<'ws, Conn: AsyncWriter + AsyncReader + Unpin> WebSocket<'ws, Conn> { pub async fn send(&mut self, message: Message) -> Result<(), Error> { send(message, &mut self.conn, &self.config, &mut self.n_buffered).await } + pub async fn write(&mut self, message: Message) -> Result { write(message, &mut self.conn, &self.config, &mut self.n_buffered).await } + pub async fn flush(&mut self) -> Result<(), Error> { flush(&mut self.conn, &mut self.n_buffered).await } From 59e0c25a382cd0627655f0ccc21a4087e7307271 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 15:55:32 +0900 Subject: [PATCH 06/15] fix around lifetimes & add example to README --- README.md | 6 +++--- ohkami/src/response/mod.rs | 2 +- ohkami/src/websocket/mod.rs | 14 ++++++------ ohkami/src/websocket/session.rs | 38 ++++++++++++++++++--------------- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 75d115cd..71a76785 100644 --- a/README.md +++ b/README.md @@ -304,9 +304,9 @@ Currently, WebSocket on `rt_worker` is not supported. use ohkami::prelude::*; use ohkami::websocket::{WebSocketContext, WebSocket, Message}; -async fn echo_text(c: WebSocketContext) -> WebSocket { - c.connect(|ws| async move { - while let Ok(Some(Message::Text(text)) = ws.recv().await) { +async fn echo_text(c: WebSocketContext<'_>) -> WebSocket { + c.connect(|mut ws| async move { + while let Ok(Some(Message::Text(text))) = ws.recv().await { ws.send(Message::Text(text)).await.expect("Failed to send text"); } }) diff --git a/ohkami/src/response/mod.rs b/ohkami/src/response/mod.rs index 3620d070..32ab9a1a 100644 --- a/ohkami/src/response/mod.rs +++ b/ohkami/src/response/mod.rs @@ -250,7 +250,7 @@ impl Response { /* this doesn't match in testing */ if let Some(tcp_stream) = ::downcast_mut::(conn) { - let ws = crate::websocket::Session::new(tcp_stream, config); + let ws = unsafe {crate::websocket::Session::new(tcp_stream, config)}; handler(ws).await } } diff --git a/ohkami/src/websocket/mod.rs b/ohkami/src/websocket/mod.rs index 7b6a7ac7..85cea9ed 100644 --- a/ohkami/src/websocket/mod.rs +++ b/ohkami/src/websocket/mod.rs @@ -25,16 +25,16 @@ pub struct WebSocketContext<'req> { } } - impl<'ws> WebSocketContext<'ws> { + impl<'ctx> WebSocketContext<'ctx> { pub fn connect + Send + 'static>(self, - handler: impl Fn(Session<'ws, __rt__::TcpStream>) -> Fut + Send + Sync + 'static + handler: impl Fn(Session<__rt__::TcpStream>) -> Fut + Send + Sync + 'static ) -> WebSocket { self.connect_with(Config::default(), handler) } pub fn connect_with + Send + 'static>(self, config: Config, - handler: impl Fn(Session<'ws, __rt__::TcpStream>) -> Fut + Send + Sync + 'static + handler: impl Fn(Session<__rt__::TcpStream>) -> Fut + Send + Sync + 'static ) -> WebSocket { #[inline] fn signed(sec_websocket_key: &str) -> String { use ::sha1::{Sha1, Digest}; @@ -48,8 +48,8 @@ pub struct WebSocketContext<'req> { config, sec_websocket_key: signed(self.sec_websocket_key), handler: Box::new(move |ws| Box::pin({ - let h = handler(unsafe {std::mem::transmute::<_, Session<'ws, _>>(ws)}); - async {h.await} + let session = handler(ws); + async {session.await} })) } } @@ -57,8 +57,8 @@ pub struct WebSocketContext<'req> { }; pub(crate) type Handler = Box) -> Pin + Send + '_>> - + Send + Sync + 'static + Fn(Session<__rt__::TcpStream>) -> Pin + Send + 'static>> + + Send + Sync >; pub struct WebSocket { diff --git a/ohkami/src/websocket/session.rs b/ohkami/src/websocket/session.rs index 8f7133e8..43d8ddfc 100644 --- a/ohkami/src/websocket/session.rs +++ b/ohkami/src/websocket/session.rs @@ -4,23 +4,23 @@ use crate::__rt__::{AsyncWriter, AsyncReader}; /* Used only in `ohkami::websocket::WebSocket::{new, with}` and NOT `use`able by user */ -pub struct WebSocket<'ws, Conn: AsyncWriter + AsyncReader + Unpin> { - conn: &'ws mut Conn, +pub struct WebSocket { + conn: *mut Conn, config: Config, n_buffered: usize, } -impl<'ws, Conn: AsyncWriter + AsyncReader + Unpin> WebSocket<'ws, Conn> { - pub(crate) fn new(conn: &'ws mut Conn, config: Config) -> Self { - Self { conn, config, n_buffered:0 } - } -} - -impl<'ws, Conn: AsyncWriter + AsyncReader + Unpin> WebSocket<'ws, Conn> { - pub async fn recv(&mut self) -> Result, Error> { - Message::read_from(self.conn, &self.config).await +const _: () = { + unsafe impl Send for WebSocket {} + unsafe impl Sync for WebSocket {} + + impl WebSocket { + /// SAFETY: `conn` is valid while entire the conversation + pub(crate) unsafe fn new(conn: &mut Conn, config: Config) -> Self { + Self { conn, config, n_buffered:0 } + } } -} +}; // ============================================================================= pub(super) async fn send( @@ -61,16 +61,20 @@ pub(super) async fn flush( } // ============================================================================= -impl<'ws, Conn: AsyncWriter + AsyncReader + Unpin> WebSocket<'ws, Conn> { +impl WebSocket { + pub async fn recv(&mut self) -> Result, Error> { + Message::read_from(unsafe {&mut *self.conn}, &self.config).await + } + pub async fn send(&mut self, message: Message) -> Result<(), Error> { - send(message, &mut self.conn, &self.config, &mut self.n_buffered).await + send(message, unsafe {&mut *self.conn}, &self.config, &mut self.n_buffered).await } pub async fn write(&mut self, message: Message) -> Result { - write(message, &mut self.conn, &self.config, &mut self.n_buffered).await + write(message, unsafe {&mut *self.conn}, &self.config, &mut self.n_buffered).await } - + pub async fn flush(&mut self) -> Result<(), Error> { - flush(&mut self.conn, &mut self.n_buffered).await + flush(unsafe {&mut *self.conn}, &mut self.n_buffered).await } } From 3a8c5a536f9083a6c33eaeaa931d24d959657a05 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 15:57:15 +0900 Subject: [PATCH 07/15] update README example --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 71a76785..76fa7698 100644 --- a/README.md +++ b/README.md @@ -311,6 +311,13 @@ async fn echo_text(c: WebSocketContext<'_>) -> WebSocket { } }) } + +#[tokio::main] +async fn main() { + Ohkami::new(( + "/ws".GET(echo_text), + )).howl("localhost:3030").await +} ```
From 3d47a1de48e2173bb99b108662a03fc666a188a3 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 16:11:44 +0900 Subject: [PATCH 08/15] checked performance difference & delete unused comments --- examples/Cargo.toml | 5 ++--- ohkami/src/session/mod.rs | 25 +++---------------------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ce8f8258..c68c1f88 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -14,9 +14,8 @@ members = [ ] [workspace.dependencies] -# To assure "DEBUG" feature be off even if DEBUGing `../ohkami`, -# explicitly set `default-features = false` -ohkami = { path = "../ohkami", default-features = false, features = ["rt_tokio", "testing", "sse"] } +# set `default-features = false` to assure "DEBUG" feature be off even when DEBUGing `../ohkami` +ohkami = { path = "../ohkami", default-features = false, features = ["rt_tokio", "testing", "sse", "ws"] } tokio = { version = "1", features = ["full"] } sqlx = { version = "0.7.3", features = ["runtime-tokio-native-tls", "postgres", "macros", "chrono", "uuid"] } tracing = "0.1" diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 69e0220c..a8aca2e0 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -36,41 +36,22 @@ impl Session { crate::Response::InternalServerError() } - // /* async-std doesn't provide split */ - // #[cfg(feature="rt_tokio")] - // let (mut r, mut w) = self.connection.split(); - // #[cfg(feature="rt_async-std")] - // let c = &mut self.connection; - - // #[cfg(feature="rt_tokio")] - // macro_rules! read {($req:ident) => {$req.as_mut().read(&mut r)};} - // #[cfg(feature="rt_async-std")] - // macro_rules! read {($req:ident) => {$req.as_mut().read(c)};} - - // #[cfg(feature="rt_tokio")] - // macro_rules! send {($res:ident) => {$res.send(&mut w)};} - // #[cfg(feature="rt_async-std")] - // macro_rules! send {($res:ident) => {$res.send(c)};} - - macro_rules! read {($req:ident) => {$req.as_mut().read(&mut self.connection)};} - macro_rules! send {($res:ident) => {$res.send(&mut self.connection)};} - timeout_in(std::time::Duration::from_secs(crate::env::OHKAMI_KEEPALIVE_TIMEOUT()), async { loop { let mut req = Request::init(); let mut req = unsafe {Pin::new_unchecked(&mut req)}; - match read!(req).await { + match req.as_mut().read(&mut self.connection).await { Ok(Some(())) => { let close = matches!(req.headers.Connection(), Some("close" | "Close")); let res = match catch_unwind(AssertUnwindSafe(|| self.router.handle(req.get_mut()))) { Ok(future) => future.await, Err(panic) => panicking(panic), }; - send!(res).await; + res.send(&mut self.connection).await; if close {break} } Ok(None) => break, - Err(res) => send!(res).await, + Err(res) => res.send(&mut self.connection).await, }; } }).await; From dc595bf6d1b6ab2345be6cb160a35bb61ab76ce4 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Fri, 19 Jul 2024 16:15:56 +0900 Subject: [PATCH 09/15] next: examples/websocket --- examples/Cargo.toml | 1 + examples/websocket/Cargo.toml | 11 +++++++++++ examples/websocket/src/main.rs | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 examples/websocket/Cargo.toml create mode 100644 examples/websocket/src/main.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c68c1f88..91db6b1f 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -5,6 +5,7 @@ members = [ "form", "hello", "openai", + "websocket", "realworld", "basic_auth", "quick_start", diff --git a/examples/websocket/Cargo.toml b/examples/websocket/Cargo.toml new file mode 100644 index 00000000..69a8f99a --- /dev/null +++ b/examples/websocket/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "websocket" +version = "0.1.0" +edition = "2021" + +[dependencies] +ohkami = { workspace = true } +tokio = { workspace = true } + +[features] +DEBUG = ["ohkami/DEBUG"] \ No newline at end of file diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs new file mode 100644 index 00000000..4e1770f5 --- /dev/null +++ b/examples/websocket/src/main.rs @@ -0,0 +1,18 @@ +use ohkami::prelude::*; +use ohkami::websocket::{WebSocketContext, WebSocket, Message}; + + +async fn echo_text(c: WebSocketContext<'_>) -> WebSocket { + c.connect(|mut ws| async move { + while let Ok(Some(Message::Text(text))) = ws.recv().await { + ws.send(Message::Text(text)).await.expect("Failed to send text"); + } + }) +} + +#[tokio::main] +async fn main() { + Ohkami::new(( + "/ws".GET(echo_text), + )).howl("localhost:3030").await +} From 7eb81902122a5d280cc22f5b6d5d9654d95d345e Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 20 Jul 2024 14:49:29 +0900 Subject: [PATCH 10/15] next: bugs around frame --- examples/websocket/src/main.rs | 39 ++++++++- examples/websocket/template/index.html | 35 ++++++++ ohkami/Cargo.toml | 22 ++--- ohkami/src/ohkami/build.rs | 2 +- ohkami/src/response/mod.rs | 19 ++++- ohkami/src/session/mod.rs | 9 +++ ohkami/src/websocket/frame.rs | 1 + ohkami/src/websocket/message.rs | 12 +-- ohkami/src/websocket/mod.rs | 108 +++++++------------------ ohkami/src/websocket/session.rs | 21 ++++- 10 files changed, 158 insertions(+), 110 deletions(-) create mode 100644 examples/websocket/template/index.html diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs index 4e1770f5..2a2c108e 100644 --- a/examples/websocket/src/main.rs +++ b/examples/websocket/src/main.rs @@ -2,17 +2,48 @@ use ohkami::prelude::*; use ohkami::websocket::{WebSocketContext, WebSocket, Message}; +#[derive(Clone)] +struct Logger; +impl FangAction for Logger { + async fn fore<'a>(&'a self, req: &'a mut Request) -> Result<(), Response> { + Ok(println!("\n{req:#?}")) + } + + async fn back<'a>(&'a self, res: &'a mut Response) { + println!("\n{res:#?}") + } +} + async fn echo_text(c: WebSocketContext<'_>) -> WebSocket { c.connect(|mut ws| async move { - while let Ok(Some(Message::Text(text))) = ws.recv().await { - ws.send(Message::Text(text)).await.expect("Failed to send text"); + #[cfg(feature="DEBUG")] { + println!("WebSocket handler is called"); + } + + #[cfg(feature="DEBUG")] { + loop { + let r = dbg!(ws.recv().await); + // println!("alive: {:?}", ws.is_alive()); + // tokio::time::sleep(std::time::Duration::from_secs(1)).await; + let Ok(Some(Message::Text(text))) = r else { + break + }; + println!("recv: {text}"); + ws.send(Message::Text(text)).await.expect("Failed to send text"); + } + } + #[cfg(not(feature="DEBUG"))] { + while let Ok(Some(Message::Text(text))) = ws.recv().await { + ws.send(Message::Text(text)).await.expect("Failed to send text"); + } } }) } #[tokio::main] async fn main() { - Ohkami::new(( - "/ws".GET(echo_text), + Ohkami::with(Logger, ( + "/".Dir("./template").omit_extensions([".html"]), + "/echo".GET(echo_text), )).howl("localhost:3030").await } diff --git a/examples/websocket/template/index.html b/examples/websocket/template/index.html new file mode 100644 index 00000000..dc2ca08c --- /dev/null +++ b/examples/websocket/template/index.html @@ -0,0 +1,35 @@ + + + + + + Ohkami WebSocket Example + + +

Echo Text

+ +
+ + +
+ + + + + \ No newline at end of file diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 0145ff65..a2adf0eb 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -37,7 +37,7 @@ sha1 = { version = "0.10", optional = true, default-features = false } [features] -default = ["testing"] +#default = ["testing"] rt_tokio = ["dep:tokio"] rt_async-std = ["dep:async-std"] @@ -54,13 +54,13 @@ DEBUG = [ "tokio?/rt-multi-thread", "async-std?/attributes", ] -#default = [ -# "nightly", -# "testing", -# "sse", -# "ws", -# "rt_tokio", -# #"rt_async-std", -# #"rt_worker", -# "DEBUG", -#] \ No newline at end of file +default = [ + "nightly", + "testing", + "sse", + "ws", + "rt_tokio", + #"rt_async-std", + #"rt_worker", + "DEBUG", +] \ No newline at end of file diff --git a/ohkami/src/ohkami/build.rs b/ohkami/src/ohkami/build.rs index ac508c7c..93bc739c 100644 --- a/ohkami/src/ohkami/build.rs +++ b/ohkami/src/ohkami/build.rs @@ -312,7 +312,7 @@ trait RoutingItem { Handlers::new(Box::leak({ let base_path = self.route.trim_end_matches('/').to_string(); match &*path.join("/") { - "" => base_path, + "" => if !base_path.is_empty() {base_path} else {"/".into()}, some => base_path + "/" + some, } }.into_boxed_str())).GET(handler) diff --git a/ohkami/src/response/mod.rs b/ohkami/src/response/mod.rs index 32ab9a1a..c7e1b5de 100644 --- a/ohkami/src/response/mod.rs +++ b/ohkami/src/response/mod.rs @@ -250,8 +250,25 @@ impl Response { /* this doesn't match in testing */ if let Some(tcp_stream) = ::downcast_mut::(conn) { + #[cfg(feature="DEBUG")] { + println!("Entered websocket session with TcpStream"); + } + let ws = unsafe {crate::websocket::Session::new(tcp_stream, config)}; - handler(ws).await + + #[cfg(feature="DEBUG")] { + if !ws.is_alive() { + println!("websocket is already disconnected before handler is called"); + } + } + + if ws.is_alive() { + handler(ws).await + } + } + + #[cfg(feature="DEBUG")] { + println!("websocket session finished"); } } } diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index a8aca2e0..4b1d1fc2 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -48,6 +48,11 @@ impl Session { Err(panic) => panicking(panic), }; res.send(&mut self.connection).await; + + #[cfg(feature="DEBUG")] { + println!("sended response"); + } + if close {break} } Ok(None) => break, @@ -56,6 +61,10 @@ impl Session { } }).await; + #[cfg(feature="DEBUG")] { + println!("about to shutdown connection"); + } + if let Some(err) = { #[cfg(feature="rt_tokio")] {use crate::__rt__::AsyncWriter; self.connection.shutdown().await diff --git a/ohkami/src/websocket/frame.rs b/ohkami/src/websocket/frame.rs index 7f10cd51..ac272a80 100644 --- a/ohkami/src/websocket/frame.rs +++ b/ohkami/src/websocket/frame.rs @@ -34,6 +34,7 @@ pub enum OpCode { } } +#[derive(Debug)] pub enum CloseCode { Normal, Away, Protocol, Unsupported, Status, Abnormal, Invalid, Policy, Size, Extension, Error, Restart, Again, Tls, Reserved, diff --git a/ohkami/src/websocket/message.rs b/ohkami/src/websocket/message.rs index 1ba43146..283f9f70 100644 --- a/ohkami/src/websocket/message.rs +++ b/ohkami/src/websocket/message.rs @@ -5,6 +5,7 @@ use super::{frame::{Frame, OpCode, CloseCode}, Config}; const PING_PONG_PAYLOAD_LIMIT: usize = 125; +#[derive(Debug)] pub enum Message { Text (String), Binary(Vec), @@ -12,6 +13,7 @@ pub enum Message { Pong (Vec), Close (Option), } +#[derive(Debug)] pub struct CloseFrame { pub code: CloseCode, pub reason: Option>, @@ -73,16 +75,6 @@ impl Message { ) -> Result { self.into_frame().write_unmasked(stream, config).await } -// /// for test -// pub(crate) async fn masking_write(self, -// stream: &mut (impl AsyncWriter + Unpin), -// config: &Config, -// mask: [u8; 4], -// ) -> Result { -// let mut frame = self.into_frame(); -// frame.mask = Some(mask); -// frame.write_masked(stream, config).await -// } } impl Message { diff --git a/ohkami/src/websocket/mod.rs b/ohkami/src/websocket/mod.rs index 85cea9ed..0863feb7 100644 --- a/ohkami/src/websocket/mod.rs +++ b/ohkami/src/websocket/mod.rs @@ -16,12 +16,22 @@ pub struct WebSocketContext<'req> { sec_websocket_key: &'req str, } const _: () = { impl<'req> FromRequest<'req> for WebSocketContext<'req> { - type Error = std::convert::Infallible; + type Error = Response; fn from_request(req: &'req Request) -> Option> { - req.headers.SecWebSocketKey().map(|swk| Ok(Self { - sec_websocket_key: swk, - })) + if !req.headers.Connection()?.contains("Upgrade") { + return Some(Err((|| Response::BadRequest().with_text("upgrade request must have `Connection: Upgrade`"))())) + } + if req.headers.Upgrade()? != "websocket" { + return Some(Err((|| Response::BadRequest().with_text("upgrade request must have `Upgrade: websocket`"))())) + } + if req.headers.SecWebSocketVersion()? != "13" { + return Some(Err((|| Response::BadRequest().with_text("upgrade request must have `Sec-WebSocket-Version: 13`"))())) + } + + req.headers.SecWebSocketKey().map(|sec_websocket_key| + Ok(Self { sec_websocket_key }) + ) } } @@ -36,17 +46,9 @@ pub struct WebSocketContext<'req> { config: Config, handler: impl Fn(Session<__rt__::TcpStream>) -> Fut + Send + Sync + 'static ) -> WebSocket { - #[inline] fn signed(sec_websocket_key: &str) -> String { - use ::sha1::{Sha1, Digest}; - let mut sha1 = ::new(); - sha1.update(sec_websocket_key.as_bytes()); - sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - base64::encode(sha1.finalize()) - } - WebSocket { config, - sec_websocket_key: signed(self.sec_websocket_key), + sec_websocket_key: sign(self.sec_websocket_key), handler: Box::new(move |ws| Box::pin({ let session = handler(ws); async {session.await} @@ -68,7 +70,7 @@ pub struct WebSocket { } impl IntoResponse for WebSocket { fn into_response(self) -> Response { Response::SwitchingProtocols().with_headers(|h|h - .Connection("Update") + .Connection("Upgrade") .Upgrade("websocket") .SecWebSocketAccept(self.sec_websocket_key) ).with_websocket(self.config, self.handler) @@ -98,72 +100,16 @@ pub struct Config { } }; -// impl WebSocket { -// /// shortcut for `WebSocket::with(Config::default())` -// pub fn new + Send>( -// handler: impl Fn(Session<'_, TcpStream>) -> Fut + 'static -// ) -> Self { -// Self::with(Config::default(), handler) -// } -// -// pub fn with + Send>( -// config: Config, -// handler: impl Fn(Session<'_, TcpStream>) -> Fut + 'static -// ) -> Self { -// task::spawn(async move { -// todo!() -// }); -// -// Self { config, handler: } -// } -// } -// -// - -/* -impl WebSocket { - pub fn on_upgrade + Send + 'static>( - self, - handler: impl Fn(WebSocket) -> Fut + Send + Sync + 'static - ) -> Response { - #[inline] fn sign(sec_websocket_key: &str) -> String { - use ::sha1::{Sha1, Digest}; - - let mut sha1 = ::new(); - sha1.update(sec_websocket_key.as_bytes()); - sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - base64::encode(sha1.finalize()) - } - - let Self { - config, - selected_protocol, - sec_websocket_key, - .. - } = self; - - task::spawn({ - async move { - let stream = match self.id { - None => return on_failed_upgrade.handle(UpgradeError::NotRequestedUpgrade), - Some(id) => assume_upgradable(id).await, - }; - - let ws = WebSocket::new(stream, config); - handler(ws).await - } - }); +#[inline] fn sign(sec_websocket_key: &str) -> String { + use ::sha1::{Sha1, Digest}; + let mut sha1 = ::new(); + sha1.update(sec_websocket_key.as_bytes()); + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + base64::encode(sha1.finalize()) +} - let mut handshake_res = Response::SwitchingProtocols(); - handshake_res.headers.set() - .Connection("Update") - .Upgrade("websocket") - .SecWebSocketAccept(sign(&sec_websocket_key)); - if let Some(protocol) = selected_protocol { - handshake_res.headers.set() - .SecWebSocketProtocol(protocol.to_string()); - } - handshake_res - } +#[cfg(test)] +#[test] fn test_sign() { + // example in https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#server_handshake_response + assert_eq!(sign("dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); } -*/ diff --git a/ohkami/src/websocket/session.rs b/ohkami/src/websocket/session.rs index 43d8ddfc..bb6b840b 100644 --- a/ohkami/src/websocket/session.rs +++ b/ohkami/src/websocket/session.rs @@ -5,8 +5,8 @@ use crate::__rt__::{AsyncWriter, AsyncReader}; /* Used only in `ohkami::websocket::WebSocket::{new, with}` and NOT `use`able by user */ pub struct WebSocket { - conn: *mut Conn, - config: Config, + conn: *mut Conn, + config: Config, n_buffered: usize, } @@ -17,8 +17,25 @@ const _: () = { impl WebSocket { /// SAFETY: `conn` is valid while entire the conversation pub(crate) unsafe fn new(conn: &mut Conn, config: Config) -> Self { + let conn: *mut Conn = conn; + if conn.is_null() { + panic!("Invalid connection") + } + + #[cfg(feature="DEBUG")] { + println!("`websocket::session::WebSocket::new` finished successfully: conn @ {conn:?}") + } + Self { conn, config, n_buffered:0 } } + + pub fn is_alive(&self) -> bool { + #[cfg(feature="DEBUG")] { + println!("conn @ {:?}", self.conn); + } + + !self.conn.is_null() + } } }; From dd1ae6b513fe9044f0e371fc79afbe7297df8576 Mon Sep 17 00:00:00 2001 From: kana-rus Date: Sat, 20 Jul 2024 15:38:31 +0900 Subject: [PATCH 11/15] seems works --- examples/websocket/src/main.rs | 2 +- examples/websocket/template/index.html | 5 +++-- ohkami/Cargo.toml | 22 +++++++++++----------- ohkami/src/websocket/frame.rs | 17 +++++++++++------ ohkami/src/websocket/message.rs | 2 +- ohkami/src/websocket/session.rs | 6 +----- 6 files changed, 28 insertions(+), 26 deletions(-) diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs index 2a2c108e..36aaeda4 100644 --- a/examples/websocket/src/main.rs +++ b/examples/websocket/src/main.rs @@ -28,7 +28,7 @@ async fn echo_text(c: WebSocketContext<'_>) -> WebSocket { let Ok(Some(Message::Text(text))) = r else { break }; - println!("recv: {text}"); + println!("recv: `{text}`"); ws.send(Message::Text(text)).await.expect("Failed to send text"); } } diff --git a/examples/websocket/template/index.html b/examples/websocket/template/index.html index dc2ca08c..69880418 100644 --- a/examples/websocket/template/index.html +++ b/examples/websocket/template/index.html @@ -13,9 +13,11 @@

Echo Text

-