diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 0f254004b..dc333e8be 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -55,6 +55,9 @@ use unicode_segmentation::UnicodeSegmentation; mod unhex; use unhex::spark_unhex; +mod hex; +use hex::spark_hex; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -108,6 +111,10 @@ pub fn create_comet_physical_fun( "make_decimal" => { make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } + "hex" => { + let func = Arc::new(spark_hex); + make_comet_scalar_udf!("hex", func, without data_type) + } "unhex" => { let func = Arc::new(spark_unhex); make_comet_scalar_udf!("unhex", func, without data_type) diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs new file mode 100644 index 000000000..ea572574a --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/hex.rs @@ -0,0 +1,306 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{as_dictionary_array, as_largestring_array, as_string_array}, + datatypes::Int32Type, +}; +use arrow_array::StringArray; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, + exec_err, DataFusionError, +}; +use std::fmt::Write; + +fn hex_int64(num: i64) -> String { + format!("{:X}", num) +} + +fn hex_bytes>(bytes: T) -> Result { + let bytes = bytes.as_ref(); + let length = bytes.len(); + let mut hex_string = String::with_capacity(length * 2); + for &byte in bytes { + write!(&mut hex_string, "{:02X}", byte)?; + } + Ok(hex_string) +} + +pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "hex expects exactly one argument".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + + let hexed_array: StringArray = array.iter().map(|v| v.map(hex_int64)).collect(); + + Ok(ColumnarValue::Array(Arc::new(hexed_array))) + } + DataType::Utf8 => { + let array = as_string_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::LargeUtf8 => { + let array = as_largestring_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Binary => { + let array = as_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Int64) => { + let dict = as_dictionary_array::(&array); + + let hexed_values = as_int64_array(dict.values())?; + let values = hexed_values + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(); + + let keys = dict.keys().clone(); + let mut new_keys = Vec::with_capacity(values.len()); + + for key in keys.iter() { + let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); + new_keys.push(key); + } + + let string_array_values = StringArray::from(new_keys); + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Utf8) => { + let dict = as_dictionary_array::(&array); + + let hexed_values = as_string_array(dict.values()); + let values: Vec> = hexed_values + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + let keys = dict.keys().clone(); + + let mut new_keys = Vec::with_capacity(values.len()); + + for key in keys.iter() { + let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); + new_keys.push(key); + } + + let string_array_values = StringArray::from(new_keys); + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Binary) => { + let dict = as_dictionary_array::(&array); + + let hexed_values = as_binary_array(dict.values())?; + let values: Vec> = hexed_values + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + let keys = dict.keys().clone(); + let mut new_keys = Vec::with_capacity(values.len()); + + for key in keys.iter() { + let key = key.map(|k| values[k as usize].clone()).unwrap_or(None); + new_keys.push(key); + } + + let string_array_values = StringArray::from(new_keys); + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + ), + }, + _ => exec_err!("native hex does not support scalar values at this time"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::{ + array::{ + as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, + StringDictionaryBuilder, + }, + datatypes::{Int32Type, Int64Type}, + }; + use arrow_array::{Int64Array, StringArray}; + use datafusion::logical_expr::ColumnarValue; + + #[test] + fn test_dictionary_hex_utf8() { + let mut input_builder = StringDictionaryBuilder::::new(); + input_builder.append_value("hi"); + input_builder.append_value("bye"); + input_builder.append_null(); + input_builder.append_value("rust"); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("6869"); + string_builder.append_value("627965"); + string_builder.append_null(); + string_builder.append_value("72757374"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_int64() { + let mut input_builder = PrimitiveDictionaryBuilder::::new(); + input_builder.append_value(1); + input_builder.append_value(2); + input_builder.append_null(); + input_builder.append_value(3); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("1"); + string_builder.append_value("2"); + string_builder.append_null(); + string_builder.append_value("3"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_binary() { + let mut input_builder = BinaryDictionaryBuilder::::new(); + input_builder.append_value("1"); + input_builder.append_value("1"); + input_builder.append_null(); + input_builder.append_value("3"); + let input = input_builder.finish(); + + let mut expected_builder = StringBuilder::new(); + expected_builder.append_value("31"); + expected_builder.append_value("31"); + expected_builder.append_null(); + expected_builder.append_value("33"); + let expected = expected_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_hex_int64() { + let num = 1234; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + } + + #[test] + fn test_spark_hex_int64() { + let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let columnar_value = ColumnarValue::Array(Arc::new(int_array)); + + let result = super::spark_hex(&[columnar_value]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let string_array = as_string_array(&result); + let expected_array = StringArray::from(vec![ + Some("1".to_string()), + Some("2".to_string()), + None, + Some("3".to_string()), + ]); + + assert_eq!(string_array, &expected_array); + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index fe60e9ba3..c1188e193 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1509,6 +1509,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) + case Hex(child) => + val childExpr = exprToProtoInternal(child, inputs) + val optExpr = + scalarExprToProtoWithReturnType("hex", StringType, childExpr) + + optExprWithInfo(optExpr, expr, child) + case e: Unhex if !isSpark32 => val unHex = unhexSerde(e) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index d418ec870..e15b09ca2 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1062,6 +1062,22 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("hex") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "hex.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + + withParquetTable(path.toString, "tbl") { + // _9 and _10 (uint8 and uint16) not supported + checkSparkAnswerAndOperator( + "SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl") + } + } + } + } + test("unhex") { // When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that // was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not