Skip to content

Commit

Permalink
fix: Remove castting on decimals with a small precision to decimal256 (
Browse files Browse the repository at this point in the history
…apache#741)

## Which issue does this PR close?

Part of apache#670

## Rationale for this change

This PR improves the native execution performance on decimals with a small precision

## What changes are included in this PR?

This PR changes not to promote decimal128 to decimal256 if the precisions are small enough

## How are these changes tested?

Existing tests
  • Loading branch information
kazuyukitanimura authored Aug 1, 2024
1 parent fbe86d0 commit 25957dd
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
17 changes: 12 additions & 5 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

//! Converts Spark physical plan to DataFusion physical plan
use std::{collections::HashMap, sync::Arc};

use arrow_schema::{DataType, Field, Schema, TimeUnit};
use arrow_schema::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION};
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf};
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
Expand Down Expand Up @@ -62,6 +60,8 @@ use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
use std::cmp::max;
use std::{collections::HashMap, sync::Arc};

use crate::{
errors::ExpressionError,
Expand Down Expand Up @@ -410,7 +410,7 @@ impl PhysicalPlanner {
// Spark Substring's start is 1-based when start > 0
let start = expr.start - i32::from(expr.start > 0);
// substring negative len is treated as 0 in Spark
let len = std::cmp::max(expr.len, 0);
let len = max(expr.len, 0);

Ok(Arc::new(SubstringExpr::new(
child,
Expand Down Expand Up @@ -664,7 +664,14 @@ impl PhysicalPlanner {
| DataFusionOperator::Modulo,
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
) => {
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
>= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Modulo
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
> DECIMAL128_MAX_PRECISION) =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
// Cast left and right to Decimal256 and cast the result back to Decimal128
Expand Down
52 changes: 33 additions & 19 deletions native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use arrow::{
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::DataType;
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
use datafusion_common::{
cast::as_generic_string_array, exec_err, internal_err, DataFusionError,
Expand Down Expand Up @@ -460,27 +460,41 @@ pub fn spark_decimal_div(
};
let left = left.as_primitive::<Decimal128Type>();
let right = right.as_primitive::<Decimal128Type>();
let (_, s1) = get_precision_scale(left.data_type());
let (_, s2) = get_precision_scale(right.data_type());
let (p1, s1) = get_precision_scale(left.data_type());
let (p2, s2) = get_precision_scale(right.data_type());

let ten = BigInt::from(10);
let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32);
let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32);
let l_mul = ten.pow(l_exp);
let r_mul = ten.pow(r_exp);
let five = BigInt::from(5);
let zero = BigInt::from(0);
let result: Decimal128Array = arrow::compute::kernels::arity::binary(left, right, |l, r| {
let l = BigInt::from(l) * &l_mul;
let r = BigInt::from(r) * &r_mul;
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
let res = if div.is_negative() {
div - &five
} else {
div + &five
} / &ten;
res.to_i128().unwrap_or(i128::MAX)
})?;
let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32
|| p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32
{
let ten = BigInt::from(10);
let l_mul = ten.pow(l_exp);
let r_mul = ten.pow(r_exp);
let five = BigInt::from(5);
let zero = BigInt::from(0);
arrow::compute::kernels::arity::binary(left, right, |l, r| {
let l = BigInt::from(l) * &l_mul;
let r = BigInt::from(r) * &r_mul;
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
let res = if div.is_negative() {
div - &five
} else {
div + &five
} / &ten;
res.to_i128().unwrap_or(i128::MAX)
})?
} else {
let l_mul = 10_i128.pow(l_exp);
let r_mul = 10_i128.pow(r_exp);
arrow::compute::kernels::arity::binary(left, right, |l, r| {
let l = l * l_mul;
let r = r * r_mul;
let div = if r == 0 { 0 } else { l / r };
let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
res.to_i128().unwrap_or(i128::MAX)
})?
};
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Ok(ColumnarValue::Array(Arc::new(result)))
}
Expand Down

0 comments on commit 25957dd

Please sign in to comment.