diff --git a/optd-core/src/cascades.rs b/optd-core/src/cascades.rs index e43b48ca..bde7762c 100644 --- a/optd-core/src/cascades.rs +++ b/optd-core/src/cascades.rs @@ -7,7 +7,8 @@ mod memo; mod optimizer; -pub mod rule_match; +pub(crate) mod rule_match; +pub(crate) mod scheduler; mod tasks2; pub use memo::{Memo, NaiveMemo}; diff --git a/optd-core/src/cascades/optimizer.rs b/optd-core/src/cascades/optimizer.rs index ff0b8e8a..bd0f01d6 100644 --- a/optd-core/src/cascades/optimizer.rs +++ b/optd-core/src/cascades/optimizer.rs @@ -5,8 +5,6 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::Display; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; use anyhow::Result; @@ -293,14 +291,9 @@ impl> CascadesOptimizer { } pub fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> { - use pollster::FutureExt as _; trace!(event = "fire_optimize_tasks", root_group_id = %group_id); let mut task = TaskContext::new(self); - // 32MB stack for the optimization process, TODO: reduce memory footprint - stacker::grow(32 * 1024 * 1024, || { - let fut: Pin>> = Box::pin(task.fire_optimize(group_id)); - fut.block_on(); - }); + task.fire_optimize(group_id); Ok(()) } diff --git a/optd-core/src/cascades/scheduler.rs b/optd-core/src/cascades/scheduler.rs new file mode 100644 index 00000000..ea0b1cf4 --- /dev/null +++ b/optd-core/src/cascades/scheduler.rs @@ -0,0 +1,101 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! A single-thread scheduler for the cascades tasks. The tasks are queued in a stack of `Vec` so that +//! we won't overflow the system stack. The cascades task are compute-only and don't have I/O. + +use std::{ + cell::RefCell, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Wake}, +}; + +struct Task { + // The task to be executed. + inner: Pin + 'static>>, +} + +pub struct Executor {} + +impl Wake for Task { + fn wake(self: Arc) { + unreachable!("cascades tasks shouldn't yield"); + } +} + +// This needs nightly feature and we use stable Rust, so we had to copy-paste it here. TODO: license + +mod optd_futures_task { + use std::{ + ptr, + task::{RawWaker, RawWakerVTable, Waker}, + }; + const NOOP: RawWaker = { + const VTABLE: RawWakerVTable = RawWakerVTable::new( + // Cloning just returns a new no-op raw waker + |_| NOOP, + // `wake` does nothing + |_| {}, + // `wake_by_ref` does nothing + |_| {}, + // Dropping does nothing as we don't allocate anything + |_| {}, + ); + RawWaker::new(ptr::null(), &VTABLE) + }; + + #[inline] + #[must_use] + pub const fn noop() -> &'static Waker { + const WAKER: &Waker = &unsafe { Waker::from_raw(NOOP) }; + WAKER + } +} + +thread_local! { + pub static OPTD_SCHEDULER_QUEUE: RefCell> = RefCell::new(Vec::new()); +} + +pub fn spawn(task: F) +where + F: Future + 'static, +{ + OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| { + tasks.push( + Task { + inner: Box::pin(task), + } + .into(), + ) + }); +} + +impl Executor { + pub fn new() -> Self { + Executor {} + } + + pub fn spawn(&self, task: F) + where + F: Future + 'static, + { + spawn(task); + } + + /// SAFETY: The caller must ensure all futures running on this runtime does not have I/O. Otherwise it will deadloop + /// with all futures pending. + pub fn run(&self) { + let waker = optd_futures_task::noop(); + let mut cx: Context<'_> = Context::from_waker(&waker); + + while let Some(mut task) = OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| tasks.pop()) { + if task.inner.as_mut().poll(&mut cx).is_pending() { + OPTD_SCHEDULER_QUEUE.with_borrow_mut(|tasks| tasks.push(task)) + } + } + } +} diff --git a/optd-core/src/cascades/tasks2.rs b/optd-core/src/cascades/tasks2.rs index a1961804..41cf2e7f 100644 --- a/optd-core/src/cascades/tasks2.rs +++ b/optd-core/src/cascades/tasks2.rs @@ -1,3 +1,13 @@ +// Copyright (c) 2023-2024 CMU Database Group +// +// Use of this source code is governed by an MIT-style license that can be found in the LICENSE file or at +// https://opensource.org/licenses/MIT. + +//! The v2 implementation of cascades tasks. The code uses Rust async/await to generate the state machine, +//! so that the logic is much more clear and easier to follow. + +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use itertools::Itertools; @@ -5,6 +15,7 @@ use tracing::trace; use super::memo::MemoPlanNode; use super::rule_match::match_and_pick_expr; +use super::scheduler::{self, Executor}; use super::{optimizer::RuleId, CascadesOptimizer, ExprId, GroupId, Memo}; use crate::cascades::{ memo::{Winner, WinnerInfo}, @@ -31,6 +42,12 @@ pub enum TaskDesc { OptimizeInput(ExprId, GroupId), } +unsafe fn extend_to_static<'x>( + f: Pin + 'x>>, +) -> Pin + 'static>> { + unsafe { std::mem::transmute(f) } +} + impl<'a, T: NodeType, M: Memo> TaskContext<'a, T, M> { pub fn new(optimizer: &'a mut CascadesOptimizer) -> Self { Self { @@ -39,24 +56,45 @@ impl<'a, T: NodeType, M: Memo> TaskContext<'a, T, M> { } } - pub async fn fire_optimize(&mut self, group_id: GroupId) { - self.optimize_group(SearchContext { - group_id, - upper_bound: None, - }) - .await; + pub fn fire_optimize(&mut self, group_id: GroupId) { + let executor = Executor::new(); + executor.spawn(unsafe { + extend_to_static(Box::pin(async { + (Box::pin(self.optimize_group(SearchContext { + group_id, + upper_bound: None, + })) as Pin>>) + .await + })) + }); + executor.run(); } async fn optimize_group(&mut self, ctx: SearchContext) { - Box::pin(self.optimize_group_inner(ctx)).await; + scheduler::spawn(unsafe { + extend_to_static(Box::pin(async { + (Box::pin(self.optimize_group_inner(ctx)) as Pin>>) + .await + })) + }); } async fn optimize_expr(&mut self, ctx: SearchContext, expr_id: ExprId, exploring: bool) { - Box::pin(self.optimize_expr_inner(ctx, expr_id, exploring)).await; + scheduler::spawn(unsafe { + extend_to_static(Box::pin(async { + (Box::pin(self.optimize_expr_inner(ctx, expr_id, exploring)) + as Pin>>) + .await + })) + }); } async fn explore_group(&mut self, ctx: SearchContext) { - Box::pin(self.explore_group_inner(ctx)).await; + scheduler::spawn(unsafe { + extend_to_static(Box::pin(async { + (Box::pin(self.explore_group_inner(ctx)) as Pin>>).await + })) + }); } async fn apply_rule( @@ -66,11 +104,23 @@ impl<'a, T: NodeType, M: Memo> TaskContext<'a, T, M> { expr_id: ExprId, exploring: bool, ) { - Box::pin(self.apply_rule_inner(ctx, rule_id, expr_id, exploring)).await; + scheduler::spawn(unsafe { + extend_to_static(Box::pin(async { + (Box::pin(self.apply_rule_inner(ctx, rule_id, expr_id, exploring)) + as Pin>>) + .await + })) + }); } async fn optimize_input(&mut self, ctx: SearchContext, expr_id: ExprId) { - Box::pin(self.optimize_input_inner(ctx, expr_id)).await; + scheduler::spawn(unsafe { + extend_to_static(Box::pin(async { + (Box::pin(self.optimize_input_inner(ctx, expr_id)) + as Pin>>) + .await + })) + }); } async fn optimize_group_inner(&mut self, ctx: SearchContext) {