diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index f5839fd4c..7560e0c2d 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -25,6 +25,7 @@ use std::{ use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, + datatypes::TimestampMicrosecondType, record_batch::RecordBatch, util::display::FormatOptions, }; @@ -33,10 +34,12 @@ use arrow_array::{ Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow_schema::{DataType, Schema}; +use chrono::{TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use num::{traits::CheckedNeg, CheckedSub, Integer, Num}; +use regex::Regex; use crate::execution::datafusion::expressions::utils::{ array_with_timezone, down_cast_any_ref, spark_cast, @@ -86,6 +89,24 @@ macro_rules! cast_utf8_to_int { }}; } +macro_rules! cast_utf8_to_timestamp { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC"); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Ok(Some(cast_value)) = $cast_method($array.value(i).trim(), $eval_mode) { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef; + result + }}; +} + impl Cast { pub fn new( child: Arc, @@ -125,6 +146,9 @@ impl Cast { (DataType::LargeUtf8, DataType::Boolean) => { Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } + (DataType::Utf8, DataType::Timestamp(_, _)) => { + Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? + } ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, @@ -200,6 +224,30 @@ impl Cast { Ok(cast_array) } + fn cast_string_to_timestamp( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Timestamp(_, _) => { + cast_utf8_to_timestamp!( + string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ) + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) + } + fn spark_cast_utf8_to_boolean( from: &dyn Array, eval_mode: EvalMode, @@ -510,9 +558,273 @@ impl PhysicalExpr for Cast { } } +fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult> { + let value = value.trim(); + if value.is_empty() { + return Ok(None); + } + // Define regex patterns and corresponding parsing functions + let patterns = &[ + ( + Regex::new(r"^\d{4}$").unwrap(), + parse_str_to_year_timestamp as fn(&str) -> CometResult>, + ), + ( + Regex::new(r"^\d{4}-\d{2}$").unwrap(), + parse_str_to_month_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(), + parse_str_to_day_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(), + parse_str_to_hour_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(), + parse_str_to_minute_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(), + parse_str_to_second_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(), + parse_str_to_microsecond_timestamp, + ), + ( + Regex::new(r"^T\d{1,2}$").unwrap(), + parse_str_to_time_only_timestamp, + ), + ]; + + let mut timestamp = None; + + // Iterate through patterns and try matching + for (pattern, parse_func) in patterns { + if pattern.is_match(value) { + timestamp = parse_func(value)?; + break; + } + } + + if timestamp.is_none() { + if eval_mode == EvalMode::Ansi { + return Err(CometError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }); + } else { + return Ok(None); + } + } + + match timestamp { + Some(ts) => Ok(Some(ts)), + None => Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )), + } +} + +fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> CometResult> { + let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0); + + // Check if datetime is not None + let utc_datetime = match datetime.single() { + Some(dt) => dt.with_timezone(&chrono::Utc), + None => { + return Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(utc_datetime.timestamp_micros())) +} + +fn parse_hms_timestamp( + year: i32, + month: u32, + day: u32, + hour: u32, + minute: u32, + second: u32, + microsecond: u32, +) -> CometResult> { + let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second); + + // Check if datetime is not None + let utc_datetime = match datetime.single() { + Some(dt) => dt + .with_timezone(&chrono::Utc) + .with_nanosecond(microsecond * 1000), + None => { + return Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + let result = match utc_datetime { + Some(dt) => dt.timestamp_micros(), + None => { + return Err(CometError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(result)) +} + +fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult> { + let values: Vec<_> = value + .split(|c| c == 'T' || c == '-' || c == ':' || c == '.') + .collect(); + let year = values[0].parse::().unwrap_or_default(); + let month = values.get(1).map_or(1, |m| m.parse::().unwrap_or(1)); + let day = values.get(2).map_or(1, |d| d.parse::().unwrap_or(1)); + let hour = values.get(3).map_or(0, |h| h.parse::().unwrap_or(0)); + let minute = values.get(4).map_or(0, |m| m.parse::().unwrap_or(0)); + let second = values.get(5).map_or(0, |s| s.parse::().unwrap_or(0)); + let microsecond = values.get(6).map_or(0, |ms| ms.parse::().unwrap_or(0)); + + match timestamp_type { + "year" => parse_ymd_timestamp(year, 1, 1), + "month" => parse_ymd_timestamp(year, month, 1), + "day" => parse_ymd_timestamp(year, month, day), + "hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0), + "minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0), + "second" => parse_hms_timestamp(year, month, day, hour, minute, second, 0), + "microsecond" => parse_hms_timestamp(year, month, day, hour, minute, second, microsecond), + _ => Err(CometError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }), + } +} + +fn parse_str_to_year_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "year") +} + +fn parse_str_to_month_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "month") +} + +fn parse_str_to_day_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "day") +} + +fn parse_str_to_hour_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "hour") +} + +fn parse_str_to_minute_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "minute") +} + +fn parse_str_to_second_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "second") +} + +fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult> { + get_timestamp_values(value, "microsecond") +} + +fn parse_str_to_time_only_timestamp(value: &str) -> CometResult> { + let values: Vec<&str> = value.split('T').collect(); + let time_values: Vec = values[1] + .split(':') + .map(|v| v.parse::().unwrap_or(0)) + .collect(); + + let datetime = chrono::Utc::now(); + let timestamp = datetime + .with_hour(time_values.first().copied().unwrap_or_default()) + .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0))) + .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0))) + .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000)) + .map(|dt| dt.to_utc().timestamp_micros()) + .unwrap_or_default(); + + Ok(Some(timestamp)) +} + #[cfg(test)] -mod test { - use super::{cast_string_to_i8, EvalMode}; +mod tests { + use super::*; + use arrow::datatypes::TimestampMicrosecondType; + use arrow_array::StringArray; + use arrow_schema::TimeUnit; + + #[test] + fn timestamp_parser_test() { + // write for all formats + assert_eq!( + timestamp_parser("2020", EvalMode::Legacy).unwrap(), + Some(1577836800000000) // this is in milliseconds + ); + assert_eq!( + timestamp_parser("2020-01", EvalMode::Legacy).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(), + Some(1577880000000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(), + Some(1577882040000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(), + Some(1577882096000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy).unwrap(), + Some(1577882096123456) + ); + // assert_eq!( + // timestamp_parser("T2", EvalMode::Legacy).unwrap(), + // Some(1714356000000000) // this value needs to change everyday. + // ); + } + + #[test] + fn test_cast_string_to_timestamp() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + ])); + + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let eval_mode = EvalMode::Legacy; + let result = cast_utf8_to_timestamp!( + &string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ); + + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!(result.len(), 2); + } #[test] fn test_cast_string_as_i8() { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6eda0547f..c07b2b3c5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -585,6 +585,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY evalMode.toString } + + val supportedTimezone = (child.dataType, dt) match { + case (DataTypes.StringType, DataTypes.TimestampType) + if !timeZoneId.contains("UTC") => + withInfo(expr, s"Unsupported timezone ${timeZoneId} for timestamp cast") + false + case _ => true + } + val supportedCast = (child.dataType, dt) match { case (DataTypes.StringType, DataTypes.TimestampType) if !CometConf.COMET_CAST_STRING_TO_TIMESTAMP.get() => @@ -593,7 +602,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { false case _ => true } - if (supportedCast) { + + if (supportedCast && supportedTimezone) { castToProto(timeZoneId, dt, childExpr, evalModeStr) } else { // no need to call withInfo here since it was called when determining diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1bddedde9..a31f4e682 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -528,14 +528,37 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { "spark.comet.cast.stringToTimestamp is disabled") } - ignore("cast StringType to TimestampType") { - // https://github.com/apache/datafusion-comet/issues/328 - withSQLConf((CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true")) { - val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ generateStrings(timestampPattern, 8) - castTest(values.toDF("a"), DataTypes.TimestampType) + test("cast StringType to TimestampType") { + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", + CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key -> "true") { + val values = Seq( + "2020", + "2020-01", + "2020-01-01", + "2020-01-01T12", + "2020-01-01T12:34", + "2020-01-01T12:34:56", + "2020-01-01T12:34:56.123456", + "T2", + "-9?") + castTimestampTest(values.toDF("a"), DataTypes.TimestampType) + } + + // test for invalid inputs + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC", + CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key -> "true") { + val values = Seq("-9?", "1-", "0.5") + castTimestampTest(values.toDF("a"), DataTypes.TimestampType) } } + test("cast StringType to TimestampType with invalid timezone") { + val values = Seq("2020-01-01T12:34:56.123456", "T2") + castFallbackTestTimezone(values.toDF("a"), DataTypes.TimestampType, "Unsupported timezone") + } + // CAST from DateType ignore("cast DateType to BooleanType") { @@ -763,6 +786,44 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + private def castFallbackTestTimezone( + input: DataFrame, + toType: DataType, + expectedMessage: String): Unit = { + withTempPath { dir => + val data = roundtripParquet(input, dir).coalesce(1) + data.createOrReplaceTempView("t") + + withSQLConf( + (SQLConf.ANSI_ENABLED.key, "false"), + (CometConf.COMET_CAST_STRING_TO_TIMESTAMP.key, "true"), + (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) { + val df = data.withColumn("converted", col("a").cast(toType)) + df.collect() + val str = + new ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan) + assert(str.contains(expectedMessage)) + } + } + } + + private def castTimestampTest(input: DataFrame, toType: DataType) = { + withTempPath { dir => + val data = roundtripParquet(input, dir).coalesce(1) + data.createOrReplaceTempView("t") + + withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { + // cast() should return null for invalid inputs when ansi mode is disabled + val df = data.withColumn("converted", col("a").cast(toType)) + checkSparkAnswer(df) + + // try_cast() should always return null for invalid inputs + val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") + checkSparkAnswer(df2) + } + } + } + private def castTest(input: DataFrame, toType: DataType): Unit = { withTempPath { dir => val data = roundtripParquet(input, dir).coalesce(1)