Skip to content

Commit

Permalink
refactor(sse): DataStream: more intuitive colocation and interface (#…
Browse files Browse the repository at this point in the history
…298)

* refactor(sse): colocation, interface

* eliminate double-boxing
  • Loading branch information
kanarus authored Dec 14, 2024
1 parent 7e365ff commit b93fc60
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 373 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,25 @@ Use some reverse proxy to do with HTTP/2,3.

```rust,no_run
use ohkami::prelude::*;
use ohkami::typed::DataStream;
use ohkami::util::stream;
use {tokio::time::sleep, std::time::Duration};
use ohkami::sse::DataStream;
use tokio::time::{sleep, Duration};
async fn sse() -> DataStream<String> {
DataStream::from_stream(stream::queue(|mut q| async move {
async fn handler() -> DataStream {
DataStream::new(|mut s| async move {
s.send("starting streaming...");
for i in 1..=5 {
sleep(Duration::from_secs(1)).await;
q.add(format!("Hi, I'm message #{i} !"))
s.send(format!("MESSAGE #{i}"));
}
}))
s.send("streaming finished!");
})
}
#[tokio::main]
async fn main() {
Ohkami::new((
"/sse".GET(sse),
)).howl("localhost:5050").await
"/sse".GET(handler),
)).howl("localhost:3020").await
}
```

Expand Down
12 changes: 6 additions & 6 deletions examples/openai/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use models::{ChatMessage, ChatCompletions, Role};

use ohkami::prelude::*;
use ohkami::format::Text;
use ohkami::typed::DataStream;
use ohkami::util::{StreamExt, stream};
use ohkami::sse::DataStream;
use ohkami::util::StreamExt;


#[tokio::main]
Expand All @@ -30,7 +30,7 @@ async fn main() {
pub async fn relay_chat_completion(
Memory(api_key): Memory<'_, &'static str>,
Text(message): Text<String>,
) -> Result<DataStream<String, Error>, Error> {
) -> Result<DataStream, Error> {
let mut gpt_response = reqwest::Client::new()
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(api_key)
Expand All @@ -47,7 +47,7 @@ pub async fn relay_chat_completion(
.send().await?
.bytes_stream();

Ok(DataStream::from_stream(stream::queue(|mut q| async move {
Ok(DataStream::new(|mut s| async move {
let mut push_line = |mut line: String| {
#[cfg(debug_assertions)] {
assert!(line.ends_with("\n\n"))
Expand All @@ -67,7 +67,7 @@ pub async fn relay_chat_completion(
}
}

q.push(Ok(line));
s.send(line);
};

let mut remaining = String::new();
Expand All @@ -89,5 +89,5 @@ pub async fn relay_chat_completion(
}
}
}
})))
}))
}
27 changes: 0 additions & 27 deletions examples/sse/src/bin/from_iter_async.rs

This file was deleted.

17 changes: 0 additions & 17 deletions examples/sse/src/bin/queue_stream.rs

This file was deleted.

27 changes: 14 additions & 13 deletions examples/sse/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use ohkami::prelude::*;
use ohkami::typed::DataStream;
use ohkami::sse::DataStream;
use tokio::time::{sleep, Duration};

async fn handler() -> DataStream {
DataStream::new(|mut s| async move {
s.send("starting streaming...");
for i in 1..=5 {
sleep(Duration::from_secs(1)).await;
s.send(format!("MESSAGE #{i}"));
}
s.send("streaming finished!");
})
}

#[tokio::main]
async fn main() {
Ohkami::new((
"/sse".GET(sse),
)).howl("localhost:5050").await
}

async fn sse() -> DataStream<String> {
DataStream::from_iter_async((1..=5).map(|i| async move {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;

Result::<_, std::convert::Infallible>::Ok(format!(
"I'm message #{i} !"
))
}))
"/sse".GET(handler),
)).howl("localhost:3020").await
}
23 changes: 3 additions & 20 deletions ohkami/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,6 @@ mod __rt__ {
#[cfg(feature="rt_glommio")]
pub(crate) use futures_util::AsyncWriteExt as AsyncWrite;

// #[cfg(feature="rt_tokio")]
// pub(crate) use tokio::select;
// #[cfg(feature="rt_async-std")]
// pub(crate) use futures_util::select;
// #[cfg(feature="rt_smol")]
// pub(crate) use futures_util::select;
// #[cfg(feature="rt_nio")]
// pub(crate) use tokio::select;
// #[cfg(feature="rt_glommio")]
// pub(crate) use futures_util::select;
//
// #[cfg(any(feature="rt_tokio", feature="rt_nio"))]
// pub(crate) const fn selectable<F: std::future::Future>(future: F) -> F {
// future
// }
// #[cfg(any(feature="rt_async-std", feature="rt_smol", feature="rt_glommio"))]
// pub(crate) fn selectable<F: std::future::Future>(future: F) -> ::futures_util::future::Fuse<F> {
// ::futures_util::FutureExt::fuse(future)
// }

#[cfg(any(feature="rt_tokio", feature="rt_async-std", feature="rt_smol", feature="rt_nio"))]
mod task {
pub trait Task: std::future::Future<Output: Send + 'static> + Send + 'static {}
Expand Down Expand Up @@ -219,6 +199,9 @@ pub mod header;

pub mod typed;

#[cfg(feature="sse")]
pub mod sse;

#[cfg(feature="ws")]
pub mod ws;

Expand Down
8 changes: 2 additions & 6 deletions ohkami/src/response/_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ async fn test_stream_response() {

let res = Response::OK()
.with_stream(
repeat_by(3, |i| Result::<_, std::convert::Infallible>::Ok(
format!("This is message#{i} !")
))
repeat_by(3, |i| format!("This is message#{i} !"))
)
.with_headers(|h| h
.Server("ohkami")
Expand Down Expand Up @@ -170,9 +168,7 @@ async fn test_stream_response() {

let res = Response::OK()
.with_stream(
repeat_by(3, |i| Result::<_, std::convert::Infallible>::Ok(
format!("This is message#{i}\nです")
))
repeat_by(3, |i| format!("This is message#{i}\nです"))
)
.with_headers(|h| h
.Server("ohkami")
Expand Down
7 changes: 5 additions & 2 deletions ohkami/src/response/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub enum Content {
Payload(CowSlice),

#[cfg(feature="sse")]
Stream(std::pin::Pin<Box<dyn Stream<Item = Result<String, String>> + Send>>),
Stream(std::pin::Pin<Box<dyn Stream<Item = String> + Send>>),

#[cfg(all(feature="ws", feature="__rt__"))]
WebSocket(Session),
Expand Down Expand Up @@ -88,7 +88,10 @@ impl Content {
Self::Payload(bytes) => ::worker::Response::from_bytes(bytes.into()),

#[cfg(feature="sse")]
Self::Stream(stream) => ::worker::Response::from_stream(stream),
Self::Stream(stream) => ::worker::Response::from_stream({
use {ohkami_lib::StreamExt, std::convert::Infallible};
stream.map(Result::<_, Infallible>::Ok)
}),

#[cfg(feature="ws")]
Self::WebSocket(ws) => ::worker::Response::from_websocket(ws)
Expand Down
82 changes: 37 additions & 45 deletions ohkami/src/response/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use ohkami_lib::{CowSlice, Slice};
#[cfg(feature="__rt_native__")]
use crate::__rt__::AsyncWrite;
#[cfg(feature="sse")]
use crate::util::StreamExt;
use crate::{sse, util::{Stream, StreamExt}};


/// # HTTP Response
Expand Down Expand Up @@ -266,27 +266,27 @@ impl Response {
#[cfg(feature="sse")]
impl Response {
#[inline]
pub fn with_stream<
T: Into<String>,
E: std::error::Error,
>(mut self,
stream: impl ohkami_lib::Stream<Item = Result<T, E>> + Unpin + Send + 'static
pub fn with_stream<T: sse::Data>(
mut self,
stream: impl Stream<Item = T> + Unpin + Send + 'static
) -> Self {
self.set_stream(stream);
self
}

#[inline]
pub fn set_stream<
T: Into<String>,
E: std::error::Error,
>(&mut self, stream: impl ohkami_lib::Stream<Item = Result<T, E>> + Unpin + Send + 'static) {
let stream = Box::pin(stream.map(|res|
res
.map(Into::into)
.map_err(|e| e.to_string())
));
pub fn set_stream<T: sse::Data>(
&mut self,
stream: impl Stream<Item = T> + Unpin + Send + 'static
) {
self.set_stream_raw(Box::pin(stream.map(sse::Data::encode)));
}

#[inline]
pub fn set_stream_raw(
&mut self,
stream: std::pin::Pin<Box<dyn Stream<Item = String> + Send>>
) {
self.headers.set()
.ContentType("text/event-stream")
.CacheControl("no-cache, must-revalidate")
Expand Down Expand Up @@ -375,37 +375,29 @@ impl Response {
conn.flush().await.expect("Failed to flush connection");

while let Some(chunk) = stream.next().await {
match chunk {
Err(msg) => {
crate::warning!("Error in stream: {msg}");
break
}
Ok(chunk) => {
let mut message = Vec::with_capacity(
/* capacity for a single line */
"data: ".len() + chunk.len() + "\n\n".len()
);
for line in chunk.split('\n') {
message.extend_from_slice(b"data: ");
message.extend_from_slice(line.as_bytes());
message.push(b'\n');
}
message.push(b'\n');

let size_hex_bytes = ohkami_lib::num::hexized_bytes(message.len());

let mut chunk = Vec::from(&size_hex_bytes[size_hex_bytes.iter().position(|b| *b!=b'0').unwrap()..]);
chunk.extend_from_slice(b"\r\n");
chunk.append(&mut message);
chunk.extend_from_slice(b"\r\n");

#[cfg(feature="DEBUG")]
println!("\n[sending chunk]\n{}", chunk.escape_ascii());

conn.write_all(&chunk).await.expect("Failed to send response");
conn.flush().await.expect("Failed to flush connection");
}
let mut message = Vec::with_capacity(
/* capacity for a single line */
"data: ".len() + chunk.len() + "\n\n".len()
);
for line in chunk.split('\n') {
message.extend_from_slice(b"data: ");
message.extend_from_slice(line.as_bytes());
message.push(b'\n');
}
message.push(b'\n');

let size_hex_bytes = ohkami_lib::num::hexized_bytes(message.len());

let mut chunk = Vec::from(&size_hex_bytes[size_hex_bytes.iter().position(|b| *b!=b'0').unwrap()..]);
chunk.extend_from_slice(b"\r\n");
chunk.append(&mut message);
chunk.extend_from_slice(b"\r\n");

#[cfg(feature="DEBUG")]
println!("\n[sending chunk]\n{}", chunk.escape_ascii());

conn.write_all(&chunk).await.expect("Failed to send response");
conn.flush().await.expect("Failed to flush connection");
}
conn.write_all(b"0\r\n\r\n").await.expect("Failed to send response");
conn.flush().await.expect("Failed to flush connection");
Expand Down
Loading

0 comments on commit b93fc60

Please sign in to comment.