Skip to content

Commit

Permalink
fix: Cast string to boolean not compatible with Spark (apache#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
erenavsarogullari authored Feb 26, 2024
1 parent 749731b commit 96dfccf
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
40 changes: 36 additions & 4 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
use arrow_array::ArrayRef;
use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{Result as DataFusionResult, ScalarValue};
Expand Down Expand Up @@ -73,10 +73,42 @@ impl Cast {
}

fn cast_array(&self, array: ArrayRef) -> DataFusionResult<ArrayRef> {
let array = array_with_timezone(array, self.timezone.clone(), Some(&self.data_type));
let to_type = &self.data_type;
let array = array_with_timezone(array, self.timezone.clone(), Some(to_type));
let from_type = array.data_type();
let cast_result = cast_with_options(&array, &self.data_type, &CAST_OPTIONS)?;
Ok(spark_cast(cast_result, from_type, &self.data_type))
let cast_result = match (from_type, to_type) {
(DataType::Utf8, DataType::Boolean) => Self::spark_cast_utf8_to_boolean::<i32>(&array),
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array)
}
_ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
};
let result = spark_cast(cast_result, from_type, to_type);
Ok(result)
}

fn spark_cast_utf8_to_boolean<OffsetSize>(from: &dyn Array) -> ArrayRef
where
OffsetSize: OffsetSizeTrait,
{
let array = from
.as_any()
.downcast_ref::<GenericStringArray<OffsetSize>>()
.unwrap();

let output_array = array
.iter()
.map(|value| match value {
Some(value) => match value.to_ascii_lowercase().trim() {
"t" | "true" | "y" | "yes" | "1" => Some(true),
"f" | "false" | "n" | "no" | "0" => Some(false),
_ => None,
},
_ => None,
})
.collect::<BooleanArray>();

Arc::new(output_array)
}
}

Expand Down
24 changes: 24 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1302,4 +1302,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("test cast utf8 to boolean as compatible with Spark") {
def testCastedColumn(inputValues: Seq[String]): Unit = {
val table = "test_table"
withTable(table) {
val values = inputValues.map(x => s"('$x')").mkString(",")
sql(s"create table $table(base_column char(20)) using parquet")
sql(s"insert into $table values $values")
checkSparkAnswerAndOperator(
s"select base_column, cast(base_column as boolean) as casted_column from $table")
}
}

// Supported boolean values as true by both Arrow and Spark
testCastedColumn(inputValues = Seq("t", "true", "y", "yes", "1", "T", "TrUe", "Y", "YES"))
// Supported boolean values as false by both Arrow and Spark
testCastedColumn(inputValues = Seq("f", "false", "n", "no", "0", "F", "FaLSe", "N", "No"))
// Supported boolean values by Arrow but not Spark
testCastedColumn(inputValues =
Seq("TR", "FA", "tr", "tru", "ye", "on", "fa", "fal", "fals", "of", "off"))
// Invalid boolean casting values for Arrow and Spark
testCastedColumn(inputValues = Seq("car", "Truck"))
}

}

0 comments on commit 96dfccf

Please sign in to comment.