Skip to content

Commit

Permalink
feat: add hex scalar function (apache#449)
Browse files Browse the repository at this point in the history
* feat: add hex scalar function

* test: change hex test to use makeParquetFileAllTypes, support more types

* test: add more columns to spark test

* refactor: remove extra rust code

* feat: support dictionary

* fix: simplify hex_int64

* refactor: combine functions for hex byte/string

* refactor: update vec collection

* refactor: refactor hex to support byte ref

* style: fix clippy

* refactor: remove scalar handling

* style: new lines in expression test file

* fix: handle large strings
  • Loading branch information
tshauck authored Jun 3, 2024
1 parent c79bd5c commit a71f68b
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 0 deletions.
7 changes: 7 additions & 0 deletions core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ use unicode_segmentation::UnicodeSegmentation;
mod unhex;
use unhex::spark_unhex;

mod hex;
use hex::spark_hex;

macro_rules! make_comet_scalar_udf {
($name:expr, $func:ident, $data_type:ident) => {{
let scalar_func = CometScalarFunction::new(
Expand Down Expand Up @@ -108,6 +111,10 @@ pub fn create_comet_physical_fun(
"make_decimal" => {
make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
}
"hex" => {
let func = Arc::new(spark_hex);
make_comet_scalar_udf!("hex", func, without data_type)
}
"unhex" => {
let func = Arc::new(spark_unhex);
make_comet_scalar_udf!("unhex", func, without data_type)
Expand Down
306 changes: 306 additions & 0 deletions core/src/execution/datafusion/expressions/scalar_funcs/hex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::{
array::{as_dictionary_array, as_largestring_array, as_string_array},
datatypes::Int32Type,
};
use arrow_array::StringArray;
use arrow_schema::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{
cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array},
exec_err, DataFusionError,
};
use std::fmt::Write;

fn hex_int64(num: i64) -> String {
format!("{:X}", num)
}

fn hex_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<String, std::fmt::Error> {
let bytes = bytes.as_ref();
let length = bytes.len();
let mut hex_string = String::with_capacity(length * 2);
for &byte in bytes {
write!(&mut hex_string, "{:02X}", byte)?;
}
Ok(hex_string)
}

pub(super) fn spark_hex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
if args.len() != 1 {
return Err(DataFusionError::Internal(
"hex expects exactly one argument".to_string(),
));
}

match &args[0] {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Int64 => {
let array = as_int64_array(array)?;

let hexed_array: StringArray = array.iter().map(|v| v.map(hex_int64)).collect();

Ok(ColumnarValue::Array(Arc::new(hexed_array)))
}
DataType::Utf8 => {
let array = as_string_array(array);

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
}
DataType::LargeUtf8 => {
let array = as_largestring_array(array);

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
}
DataType::Binary => {
let array = as_binary_array(array)?;

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
}
DataType::FixedSizeBinary(_) => {
let array = as_fixed_size_binary_array(array)?;

let hexed: StringArray = array
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

Ok(ColumnarValue::Array(Arc::new(hexed)))
}
DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Int64) => {
let dict = as_dictionary_array::<Int32Type>(&array);

let hexed_values = as_int64_array(dict.values())?;
let values = hexed_values
.iter()
.map(|v| v.map(hex_int64))
.collect::<Vec<_>>();

let keys = dict.keys().clone();
let mut new_keys = Vec::with_capacity(values.len());

for key in keys.iter() {
let key = key.map(|k| values[k as usize].clone()).unwrap_or(None);
new_keys.push(key);
}

let string_array_values = StringArray::from(new_keys);
Ok(ColumnarValue::Array(Arc::new(string_array_values)))
}
DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Utf8) => {
let dict = as_dictionary_array::<Int32Type>(&array);

let hexed_values = as_string_array(dict.values());
let values: Vec<Option<String>> = hexed_values
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

let keys = dict.keys().clone();

let mut new_keys = Vec::with_capacity(values.len());

for key in keys.iter() {
let key = key.map(|k| values[k as usize].clone()).unwrap_or(None);
new_keys.push(key);
}

let string_array_values = StringArray::from(new_keys);
Ok(ColumnarValue::Array(Arc::new(string_array_values)))
}
DataType::Dictionary(_, value_type) if matches!(**value_type, DataType::Binary) => {
let dict = as_dictionary_array::<Int32Type>(&array);

let hexed_values = as_binary_array(dict.values())?;
let values: Vec<Option<String>> = hexed_values
.iter()
.map(|v| v.map(hex_bytes).transpose())
.collect::<Result<_, _>>()?;

let keys = dict.keys().clone();
let mut new_keys = Vec::with_capacity(values.len());

for key in keys.iter() {
let key = key.map(|k| values[k as usize].clone()).unwrap_or(None);
new_keys.push(key);
}

let string_array_values = StringArray::from(new_keys);
Ok(ColumnarValue::Array(Arc::new(string_array_values)))
}
_ => exec_err!(
"hex got an unexpected argument type: {:?}",
array.data_type()
),
},
_ => exec_err!("native hex does not support scalar values at this time"),
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use arrow::{
array::{
as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder,
StringDictionaryBuilder,
},
datatypes::{Int32Type, Int64Type},
};
use arrow_array::{Int64Array, StringArray};
use datafusion::logical_expr::ColumnarValue;

#[test]
fn test_dictionary_hex_utf8() {
let mut input_builder = StringDictionaryBuilder::<Int32Type>::new();
input_builder.append_value("hi");
input_builder.append_value("bye");
input_builder.append_null();
input_builder.append_value("rust");
let input = input_builder.finish();

let mut string_builder = StringBuilder::new();
string_builder.append_value("6869");
string_builder.append_value("627965");
string_builder.append_null();
string_builder.append_value("72757374");
let expected = string_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
let result = super::spark_hex(&[columnar_value]).unwrap();

let result = match result {
ColumnarValue::Array(array) => array,
_ => panic!("Expected array"),
};

let result = as_string_array(&result);

assert_eq!(result, &expected);
}

#[test]
fn test_dictionary_hex_int64() {
let mut input_builder = PrimitiveDictionaryBuilder::<Int32Type, Int64Type>::new();
input_builder.append_value(1);
input_builder.append_value(2);
input_builder.append_null();
input_builder.append_value(3);
let input = input_builder.finish();

let mut string_builder = StringBuilder::new();
string_builder.append_value("1");
string_builder.append_value("2");
string_builder.append_null();
string_builder.append_value("3");
let expected = string_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
let result = super::spark_hex(&[columnar_value]).unwrap();

let result = match result {
ColumnarValue::Array(array) => array,
_ => panic!("Expected array"),
};

let result = as_string_array(&result);

assert_eq!(result, &expected);
}

#[test]
fn test_dictionary_hex_binary() {
let mut input_builder = BinaryDictionaryBuilder::<Int32Type>::new();
input_builder.append_value("1");
input_builder.append_value("1");
input_builder.append_null();
input_builder.append_value("3");
let input = input_builder.finish();

let mut expected_builder = StringBuilder::new();
expected_builder.append_value("31");
expected_builder.append_value("31");
expected_builder.append_null();
expected_builder.append_value("33");
let expected = expected_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
let result = super::spark_hex(&[columnar_value]).unwrap();

let result = match result {
ColumnarValue::Array(array) => array,
_ => panic!("Expected array"),
};

let result = as_string_array(&result);

assert_eq!(result, &expected);
}

#[test]
fn test_hex_int64() {
let num = 1234;
let hexed = super::hex_int64(num);
assert_eq!(hexed, "4D2".to_string());

let num = -1;
let hexed = super::hex_int64(num);
assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string());
}

#[test]
fn test_spark_hex_int64() {
let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]);
let columnar_value = ColumnarValue::Array(Arc::new(int_array));

let result = super::spark_hex(&[columnar_value]).unwrap();
let result = match result {
ColumnarValue::Array(array) => array,
_ => panic!("Expected array"),
};

let string_array = as_string_array(&result);
let expected_array = StringArray::from(vec![
Some("1".to_string()),
Some("2".to_string()),
None,
Some("3".to_string()),
]);

assert_eq!(string_array, &expected_array);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)

case Hex(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val optExpr =
scalarExprToProtoWithReturnType("hex", StringType, childExpr)

optExprWithInfo(optExpr, expr, child)

case e: Unhex if !isSpark32 =>
val unHex = unhexSerde(e)

Expand Down
16 changes: 16 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,22 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("hex") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "hex.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)

withParquetTable(path.toString, "tbl") {
// _9 and _10 (uint8 and uint16) not supported
checkSparkAnswerAndOperator(
"SELECT hex(_1), hex(_2), hex(_3), hex(_4), hex(_5), hex(_6), hex(_7), hex(_8), hex(_11), hex(_12), hex(_13), hex(_14), hex(_15), hex(_16), hex(_17), hex(_18), hex(_19), hex(_20) FROM tbl")
}
}
}
}

test("unhex") {
// When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that
// was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not
Expand Down

0 comments on commit a71f68b

Please sign in to comment.