Skip to content

Commit

Permalink
Minor optimization of the null handling in ArrayAggAccumulator (#263)
Browse files Browse the repository at this point in the history
* Minor optimization of the null handling in ArrayAggAccumulator

* Add `ignore_nulls` flag to `AggrFn`
  • Loading branch information
joroKr21 authored Aug 26, 2024
1 parent e3c300e commit 8a0ca9b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 19 deletions.
34 changes: 23 additions & 11 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ impl ArrayAgg {
ignore_nulls,
}
}

pub fn ignore_nulls(&self) -> bool {
self.ignore_nulls
}
}

impl AggregateExpr for ArrayAgg {
Expand Down Expand Up @@ -339,10 +343,25 @@ impl ArrayAggAccumulator {
impl Accumulator for ArrayAggAccumulator {
// Append value like Int64Array(1,2,3)
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
fn has_nulls(arr: &ArrayRef) -> bool {
arr.logical_nulls()
.is_some_and(|nulls| nulls.null_count() > 0)
}

if !values.is_empty() {
assert!(values.len() == 1, "array_agg can only take 1 param!");
let val = values[0].clone();
self.values.push(val);
let val = &values[0];
if !val.is_empty() {
if self.ignore_nulls && has_nulls(val) {
let not_null = arrow::compute::is_not_null(val)?;
let result = arrow::compute::filter(val, &not_null)?;
if !result.is_empty() {
self.values.push(result)
}
} else {
self.values.push(val.clone())
}
}
}

Ok(())
Expand All @@ -353,9 +372,7 @@ impl Accumulator for ArrayAggAccumulator {
if !states.is_empty() {
assert!(states.len() == 1, "array_agg states must be singleton!");
let list_arr = as_list_array(&states[0])?;
for arr in list_arr.iter().flatten() {
self.values.push(arr);
}
self.values.extend(list_arr.iter().flatten())
}

Ok(())
Expand All @@ -375,12 +392,7 @@ impl Accumulator for ArrayAggAccumulator {
return Ok(ScalarValue::List(arr));
}

let mut result = arrow::compute::concat(&element_arrays)?;
if self.ignore_nulls {
let not_null = arrow::compute::is_not_null(&result)?;
result = arrow::compute::filter(&result, &not_null)?
}

let result = arrow::compute::concat(&element_arrays)?;
let result = array_into_list_array(result);
Ok(ScalarValue::List(Arc::new(result)))
}
Expand Down
4 changes: 4 additions & 0 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ impl DistinctArrayAgg {
ignore_nulls,
}
}

pub fn ignore_nulls(&self) -> bool {
self.ignore_nulls
}
}

impl AggregateExpr for DistinctArrayAgg {
Expand Down
27 changes: 19 additions & 8 deletions datafusion/proto/src/physical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,23 @@ pub fn serialize_physical_aggr_expr(
}

let AggrFn {
inner: aggr_function,
inner,
distinct,
ignore_nulls,
} = aggr_expr_to_aggr_fn(aggr_expr.as_ref())?;

Ok(protobuf::PhysicalExprNode {
expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
protobuf::PhysicalAggregateExprNode {
aggregate_function: Some(
physical_aggregate_expr_node::AggregateFunction::AggrFunction(
aggr_function as i32,
inner.into(),
),
),
expr: expressions,
ordering_req,
distinct,
ignore_nulls: false,
ignore_nulls,
fun_definition: None,
},
)),
Expand Down Expand Up @@ -124,7 +125,9 @@ fn serialize_physical_window_aggr_expr(
(!buf.is_empty()).then_some(buf),
))
} else {
let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn(aggr_expr)?;
let AggrFn {
inner, distinct, ..
} = aggr_expr_to_aggr_fn(aggr_expr)?;
if distinct {
return not_impl_err!(
"Distinct aggregate functions not supported in window expressions"
Expand All @@ -138,7 +141,7 @@ fn serialize_physical_window_aggr_expr(
}

Ok((
physical_window_expr_node::WindowFunction::AggrFunction(inner as i32),
physical_window_expr_node::WindowFunction::AggrFunction(inner.into()),
None,
))
}
Expand Down Expand Up @@ -260,11 +263,13 @@ pub fn serialize_physical_window_expr(
struct AggrFn {
inner: protobuf::AggregateFunction,
distinct: bool,
ignore_nulls: bool,
}

fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result<AggrFn> {
let aggr_expr = expr.as_any();
let mut distinct = false;
let mut ignore_nulls = false;

let inner = if aggr_expr.downcast_ref::<Count>().is_some() {
protobuf::AggregateFunction::Count
Expand Down Expand Up @@ -293,10 +298,12 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result<AggrFn> {
protobuf::AggregateFunction::Sum
} else if aggr_expr.downcast_ref::<ApproxDistinct>().is_some() {
protobuf::AggregateFunction::ApproxDistinct
} else if aggr_expr.downcast_ref::<ArrayAgg>().is_some() {
} else if let Some(array_agg) = aggr_expr.downcast_ref::<ArrayAgg>() {
ignore_nulls = array_agg.ignore_nulls();
protobuf::AggregateFunction::ArrayAgg
} else if aggr_expr.downcast_ref::<DistinctArrayAgg>().is_some() {
} else if let Some(array_agg) = aggr_expr.downcast_ref::<DistinctArrayAgg>() {
distinct = true;
ignore_nulls = array_agg.ignore_nulls();
protobuf::AggregateFunction::ArrayAgg
} else if aggr_expr.downcast_ref::<OrderSensitiveArrayAgg>().is_some() {
protobuf::AggregateFunction::ArrayAgg
Expand Down Expand Up @@ -343,7 +350,11 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result<AggrFn> {
return not_impl_err!("Aggregate function not supported: {expr:?}");
};

Ok(AggrFn { inner, distinct })
Ok(AggrFn {
inner,
distinct,
ignore_nulls,
})
}

pub fn serialize_physical_sort_exprs<I>(
Expand Down

0 comments on commit 8a0ca9b

Please sign in to comment.