Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement statement/execution timeout session variable #4792

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/recordbatch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pin-project.workspace = true
serde.workspace = true
serde_json.workspace = true
snafu.workspace = true
tokio.workspace = true

[dev-dependencies]
tokio.workspace = true
9 changes: 9 additions & 0 deletions src/common/recordbatch/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Stream timeout"))]
StreamTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
}

impl ErrorExt for Error {
Expand Down Expand Up @@ -190,6 +197,8 @@ impl ErrorExt for Error {
Error::SchemaConversion { source, .. } | Error::CastVector { source, .. } => {
source.status_code()
}

Error::StreamTimeout { .. } => StatusCode::Cancelled,
}
}

Expand Down
1 change: 1 addition & 0 deletions src/operator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ workspace = true

[dependencies]
api.workspace = true
async-stream.workspace = true
async-trait = "0.1"
catalog.workspace = true
chrono.workspace = true
Expand Down
10 changes: 10 additions & 0 deletions src/operator/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion::parquet;
use datatypes::arrow::error::ArrowError;
use snafu::{Location, Snafu};
use table::metadata::TableType;
use tokio::time::error::Elapsed;

#[derive(Snafu)]
#[snafu(visibility(pub))]
Expand Down Expand Up @@ -777,6 +778,14 @@ pub enum Error {
location: Location,
json: String,
},

#[snafu(display("Canceling statement due to statement timeout"))]
StatementTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: Elapsed,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -924,6 +933,7 @@ impl ErrorExt for Error {
Error::BuildRecordBatch { source, .. } => source.status_code(),

Error::UpgradeCatalogManagerRef { .. } => StatusCode::Internal,
Error::StatementTimeout { .. } => StatusCode::Cancelled,
}
}

Expand Down
93 changes: 88 additions & 5 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ mod show;
mod tql;

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use async_stream::stream;
use catalog::kvbackend::KvBackendCatalogManager;
use catalog::CatalogManagerRef;
use client::RecordBatches;
use client::{OutputData, RecordBatches};
use common_error::ext::BoxedError;
use common_meta::cache::TableRouteCacheRef;
use common_meta::cache_invalidator::CacheInvalidatorRef;
Expand All @@ -39,15 +42,19 @@ use common_meta::key::view_info::{ViewInfoManager, ViewInfoManagerRef};
use common_meta::key::{TableMetadataManager, TableMetadataManagerRef};
use common_meta::kv_backend::KvBackendRef;
use common_query::Output;
use common_recordbatch::error::StreamTimeoutSnafu;
use common_recordbatch::RecordBatchStreamWrapper;
use common_telemetry::tracing;
use common_time::range::TimestampRange;
use common_time::Timestamp;
use datafusion_expr::LogicalPlan;
use futures::stream::{Stream, StreamExt};
use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef};
use query::parser::QueryStatement;
use query::QueryEngineRef;
use session::context::{Channel, QueryContextRef};
use session::table_name::table_idents_to_full_name;
use set::set_query_timeout;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
use sql::statements::set_variables::SetVariables;
Expand All @@ -63,8 +70,8 @@ use table::TableRef;
use self::set::{set_bytea_output, set_datestyle, set_timezone, validate_client_encoding};
use crate::error::{
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu, TableNotFoundSnafu,
UpgradeCatalogManagerRefSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, StatementTimeoutSnafu,
TableMetadataManagerSnafu, TableNotFoundSnafu, UpgradeCatalogManagerRefSnafu,
};
use crate::insert::InserterRef;
use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY};
Expand Down Expand Up @@ -338,6 +345,28 @@ impl StatementExecutor {
"DATESTYLE" => set_datestyle(set_var.value, query_ctx)?,

"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
"MAX_EXECUTION_TIME" => match query_ctx.channel() {
Channel::Mysql => set_query_timeout(set_var.value, query_ctx)?,
Channel::Postgres => {
query_ctx.set_warning(format!("Unsupported set variable {}", var_name))
}
_ => {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail()
}
},
"STATEMENT_TIMEOUT" => {
MichaelScofield marked this conversation as resolved.
Show resolved Hide resolved
if query_ctx.channel() == Channel::Postgres {
set_query_timeout(set_var.value, query_ctx)?
} else {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail();
}
}
_ => {
// for postgres, we give unknown SET statements a warning with
// success, this is prevent the SET call becoming a blocker
Expand Down Expand Up @@ -387,8 +416,19 @@ impl StatementExecutor {

#[tracing::instrument(skip_all)]
async fn plan_exec(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
let timeout = derive_timeout(&stmt, &query_ctx);
match timeout {
Some(timeout) => {
let start = tokio::time::Instant::now();
let output = tokio::time::timeout(timeout, self.plan_exec_inner(stmt, query_ctx))
.await
.context(StatementTimeoutSnafu)?;
// compute remaining timeout
let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default();
Ok(attach_timeout(output?, remaining_timeout))
}
None => self.plan_exec_inner(stmt, query_ctx).await,
}
}

async fn get_table(&self, table_ref: &TableReference<'_>) -> Result<TableRef> {
Expand All @@ -405,6 +445,49 @@ impl StatementExecutor {
table_name: table_ref.to_string(),
})
}

async fn plan_exec_inner(
&self,
stmt: QueryStatement,
query_ctx: QueryContextRef,
) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
}
}

fn attach_timeout(output: Output, mut timeout: Duration) -> Output {
match output.data {
OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
OutputData::Stream(mut stream) => {
let schema = stream.schema();
let s = Box::pin(stream! {
let start = tokio::time::Instant::now();
while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.context(StreamTimeoutSnafu)? {
yield item;
timeout = timeout.checked_sub(tokio::time::Instant::now() - start).unwrap_or(Duration::ZERO);
}
}) as Pin<Box<dyn Stream<Item = _> + Send>>;
let stream = RecordBatchStreamWrapper {
schema,
stream: s,
output_ordering: None,
metrics: Default::default(),
};
Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
}
}
}

/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
/// For MySQL, it applies only to read-only statements.
fn derive_timeout(stmt: &QueryStatement, query_ctx: &QueryContextRef) -> Option<Duration> {
let query_timeout = query_ctx.query_timeout()?;
match (query_ctx.channel(), stmt) {
(Channel::Mysql, QueryStatement::Sql(Statement::Query(_)))
| (Channel::Postgres, QueryStatement::Sql(_)) => Some(query_timeout),
(_, _) => None,
}
}

fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result<CopyTableRequest> {
Expand Down
107 changes: 107 additions & 0 deletions src/operator/src/statement/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::time::Duration;

use common_time::Timezone;
use lazy_static::lazy_static;
use regex::Regex;
use session::context::Channel::Postgres;
use session::context::QueryContextRef;
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use snafu::{ensure, OptionExt, ResultExt};
Expand All @@ -21,6 +26,15 @@ use sql::statements::set_variables::SetVariables;

use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};

lazy_static! {
// Regex rules:
// The string must start with a number (one or more digits).
// The number must be followed by one of the valid time units (ms, s, min, h, d).
// The string must end immediately after the unit, meaning there can be no extra
// characters or spaces after the valid time specification.
static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
}

pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let tz_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timezone find in set variable statement",
Expand Down Expand Up @@ -177,3 +191,96 @@ pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
.set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
Ok(())
}

pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let timeout_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timeout value find in set query timeout statement",
})?;
match timeout_expr {
Expr::Value(Value::Number(timeout, _)) => {
match timeout.parse::<u64>() {
Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
Err(_) => {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail()
}
}
Ok(())
}
// postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
Expr::Value(Value::SingleQuotedString(timeout))
| Expr::Value(Value::DoubleQuotedString(timeout)) => {
if ctx.channel() != Postgres {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail();
}
evenyag marked this conversation as resolved.
Show resolved Hide resolved
let timeout = parse_pg_query_timeout_input(timeout)?;
ctx.set_query_timeout(Duration::from_millis(timeout));
Ok(())
}
expr => NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
expr
),
}
.fail(),
}
}

// support time units in ms, s, min, h, d for postgres protocol.
// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
evenyag marked this conversation as resolved.
Show resolved Hide resolved
match input.parse::<u64>() {
Ok(timeout) => Ok(timeout),
Err(_) => {
if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
let value = captures[1].parse::<u64>().expect("regex failed");
let unit = &captures[2];

match unit {
"ms" => Ok(value),
"s" => Ok(value * 1000),
"min" => Ok(value * 60 * 1000),
"h" => Ok(value * 60 * 60 * 1000),
"d" => Ok(value * 24 * 60 * 60 * 1000),
_ => unreachable!("regex failed"),
}
} else {
NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
input
),
}
.fail()
}
}
}
}

#[cfg(test)]
mod test {
use crate::statement::set::parse_pg_query_timeout_input;

#[test]
fn test_parse_pg_query_timeout_input() {
assert!(parse_pg_query_timeout_input("").is_err());
evenyag marked this conversation as resolved.
Show resolved Hide resolved
assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
assert!(parse_pg_query_timeout_input("3a").is_err());
assert!(parse_pg_query_timeout_input("1.5min").is_err());
assert!(parse_pg_query_timeout_input("ms").is_err());
assert!(parse_pg_query_timeout_input("a").is_err());
assert!(parse_pg_query_timeout_input("-1").is_err());

assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
}
}
19 changes: 18 additions & 1 deletion src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use datatypes::vectors::StringVector;
use object_store::ObjectStore;
use once_cell::sync::Lazy;
use regex::Regex;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContextRef};
pub use show_create_table::create_table_stmt;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::Ident;
Expand Down Expand Up @@ -651,6 +651,23 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result<
let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style();
format!("{}, {}", style, order)
}
"MAX_EXECUTION_TIME" => {
if query_ctx.channel() == Channel::Mysql {
query_ctx.query_timeout_as_millis().to_string()
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
"STATEMENT_TIMEOUT" => {
// Add time units to postgres query timeout display.
if query_ctx.channel() == Channel::Postgres {
let mut timeout = query_ctx.query_timeout_as_millis().to_string();
timeout.push_str("ms");
evenyag marked this conversation as resolved.
Show resolved Hide resolved
timeout
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
_ => return UnsupportedVariableSnafu { name: variable }.fail(),
};
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
Expand Down
Loading