From 8750c2a4cde2c34bf7637b0a1715022ccda94dea Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 May 2024 01:38:35 -0600 Subject: [PATCH] feat: Implement Spark-compatible CAST from string to integral types (#307) --- core/Cargo.toml | 3 + core/benches/cast.rs | 85 +++++ .../execution/datafusion/expressions/cast.rs | 326 +++++++++++++++++- core/src/execution/datafusion/mod.rs | 2 +- .../org/apache/comet/CometCastSuite.scala | 81 ++++- 5 files changed, 475 insertions(+), 22 deletions(-) create mode 100644 core/benches/cast.rs diff --git a/core/Cargo.toml b/core/Cargo.toml index b09b0ea7f5..cbca7f629f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -118,3 +118,6 @@ harness = false name = "row_columnar" harness = false +[[bench]] +name = "cast" +harness = false diff --git a/core/benches/cast.rs b/core/benches/cast.rs new file mode 100644 index 0000000000..281fe82e23 --- /dev/null +++ b/core/benches/cast.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::StringBuilder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else if i % 2 == 0 { + b.append_value(format!("{}", rand::random::())); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + let expr = Arc::new(Column::new("a", 0)); + let timezone = "".to_string(); + let cast_string_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i32 = Cast::new( + expr.clone(), + DataType::Int32, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + + let mut group = c.benchmark_group("cast"); + group.bench_function("cast_string_to_i8", |b| { + b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i16", |b| { + b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i32", |b| { + b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i64", |b| { + b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap()); + }); +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index 10079855dd..f5839fd4cd 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -28,11 +28,15 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; -use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait}; +use arrow_array::{ + types::{Int16Type, Int32Type, Int64Type, Int8Type}, + Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; use arrow_schema::{DataType, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; +use num::{traits::CheckedNeg, CheckedSub, Integer, Num}; use crate::execution::datafusion::expressions::utils::{ array_with_timezone, down_cast_any_ref, spark_cast, @@ -64,6 +68,24 @@ pub struct Cast { pub timezone: String, } +macro_rules! cast_utf8_to_int { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: CometResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); + result + }}; +} + impl Cast { pub fn new( child: Arc, @@ -103,10 +125,79 @@ impl Cast { (DataType::LargeUtf8, DataType::Boolean) => { Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode)? } - _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?, + ( + DataType::Utf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, + ( + DataType::LargeUtf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode)?, + ( + DataType::Dictionary(key_type, value_type), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + // TODO: we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code + match value_type.as_ref() { + DataType::Utf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + DataType::LargeUtf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + dt => unreachable!( + "{}", + format!("invalid value type {dt} for dictionary-encoded string array") + ), + } + } + _ => { + // when we have no Spark-specific casting we delegate to DataFusion + cast_with_options(&array, to_type, &CAST_OPTIONS)? + } + }; + Ok(spark_cast(cast_result, from_type, to_type)) + } + + fn cast_string_to_int( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("cast_string_to_int expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => { + cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + } + DataType::Int16 => { + cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + } + DataType::Int32 => { + cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + } + DataType::Int64 => { + cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? + } + dt => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), }; - let result = spark_cast(cast_result, from_type, to_type); - Ok(result) + Ok(cast_array) } fn spark_cast_utf8_to_boolean( @@ -142,6 +233,202 @@ impl Cast { } } +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "TINYINT", + i8::MIN as i32, + i8::MAX as i32, + )? + .map(|v| v as i8)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "SMALLINT", + i16::MIN as i32, + i16::MAX as i32, + )? + .map(|v| v as i16)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult> { + do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult> { + do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) +} + +fn cast_string_to_int_with_range_check( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min: i32, + max: i32, +) -> CometResult> { + match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { + None => Ok(None), + Some(v) if v >= min && v <= max => Ok(Some(v)), + _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +#[derive(PartialEq)] +enum State { + SkipLeadingWhiteSpace, + SkipTrailingWhiteSpace, + ParseSignAndDigits, + ParseFractionalDigits, +} + +/// Equivalent to +/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) +/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) +fn do_cast_string_to_int< + T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, +>( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: T, +) -> CometResult> { + let len = str.len(); + if str.is_empty() { + return none_or_err(eval_mode, type_name, str); + } + + let mut result: T = T::zero(); + let mut negative = false; + let radix = T::from(10); + let stop_value = min_value / radix; + let mut state = State::SkipLeadingWhiteSpace; + let mut parsed_sign = false; + + for (i, ch) in str.char_indices() { + // skip leading whitespace + if state == State::SkipLeadingWhiteSpace { + if ch.is_whitespace() { + // consume this char + continue; + } + // change state and fall through to next section + state = State::ParseSignAndDigits; + } + + if state == State::ParseSignAndDigits { + if !parsed_sign { + negative = ch == '-'; + let positive = ch == '+'; + parsed_sign = true; + if negative || positive { + if i + 1 == len { + // input string is just "+" or "-" + return none_or_err(eval_mode, type_name, str); + } + // consume this char + continue; + } + } + + if ch == '.' { + if eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + state = State::ParseFractionalDigits; + continue; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + let digit = if ch.is_ascii_digit() { + (ch as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + + // We are going to process the new digit and accumulate the result. However, before + // doing this, if the result is already smaller than the + // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be + // smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / + // radix), we can just use `result > 0` to check overflow. If result + // overflows, we should stop + let v = result * radix; + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } + } + + if state == State::ParseFractionalDigits { + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + if ch.is_whitespace() { + // finished parsing fractional digits, now need to skip trailing whitespace + state = State::SkipTrailingWhiteSpace; + // consume this char + continue; + } + if !ch.is_ascii_digit() { + return none_or_err(eval_mode, type_name, str); + } + } + + // skip trailing whitespace + if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() { + return none_or_err(eval_mode, type_name, str); + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return none_or_err(eval_mode, type_name, str); + } + result = neg; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + Ok(Some(result)) +} + +/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode +#[inline] +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult> { + match eval_mode { + EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +#[inline] +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { + CometError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -222,3 +509,34 @@ impl PhysicalExpr for Cast { self.hash(&mut s); } } + +#[cfg(test)] +mod test { + use super::{cast_string_to_i8, EvalMode}; + + #[test] + fn test_cast_string_as_i8() { + // basic + assert_eq!( + cast_string_to_i8("127", EvalMode::Legacy).unwrap(), + Some(127_i8) + ); + assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None); + assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err()); + // decimals + assert_eq!( + cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + assert_eq!( + cast_string_to_i8(".", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + // TRY should always return null for decimals + assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None); + assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None); + // ANSI mode should throw error on decimal + assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err()); + assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err()); + } +} diff --git a/core/src/execution/datafusion/mod.rs b/core/src/execution/datafusion/mod.rs index c464eeed0b..76f0b1c760 100644 --- a/core/src/execution/datafusion/mod.rs +++ b/core/src/execution/datafusion/mod.rs @@ -17,7 +17,7 @@ //! Native execution through DataFusion -mod expressions; +pub mod expressions; mod operators; pub mod planner; pub(crate) mod shuffle_writer; diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index c6a7c72232..1bddedde92 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -40,7 +40,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // but this is likely a reasonable starting point for now private val whitespaceChars = " \t\r\n" - private val numericPattern = "0123456789e+-." + whitespaceChars + /** + * We use these characters to construct strings that potentially represent valid numbers such as + * `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such as `+e.-d`. + */ + private val numericPattern = "0123456789deEf+-." + whitespaceChars + private val datePattern = "0123456789/" + whitespaceChars private val timestampPattern = "0123456789/:T" + whitespaceChars @@ -433,23 +438,64 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(testValues, DataTypes.BooleanType) } - ignore("cast StringType to ByteType") { - // https://github.com/apache/datafusion-comet/issues/15 - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType) - } - - ignore("cast StringType to ShortType") { - // https://github.com/apache/datafusion-comet/issues/15 - castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType) - } - - ignore("cast StringType to IntegerType") { - // https://github.com/apache/datafusion-comet/issues/15 + private val castStringToIntegralInputs: Seq[String] = Seq( + "", + ".", + "+", + "-", + "+.", + "-.", + "-0", + "+1", + "-1", + ".2", + "-.2", + "1e1", + "1.1d", + "1.1f", + Byte.MinValue.toString, + (Byte.MinValue.toShort - 1).toString, + Byte.MaxValue.toString, + (Byte.MaxValue.toShort + 1).toString, + Short.MinValue.toString, + (Short.MinValue.toInt - 1).toString, + Short.MaxValue.toString, + (Short.MaxValue.toInt + 1).toString, + Int.MinValue.toString, + (Int.MinValue.toLong - 1).toString, + Int.MaxValue.toString, + (Int.MaxValue.toLong + 1).toString, + Long.MinValue.toString, + Long.MaxValue.toString, + "-9223372036854775809", // Long.MinValue -1 + "9223372036854775808" // Long.MaxValue + 1 + ) + + test("cast StringType to ByteType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType) + // fuzz test + castTest(generateStrings(numericPattern, 4).toDF("a"), DataTypes.ByteType) + } + + test("cast StringType to ShortType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType) + // fuzz test + castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType) + } + + test("cast StringType to IntegerType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType) + // fuzz test castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.IntegerType) } - ignore("cast StringType to LongType") { - // https://github.com/apache/datafusion-comet/issues/15 + test("cast StringType to LongType") { + // test with hand-picked values + castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType) + // fuzz test castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType) } @@ -724,11 +770,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { // cast() should return null for invalid inputs when ansi mode is disabled - val df = data.withColumn("converted", col("a").cast(toType)) + val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order by a") checkSparkAnswer(df) // try_cast() should always return null for invalid inputs - val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t") + val df2 = + spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") checkSparkAnswer(df2) }