-
-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ab389dc
commit 0d23ac1
Showing
30 changed files
with
1,681 additions
and
2,499 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[package] | ||
name = "scuffle-batching" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
tokio = { version = "1", default-features = false, features = ["time", "sync", "rt"] } | ||
tokio-util = "0.7" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
use std::future::Future; | ||
use std::sync::atomic::AtomicU64; | ||
use std::sync::Arc; | ||
|
||
use tokio::sync::oneshot; | ||
|
||
/// A response to a batch request | ||
pub struct BatchResponse<Resp> { | ||
send: oneshot::Sender<Resp>, | ||
} | ||
|
||
impl<Resp> BatchResponse<Resp> { | ||
/// Create a new batch response | ||
#[must_use] | ||
pub fn new(send: oneshot::Sender<Resp>) -> Self { | ||
Self { send } | ||
} | ||
|
||
/// Send a response back to the requester | ||
#[inline(always)] | ||
pub fn send(self, response: Resp) { | ||
let _ = self.send.send(response); | ||
} | ||
|
||
/// Send a successful response back to the requester | ||
#[inline(always)] | ||
pub fn send_ok<O, E>(self, response: O) | ||
where | ||
Resp: From<Result<O, E>>, | ||
{ | ||
self.send(Ok(response).into()) | ||
} | ||
|
||
/// Send an error response back to the requestor | ||
#[inline(always)] | ||
pub fn send_err<O, E>(self, error: E) | ||
where | ||
Resp: From<Result<O, E>>, | ||
{ | ||
self.send(Err(error).into()) | ||
} | ||
} | ||
|
||
/// A trait for executing batches | ||
pub trait BatchExecutor { | ||
/// The incoming request type | ||
type Request: Send + 'static; | ||
/// The outgoing response type | ||
type Response: Send + Sync + 'static; | ||
|
||
/// Execute a batch of requests | ||
/// You must call `send` on the `BatchResponse` to send the response back to | ||
/// the client | ||
fn execute(&self, requests: Vec<(Self::Request, BatchResponse<Self::Response>)>) -> impl Future<Output = ()> + Send; | ||
} | ||
|
||
/// A builder for a [`Batcher`] | ||
#[derive(Clone, Copy, Debug)] | ||
#[must_use = "builders must be used to create a batcher"] | ||
pub struct BatcherBuilder { | ||
batch_size: usize, | ||
delay: std::time::Duration, | ||
} | ||
|
||
impl Default for BatcherBuilder { | ||
#[must_use] | ||
fn default() -> Self { | ||
Self::new() | ||
} | ||
} | ||
|
||
impl BatcherBuilder { | ||
/// Create a new builder | ||
pub fn new() -> Self { | ||
Self { | ||
batch_size: 100, | ||
delay: std::time::Duration::from_millis(50), | ||
} | ||
} | ||
|
||
/// Set the batch size | ||
pub fn batch_size(mut self, batch_size: usize) -> Self { | ||
self.batch_size = batch_size; | ||
self | ||
} | ||
|
||
/// Set the delay | ||
pub fn delay(mut self, delay: std::time::Duration) -> Self { | ||
self.delay = delay; | ||
self | ||
} | ||
|
||
/// Set the batch size | ||
pub fn with_batch_size(&mut self, batch_size: usize) -> &mut Self { | ||
self.batch_size = batch_size; | ||
self | ||
} | ||
|
||
/// Set the delay | ||
pub fn with_delay(&mut self, delay: std::time::Duration) -> &mut Self { | ||
self.delay = delay; | ||
self | ||
} | ||
|
||
/// Build the batcher | ||
pub fn build<E>(self, executor: E) -> Batcher<E> | ||
where | ||
E: BatchExecutor + Send + Sync + 'static, | ||
{ | ||
Batcher::new(executor, self.batch_size, self.delay) | ||
} | ||
} | ||
|
||
/// A batcher used to batch requests to a [`BatchExecutor`] | ||
#[must_use = "batchers must be used to execute batches"] | ||
pub struct Batcher<E> | ||
where | ||
E: BatchExecutor + Send + Sync + 'static, | ||
{ | ||
_auto_spawn: tokio::task::JoinHandle<()>, | ||
executor: Arc<E>, | ||
notify: Arc<tokio::sync::Notify>, | ||
semaphore: Arc<tokio::sync::Semaphore>, | ||
current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>, | ||
batch_size: usize, | ||
batch_id: AtomicU64, | ||
} | ||
|
||
struct Batch<E> | ||
where | ||
E: BatchExecutor + Send + Sync + 'static, | ||
{ | ||
id: u64, | ||
items: Vec<(E::Request, BatchResponse<E::Response>)>, | ||
_ticket: tokio::sync::OwnedSemaphorePermit, | ||
} | ||
|
||
impl<E> Batcher<E> | ||
where | ||
E: BatchExecutor + Send + Sync + 'static, | ||
{ | ||
/// Create a new batcher | ||
pub fn new(executor: E, batch_size: usize, delay: std::time::Duration) -> Self { | ||
let semaphore = Arc::new(tokio::sync::Semaphore::new(batch_size.min(1))); | ||
let notify = Arc::new(tokio::sync::Notify::new()); | ||
let current_batch = Arc::new(tokio::sync::Mutex::new(None)); | ||
let executor = Arc::new(executor); | ||
|
||
let join_handle = tokio::spawn(batch_loop(executor.clone(), current_batch.clone(), notify.clone(), delay)); | ||
|
||
Self { | ||
executor, | ||
_auto_spawn: join_handle, | ||
notify, | ||
semaphore, | ||
current_batch, | ||
batch_size, | ||
batch_id: AtomicU64::new(0), | ||
} | ||
} | ||
|
||
/// Create a builder for a [`Batcher`] | ||
pub fn builder() -> BatcherBuilder { | ||
BatcherBuilder::new() | ||
} | ||
|
||
/// Execute a single request | ||
pub async fn execute(&self, items: E::Request) -> Option<E::Response> { | ||
self.execute_many(std::iter::once(items)).await.pop()? | ||
} | ||
|
||
/// Execute many requests | ||
pub async fn execute_many(&self, items: impl IntoIterator<Item = E::Request>) -> Vec<Option<E::Response>> { | ||
let mut batch = self.current_batch.lock().await; | ||
|
||
let mut responses = Vec::new(); | ||
|
||
for item in items { | ||
if batch.is_none() { | ||
batch.replace( | ||
Batch::new( | ||
self.batch_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed), | ||
self.semaphore.clone(), | ||
) | ||
.await, | ||
); | ||
} | ||
|
||
let batch_mut = batch.as_mut().unwrap(); | ||
let (tx, rx) = oneshot::channel(); | ||
batch_mut.items.push((item, BatchResponse::new(tx))); | ||
responses.push(rx); | ||
|
||
if batch_mut.items.len() >= self.batch_size { | ||
batch.take().unwrap().spawn(self.executor.clone()).await; | ||
self.notify.notify_one(); | ||
} | ||
} | ||
|
||
let mut results = Vec::new(); | ||
for response in responses { | ||
results.push(response.await.ok()); | ||
} | ||
|
||
results | ||
} | ||
} | ||
|
||
async fn batch_loop<E>( | ||
executor: Arc<E>, | ||
current_batch: Arc<tokio::sync::Mutex<Option<Batch<E>>>>, | ||
notify: Arc<tokio::sync::Notify>, | ||
delay: std::time::Duration, | ||
) where | ||
E: BatchExecutor + Send + Sync + 'static, | ||
{ | ||
let mut pending_id = None; | ||
loop { | ||
tokio::time::timeout(delay, notify.notified()).await.ok(); | ||
|
||
let mut batch = current_batch.lock().await; | ||
let Some(batch_id) = batch.as_ref().map(|b| b.id) else { | ||
pending_id = None; | ||
continue; | ||
}; | ||
|
||
if pending_id != Some(batch_id) || batch.as_ref().unwrap().items.is_empty() { | ||
pending_id = Some(batch_id); | ||
continue; | ||
} | ||
|
||
tokio::spawn(batch.take().unwrap().spawn(executor.clone())); | ||
} | ||
} | ||
|
||
impl<E> Batch<E> | ||
where | ||
E: BatchExecutor + Send + Sync + 'static, | ||
{ | ||
async fn new(id: u64, semaphore: Arc<tokio::sync::Semaphore>) -> Self { | ||
Self { | ||
id, | ||
items: Vec::new(), | ||
_ticket: semaphore.acquire_owned().await.unwrap(), | ||
} | ||
} | ||
|
||
async fn spawn(self, executor: Arc<E>) { | ||
executor.execute(self.items).await; | ||
} | ||
} |
Oops, something went wrong.