Skip to content

Commit

Permalink
feat(core): use custom scheduler to avoid stack overflow
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Chi <[email protected]>
  • Loading branch information
skyzh committed Dec 22, 2024
1 parent d1e27a5 commit 3456640
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 20 deletions.
3 changes: 2 additions & 1 deletion optd-core/src/cascades.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
mod memo;
mod optimizer;
pub mod rule_match;
pub(crate) mod rule_match;
pub(crate) mod scheduler;
mod tasks2;

pub use memo::{Memo, NaiveMemo};
Expand Down
9 changes: 1 addition & 8 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

use std::collections::{BTreeSet, HashMap, HashSet};
use std::fmt::Display;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use anyhow::Result;
Expand Down Expand Up @@ -293,14 +291,9 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
}

pub fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> {
use pollster::FutureExt as _;
trace!(event = "fire_optimize_tasks", root_group_id = %group_id);
let mut task = TaskContext::new(self);
// 32MB stack for the optimization process, TODO: reduce memory footprint
stacker::grow(32 * 1024 * 1024, || {
let fut: Pin<Box<dyn Future<Output = ()>>> = Box::pin(task.fire_optimize(group_id));
fut.block_on();
});
task.fire_optimize(group_id);
Ok(())
}

Expand Down
101 changes: 101 additions & 0 deletions optd-core/src/cascades/scheduler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2023-2024 CMU Database Group
//
// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

//! A single-thread scheduler for the cascades tasks. The tasks are queued in a stack of `Vec` so that
//! we won't overflow the system stack. The cascades task are compute-only and don't have I/O.
use std::{
cell::RefCell,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Wake},
};

struct Task {
// The task to be executed.
inner: Pin<Box<dyn Future<Output = ()> + 'static>>,
}

pub struct Executor {}

impl Wake for Task {
fn wake(self: Arc<Self>) {
unreachable!("cascades tasks shouldn't yield");
}
}

// This needs nightly feature and we use stable Rust, so we had to copy-paste it here. TODO: license

mod optd_futures_task {
use std::{
ptr,
task::{RawWaker, RawWakerVTable, Waker},
};
const NOOP: RawWaker = {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
// Cloning just returns a new no-op raw waker
|_| NOOP,
// `wake` does nothing
|_| {},
// `wake_by_ref` does nothing
|_| {},
// Dropping does nothing as we don't allocate anything
|_| {},
);
RawWaker::new(ptr::null(), &VTABLE)
};

#[inline]
#[must_use]
pub const fn noop() -> &'static Waker {
const WAKER: &Waker = &unsafe { Waker::from_raw(NOOP) };
WAKER
}
}

thread_local! {
pub static OPTD_SCHEDULER_QUEUE: RefCell<Vec<Task>> = RefCell::new(Vec::new());
}

pub fn spawn<F>(task: F)
where
F: Future<Output = ()> + 'static,
{
OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| {
tasks.push(
Task {
inner: Box::pin(task),
}
.into(),
)
});
}

impl Executor {
pub fn new() -> Self {
Executor {}
}

pub fn spawn<F>(&self, task: F)
where
F: Future<Output = ()> + 'static,
{
spawn(task);
}

/// SAFETY: The caller must ensure all futures running on this runtime does not have I/O. Otherwise it will deadloop
/// with all futures pending.
pub fn run(&self) {
let waker = optd_futures_task::noop();
let mut cx: Context<'_> = Context::from_waker(&waker);

while let Some(mut task) = OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| tasks.pop()) {
if task.inner.as_mut().poll(&mut cx).is_pending() {
OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| tasks.push(task))
}
}
}
}
72 changes: 61 additions & 11 deletions optd-core/src/cascades/tasks2.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
// Copyright (c) 2023-2024 CMU Database Group
//
// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

//! The v2 implementation of cascades tasks. The code uses Rust async/await to generate the state machine,
//! so that the logic is much more clear and easier to follow.
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use itertools::Itertools;
use tracing::trace;

use super::memo::MemoPlanNode;
use super::rule_match::match_and_pick_expr;
use super::scheduler::{self, Executor};
use super::{optimizer::RuleId, CascadesOptimizer, ExprId, GroupId, Memo};
use crate::cascades::{
memo::{Winner, WinnerInfo},
Expand All @@ -31,6 +42,12 @@ pub enum TaskDesc {
OptimizeInput(ExprId, GroupId),
}

unsafe fn extend_to_static<'x>(
f: Pin<Box<dyn Future<Output = ()> + 'x>>,
) -> Pin<Box<dyn Future<Output = ()> + 'static>> {
unsafe { std::mem::transmute(f) }
}

impl<'a, T: NodeType, M: Memo<T>> TaskContext<'a, T, M> {
pub fn new(optimizer: &'a mut CascadesOptimizer<T, M>) -> Self {
Self {
Expand All @@ -39,24 +56,45 @@ impl<'a, T: NodeType, M: Memo<T>> TaskContext<'a, T, M> {
}
}

pub async fn fire_optimize(&mut self, group_id: GroupId) {
self.optimize_group(SearchContext {
group_id,
upper_bound: None,
})
.await;
pub fn fire_optimize(&mut self, group_id: GroupId) {
let executor = Executor::new();
executor.spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_group(SearchContext {
group_id,
upper_bound: None,
})) as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
executor.run();
}

async fn optimize_group(&mut self, ctx: SearchContext) {
Box::pin(self.optimize_group_inner(ctx)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_group_inner(ctx)) as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn optimize_expr(&mut self, ctx: SearchContext, expr_id: ExprId, exploring: bool) {
Box::pin(self.optimize_expr_inner(ctx, expr_id, exploring)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_expr_inner(ctx, expr_id, exploring))
as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn explore_group(&mut self, ctx: SearchContext) {
Box::pin(self.explore_group_inner(ctx)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.explore_group_inner(ctx)) as Pin<Box<dyn Future<Output = ()>>>).await
}))
});
}

async fn apply_rule(
Expand All @@ -66,11 +104,23 @@ impl<'a, T: NodeType, M: Memo<T>> TaskContext<'a, T, M> {
expr_id: ExprId,
exploring: bool,
) {
Box::pin(self.apply_rule_inner(ctx, rule_id, expr_id, exploring)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.apply_rule_inner(ctx, rule_id, expr_id, exploring))
as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn optimize_input(&mut self, ctx: SearchContext, expr_id: ExprId) {
Box::pin(self.optimize_input_inner(ctx, expr_id)).await;
scheduler::spawn(unsafe {
extend_to_static(Box::pin(async {
(Box::pin(self.optimize_input_inner(ctx, expr_id))
as Pin<Box<dyn Future<Output = ()>>>)
.await
}))
});
}

async fn optimize_group_inner(&mut self, ctx: SearchContext) {
Expand Down

0 comments on commit 3456640

Please sign in to comment.