diff --git a/Cargo.lock b/Cargo.lock index 3f60831975a7..7e9dd2720c67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7714,6 +7714,7 @@ name = "operator" version = "0.9.5" dependencies = [ "api", + "async-stream", "async-trait", "catalog", "chrono", diff --git a/src/common/recordbatch/Cargo.toml b/src/common/recordbatch/Cargo.toml index d82b445c8ee9..5c3d9fa55058 100644 --- a/src/common/recordbatch/Cargo.toml +++ b/src/common/recordbatch/Cargo.toml @@ -20,6 +20,7 @@ pin-project.workspace = true serde.workspace = true serde_json.workspace = true snafu.workspace = true +tokio.workspace = true [dev-dependencies] tokio.workspace = true diff --git a/src/common/recordbatch/src/error.rs b/src/common/recordbatch/src/error.rs index 3eb90b05e765..6e038d1b7e70 100644 --- a/src/common/recordbatch/src/error.rs +++ b/src/common/recordbatch/src/error.rs @@ -161,6 +161,13 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + #[snafu(display("Stream timeout"))] + StreamTimeout { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: tokio::time::error::Elapsed, + }, } impl ErrorExt for Error { @@ -190,6 +197,8 @@ impl ErrorExt for Error { Error::SchemaConversion { source, .. } | Error::CastVector { source, .. } => { source.status_code() } + + Error::StreamTimeout { .. } => StatusCode::Cancelled, } } diff --git a/src/operator/Cargo.toml b/src/operator/Cargo.toml index 5d3d18f8aaa5..d20034155f1b 100644 --- a/src/operator/Cargo.toml +++ b/src/operator/Cargo.toml @@ -12,6 +12,7 @@ workspace = true [dependencies] api.workspace = true +async-stream.workspace = true async-trait = "0.1" catalog.workspace = true chrono.workspace = true diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index 15b4e4e15bee..48bc7a81c221 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -23,6 +23,7 @@ use datafusion::parquet; use datatypes::arrow::error::ArrowError; use snafu::{Location, Snafu}; use table::metadata::TableType; +use tokio::time::error::Elapsed; #[derive(Snafu)] #[snafu(visibility(pub))] @@ -777,6 +778,14 @@ pub enum Error { location: Location, json: String, }, + + #[snafu(display("Canceling statement due to statement timeout"))] + StatementTimeout { + #[snafu(implicit)] + location: Location, + #[snafu(source)] + error: Elapsed, + }, } pub type Result = std::result::Result; @@ -924,6 +933,7 @@ impl ErrorExt for Error { Error::BuildRecordBatch { source, .. } => source.status_code(), Error::UpgradeCatalogManagerRef { .. } => StatusCode::Internal, + Error::StatementTimeout { .. } => StatusCode::Cancelled, } } diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 53b1eaf6ea5b..271e1b75e0fb 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -24,11 +24,14 @@ mod show; mod tql; use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; +use std::time::Duration; +use async_stream::stream; use catalog::kvbackend::KvBackendCatalogManager; use catalog::CatalogManagerRef; -use client::RecordBatches; +use client::{OutputData, RecordBatches}; use common_error::ext::BoxedError; use common_meta::cache::TableRouteCacheRef; use common_meta::cache_invalidator::CacheInvalidatorRef; @@ -39,15 +42,19 @@ use common_meta::key::view_info::{ViewInfoManager, ViewInfoManagerRef}; use common_meta::key::{TableMetadataManager, TableMetadataManagerRef}; use common_meta::kv_backend::KvBackendRef; use common_query::Output; +use common_recordbatch::error::StreamTimeoutSnafu; +use common_recordbatch::RecordBatchStreamWrapper; use common_telemetry::tracing; use common_time::range::TimestampRange; use common_time::Timestamp; use datafusion_expr::LogicalPlan; +use futures::stream::{Stream, StreamExt}; use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef}; use query::parser::QueryStatement; use query::QueryEngineRef; use session::context::{Channel, QueryContextRef}; use session::table_name::table_idents_to_full_name; +use set::set_query_timeout; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument}; use sql::statements::set_variables::SetVariables; @@ -63,8 +70,8 @@ use table::TableRef; use self::set::{set_bytea_output, set_datestyle, set_timezone, validate_client_encoding}; use crate::error::{ self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu, - PlanStatementSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu, TableNotFoundSnafu, - UpgradeCatalogManagerRefSnafu, + PlanStatementSnafu, Result, SchemaNotFoundSnafu, StatementTimeoutSnafu, + TableMetadataManagerSnafu, TableNotFoundSnafu, UpgradeCatalogManagerRefSnafu, }; use crate::insert::InserterRef; use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY}; @@ -338,6 +345,28 @@ impl StatementExecutor { "DATESTYLE" => set_datestyle(set_var.value, query_ctx)?, "CLIENT_ENCODING" => validate_client_encoding(set_var)?, + "MAX_EXECUTION_TIME" => match query_ctx.channel() { + Channel::Mysql => set_query_timeout(set_var.value, query_ctx)?, + Channel::Postgres => { + query_ctx.set_warning(format!("Unsupported set variable {}", var_name)) + } + _ => { + return NotSupportedSnafu { + feat: format!("Unsupported set variable {}", var_name), + } + .fail() + } + }, + "STATEMENT_TIMEOUT" => { + if query_ctx.channel() == Channel::Postgres { + set_query_timeout(set_var.value, query_ctx)? + } else { + return NotSupportedSnafu { + feat: format!("Unsupported set variable {}", var_name), + } + .fail(); + } + } _ => { // for postgres, we give unknown SET statements a warning with // success, this is prevent the SET call becoming a blocker @@ -387,8 +416,19 @@ impl StatementExecutor { #[tracing::instrument(skip_all)] async fn plan_exec(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result { - let plan = self.plan(&stmt, query_ctx.clone()).await?; - self.exec_plan(plan, query_ctx).await + let timeout = derive_timeout(&stmt, &query_ctx); + match timeout { + Some(timeout) => { + let start = tokio::time::Instant::now(); + let output = tokio::time::timeout(timeout, self.plan_exec_inner(stmt, query_ctx)) + .await + .context(StatementTimeoutSnafu)?; + // compute remaining timeout + let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default(); + Ok(attach_timeout(output?, remaining_timeout)) + } + None => self.plan_exec_inner(stmt, query_ctx).await, + } } async fn get_table(&self, table_ref: &TableReference<'_>) -> Result { @@ -405,6 +445,49 @@ impl StatementExecutor { table_name: table_ref.to_string(), }) } + + async fn plan_exec_inner( + &self, + stmt: QueryStatement, + query_ctx: QueryContextRef, + ) -> Result { + let plan = self.plan(&stmt, query_ctx.clone()).await?; + self.exec_plan(plan, query_ctx).await + } +} + +fn attach_timeout(output: Output, mut timeout: Duration) -> Output { + match output.data { + OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output, + OutputData::Stream(mut stream) => { + let schema = stream.schema(); + let s = Box::pin(stream! { + let start = tokio::time::Instant::now(); + while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.context(StreamTimeoutSnafu)? { + yield item; + timeout = timeout.checked_sub(tokio::time::Instant::now() - start).unwrap_or(Duration::ZERO); + } + }) as Pin + Send>>; + let stream = RecordBatchStreamWrapper { + schema, + stream: s, + output_ordering: None, + metrics: Default::default(), + }; + Output::new(OutputData::Stream(Box::pin(stream)), output.meta) + } + } +} + +/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements. +/// For MySQL, it applies only to read-only statements. +fn derive_timeout(stmt: &QueryStatement, query_ctx: &QueryContextRef) -> Option { + let query_timeout = query_ctx.query_timeout()?; + match (query_ctx.channel(), stmt) { + (Channel::Mysql, QueryStatement::Sql(Statement::Query(_))) + | (Channel::Postgres, QueryStatement::Sql(_)) => Some(query_timeout), + (_, _) => None, + } } fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result { diff --git a/src/operator/src/statement/set.rs b/src/operator/src/statement/set.rs index 6436f136d9c5..7b26b7f794d2 100644 --- a/src/operator/src/statement/set.rs +++ b/src/operator/src/statement/set.rs @@ -12,7 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::time::Duration; + use common_time::Timezone; +use lazy_static::lazy_static; +use regex::Regex; +use session::context::Channel::Postgres; use session::context::QueryContextRef; use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle}; use snafu::{ensure, OptionExt, ResultExt}; @@ -21,6 +26,15 @@ use sql::statements::set_variables::SetVariables; use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result}; +lazy_static! { + // Regex rules: + // The string must start with a number (one or more digits). + // The number must be followed by one of the valid time units (ms, s, min, h, d). + // The string must end immediately after the unit, meaning there can be no extra + // characters or spaces after the valid time specification. + static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap(); +} + pub fn set_timezone(exprs: Vec, ctx: QueryContextRef) -> Result<()> { let tz_expr = exprs.first().context(NotSupportedSnafu { feat: "No timezone find in set variable statement", @@ -177,3 +191,96 @@ pub fn set_datestyle(exprs: Vec, ctx: QueryContextRef) -> Result<()> { .set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order)); Ok(()) } + +pub fn set_query_timeout(exprs: Vec, ctx: QueryContextRef) -> Result<()> { + let timeout_expr = exprs.first().context(NotSupportedSnafu { + feat: "No timeout value find in set query timeout statement", + })?; + match timeout_expr { + Expr::Value(Value::Number(timeout, _)) => { + match timeout.parse::() { + Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)), + Err(_) => { + return NotSupportedSnafu { + feat: format!("Invalid timeout expr {} in set variable statement", timeout), + } + .fail() + } + } + Ok(()) + } + // postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms'; + Expr::Value(Value::SingleQuotedString(timeout)) + | Expr::Value(Value::DoubleQuotedString(timeout)) => { + if ctx.channel() != Postgres { + return NotSupportedSnafu { + feat: format!("Invalid timeout expr {} in set variable statement", timeout), + } + .fail(); + } + let timeout = parse_pg_query_timeout_input(timeout)?; + ctx.set_query_timeout(Duration::from_millis(timeout)); + Ok(()) + } + expr => NotSupportedSnafu { + feat: format!( + "Unsupported timeout expr {} in set variable statement", + expr + ), + } + .fail(), + } +} + +// support time units in ms, s, min, h, d for postgres protocol. +// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days). +fn parse_pg_query_timeout_input(input: &str) -> Result { + match input.parse::() { + Ok(timeout) => Ok(timeout), + Err(_) => { + if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) { + let value = captures[1].parse::().expect("regex failed"); + let unit = &captures[2]; + + match unit { + "ms" => Ok(value), + "s" => Ok(value * 1000), + "min" => Ok(value * 60 * 1000), + "h" => Ok(value * 60 * 60 * 1000), + "d" => Ok(value * 24 * 60 * 60 * 1000), + _ => unreachable!("regex failed"), + } + } else { + NotSupportedSnafu { + feat: format!( + "Unsupported timeout expr {} in set variable statement", + input + ), + } + .fail() + } + } + } +} + +#[cfg(test)] +mod test { + use crate::statement::set::parse_pg_query_timeout_input; + + #[test] + fn test_parse_pg_query_timeout_input() { + assert!(parse_pg_query_timeout_input("").is_err()); + assert!(parse_pg_query_timeout_input(" 50 ms").is_err()); + assert!(parse_pg_query_timeout_input("5s 1ms").is_err()); + assert!(parse_pg_query_timeout_input("3a").is_err()); + assert!(parse_pg_query_timeout_input("1.5min").is_err()); + assert!(parse_pg_query_timeout_input("ms").is_err()); + assert!(parse_pg_query_timeout_input("a").is_err()); + assert!(parse_pg_query_timeout_input("-1").is_err()); + + assert_eq!(50, parse_pg_query_timeout_input("50").unwrap()); + assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap()); + assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap()); + assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap()); + } +} diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index 5679cd5dc43d..172961d50a1f 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -48,7 +48,7 @@ use datatypes::vectors::StringVector; use object_store::ObjectStore; use once_cell::sync::Lazy; use regex::Regex; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContextRef}; pub use show_create_table::create_table_stmt; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::Ident; @@ -651,6 +651,23 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result< let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style(); format!("{}, {}", style, order) } + "MAX_EXECUTION_TIME" => { + if query_ctx.channel() == Channel::Mysql { + query_ctx.query_timeout_as_millis().to_string() + } else { + return UnsupportedVariableSnafu { name: variable }.fail(); + } + } + "STATEMENT_TIMEOUT" => { + // Add time units to postgres query timeout display. + if query_ctx.channel() == Channel::Postgres { + let mut timeout = query_ctx.query_timeout_as_millis().to_string(); + timeout.push_str("ms"); + timeout + } else { + return UnsupportedVariableSnafu { name: variable }.fail(); + } + } _ => return UnsupportedVariableSnafu { name: variable }.fail(), }; let schema = Arc::new(Schema::new(vec![ColumnSchema::new( diff --git a/src/session/src/context.rs b/src/session/src/context.rs index f85a8ceea313..cab351176b21 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; +use std::time::Duration; use api::v1::region::RegionRequestHeader; use arc_swap::ArcSwap; @@ -282,6 +283,22 @@ impl QueryContext { pub fn set_warning(&self, msg: String) { self.mutable_query_context_data.write().unwrap().warning = Some(msg); } + + pub fn query_timeout(&self) -> Option { + self.mutable_session_data.read().unwrap().query_timeout + } + + pub fn query_timeout_as_millis(&self) -> u128 { + let timeout = self.mutable_session_data.read().unwrap().query_timeout; + if let Some(t) = timeout { + return t.as_millis(); + } + 0 + } + + pub fn set_query_timeout(&self, timeout: Duration) { + self.mutable_session_data.write().unwrap().query_timeout = Some(timeout); + } } impl QueryContextBuilder { diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index 33bd140c7057..5ddaae7eb579 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -18,6 +18,7 @@ pub mod table_name; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; +use std::time::Duration; use auth::UserInfoRef; use common_catalog::build_db_string; @@ -45,6 +46,7 @@ pub(crate) struct MutableInner { schema: String, user_info: UserInfoRef, timezone: Timezone, + query_timeout: Option, } impl Default for MutableInner { @@ -53,6 +55,7 @@ impl Default for MutableInner { schema: DEFAULT_SCHEMA_NAME.into(), user_info: auth::userinfo_by_name(None), timezone: get_timezone(None).clone(), + query_timeout: None, } } } diff --git a/src/sql/src/parsers/set_var_parser.rs b/src/sql/src/parsers/set_var_parser.rs index e2a7db9d08a2..8a66269803cc 100644 --- a/src/sql/src/parsers/set_var_parser.rs +++ b/src/sql/src/parsers/set_var_parser.rs @@ -58,47 +58,83 @@ mod tests { use crate::dialect::GreptimeDbDialect; use crate::parser::ParseOptions; - fn assert_mysql_parse_result(sql: &str) { + fn assert_mysql_parse_result(sql: &str, indent_str: &str, expr: Expr) { let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); let mut stmts = result.unwrap(); assert_eq!( stmts.pop().unwrap(), Statement::SetVariables(SetVariables { - variable: ObjectName(vec![Ident::new("time_zone")]), - value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))] + variable: ObjectName(vec![Ident::new(indent_str)]), + value: vec![expr] }) ); } - fn assert_pg_parse_result(sql: &str) { + fn assert_pg_parse_result(sql: &str, indent: &str, expr: Expr) { let result = ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); let mut stmts = result.unwrap(); assert_eq!( stmts.pop().unwrap(), Statement::SetVariables(SetVariables { - variable: ObjectName(vec![Ident::new("TIMEZONE")]), - value: vec![Expr::Value(Value::SingleQuotedString("UTC".to_string()))], + variable: ObjectName(vec![Ident::new(indent)]), + value: vec![expr], }) ); } #[test] pub fn test_set_timezone() { + let expected_utc_expr = Expr::Value(Value::SingleQuotedString("UTC".to_string())); // mysql style let sql = "SET time_zone = 'UTC'"; - assert_mysql_parse_result(sql); + assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone()); // session or local style let sql = "SET LOCAL time_zone = 'UTC'"; - assert_mysql_parse_result(sql); + assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone()); let sql = "SET SESSION time_zone = 'UTC'"; - assert_mysql_parse_result(sql); + assert_mysql_parse_result(sql, "time_zone", expected_utc_expr.clone()); // postgresql style let sql = "SET TIMEZONE TO 'UTC'"; - assert_pg_parse_result(sql); + assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr.clone()); let sql = "SET TIMEZONE 'UTC'"; - assert_pg_parse_result(sql); + assert_pg_parse_result(sql, "TIMEZONE", expected_utc_expr); + } + + #[test] + pub fn test_set_query_timeout() { + let expected_query_timeout_expr = Expr::Value(Value::Number("5000".to_string(), false)); + // mysql style + let sql = "SET MAX_EXECUTION_TIME = 5000"; + assert_mysql_parse_result( + sql, + "MAX_EXECUTION_TIME", + expected_query_timeout_expr.clone(), + ); + // session or local style + let sql = "SET LOCAL MAX_EXECUTION_TIME = 5000"; + assert_mysql_parse_result( + sql, + "MAX_EXECUTION_TIME", + expected_query_timeout_expr.clone(), + ); + let sql = "SET SESSION MAX_EXECUTION_TIME = 5000"; + assert_mysql_parse_result( + sql, + "MAX_EXECUTION_TIME", + expected_query_timeout_expr.clone(), + ); + + // postgresql style + let sql = "SET STATEMENT_TIMEOUT = 5000"; + assert_pg_parse_result( + sql, + "STATEMENT_TIMEOUT", + expected_query_timeout_expr.clone(), + ); + let sql = "SET STATEMENT_TIMEOUT TO 5000"; + assert_pg_parse_result(sql, "STATEMENT_TIMEOUT", expected_query_timeout_expr); } } diff --git a/tests/cases/standalone/common/basic.result b/tests/cases/standalone/common/basic.result index 2651bc733cac..a7a1dfb5c015 100644 --- a/tests/cases/standalone/common/basic.result +++ b/tests/cases/standalone/common/basic.result @@ -179,3 +179,22 @@ DROP TABLE foo; Affected Rows: 0 +-- SQLNESS PROTOCOL MYSQL +SET MAX_EXECUTION_TIME = 2000; + +affected_rows: 0 + +-- SQLNESS PROTOCOL MYSQL +SHOW VARIABLES MAX_EXECUTION_TIME; + ++---------------+-------+ +| Variable_name | Value | ++---------------+-------+ +| | | ++---------------+-------+ + +-- SQLNESS PROTOCOL MYSQL +SET MAX_EXECUTION_TIME = 0; + +affected_rows: 0 + diff --git a/tests/cases/standalone/common/basic.sql b/tests/cases/standalone/common/basic.sql index 4c4065874256..13a7d5a1c4c0 100644 --- a/tests/cases/standalone/common/basic.sql +++ b/tests/cases/standalone/common/basic.sql @@ -72,3 +72,12 @@ DROP TABLE phy; DROP TABLE system_metrics; DROP TABLE foo; + +-- SQLNESS PROTOCOL MYSQL +SET MAX_EXECUTION_TIME = 2000; + +-- SQLNESS PROTOCOL MYSQL +SHOW VARIABLES MAX_EXECUTION_TIME; + +-- SQLNESS PROTOCOL MYSQL +SET MAX_EXECUTION_TIME = 0;