Skip to content

Commit

Permalink
feat: native types in DistinctCountAccumulator for primitive types (#…
Browse files Browse the repository at this point in the history
…8721)

* DistinctCountGroupsAccumulator

* test coverage

* clippy warnings

* count distinct for primitive types

* revert hashset to std

* fixed accumulator size estimation
  • Loading branch information
korowa authored Jan 5, 2024
1 parent e5036d0 commit 561d941
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 29 deletions.
298 changes: 290 additions & 8 deletions datafusion/physical-expr/src/aggregate/count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,32 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;

use std::any::Any;
use std::cmp::Eq;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;

use crate::aggregate::utils::down_cast_any_ref;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

type DistinctScalarValues = ScalarValue;
Expand Down Expand Up @@ -60,6 +71,18 @@ impl DistinctCount {
}
}

macro_rules! native_distinct_count_accumulator {
($TYPE:ident) => {{
Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new()))
}};
}

macro_rules! float_distinct_count_accumulator {
($TYPE:ident) => {{
Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new()))
}};
}

impl AggregateExpr for DistinctCount {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
Expand All @@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount {
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
}))
use DataType::*;
use TimeUnit::*;

match &self.state_data_type {
Int8 => native_distinct_count_accumulator!(Int8Type),
Int16 => native_distinct_count_accumulator!(Int16Type),
Int32 => native_distinct_count_accumulator!(Int32Type),
Int64 => native_distinct_count_accumulator!(Int64Type),
UInt8 => native_distinct_count_accumulator!(UInt8Type),
UInt16 => native_distinct_count_accumulator!(UInt16Type),
UInt32 => native_distinct_count_accumulator!(UInt32Type),
UInt64 => native_distinct_count_accumulator!(UInt64Type),
Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type),
Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type),

Date32 => native_distinct_count_accumulator!(Date32Type),
Date64 => native_distinct_count_accumulator!(Date64Type),
Time32(Millisecond) => {
native_distinct_count_accumulator!(Time32MillisecondType)
}
Time32(Second) => {
native_distinct_count_accumulator!(Time32SecondType)
}
Time64(Microsecond) => {
native_distinct_count_accumulator!(Time64MicrosecondType)
}
Time64(Nanosecond) => {
native_distinct_count_accumulator!(Time64NanosecondType)
}
Timestamp(Microsecond, _) => {
native_distinct_count_accumulator!(TimestampMicrosecondType)
}
Timestamp(Millisecond, _) => {
native_distinct_count_accumulator!(TimestampMillisecondType)
}
Timestamp(Nanosecond, _) => {
native_distinct_count_accumulator!(TimestampNanosecondType)
}
Timestamp(Second, _) => {
native_distinct_count_accumulator!(TimestampSecondType)
}

Float16 => float_distinct_count_accumulator!(Float16Type),
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),

_ => Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
})),
}
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -192,6 +262,182 @@ impl Accumulator for DistinctCountAccumulator {
}
}

#[derive(Debug)]
struct NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
}

impl<T> NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
T::Native: Eq + Hash,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(value);
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values.extend(list.values())
};
Ok(())
})
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();

// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}

#[derive(Debug)]
struct FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
values: HashSet<Hashable<T::Native>, RandomState>,
}

impl<T> FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().map(|v| v.0),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr)) as ArrayRef;
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(Hashable(value));
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values
.extend(list.values().iter().map(|v| Hashable(*v)));
};
Ok(())
})
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();

// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}

#[cfg(test)]
mod tests {
use crate::expressions::NoOp;
Expand All @@ -206,6 +452,8 @@ mod tests {
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::Decimal256Array;
use arrow_buffer::i256;
use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array};
use datafusion_common::internal_err;
use datafusion_common::DataFusionError;
Expand Down Expand Up @@ -367,6 +615,35 @@ mod tests {
}};
}

macro_rules! test_count_distinct_update_batch_bigint {
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
let values: Vec<Option<$PRIM_TYPE>> = vec![
Some(i256::from(1)),
Some(i256::from(1)),
None,
Some(i256::from(3)),
Some(i256::from(2)),
None,
Some(i256::from(2)),
Some(i256::from(3)),
Some(i256::from(1)),
];

let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];

let (states, result) = run_update_batch(&arrays)?;

let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);
state_vec.sort();

assert_eq!(states.len(), 1);
assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]);
assert_eq!(result, ScalarValue::Int64(Some(3)));

Ok(())
}};
}

#[test]
fn count_distinct_update_batch_i8() -> Result<()> {
test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8)
Expand Down Expand Up @@ -417,6 +694,11 @@ mod tests {
test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64)
}

#[test]
fn count_distinct_update_batch_i256() -> Result<()> {
test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256)
}

#[test]
fn count_distinct_update_batch_boolean() -> Result<()> {
let get_count = |data: BooleanArray| -> Result<(Vec<bool>, i64)> {
Expand Down
22 changes: 2 additions & 20 deletions datafusion/physical-expr/src/aggregate/sum_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ use arrow::array::{Array, ArrayRef};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType};
use arrow_buffer::{ArrowNativeType, ToByteSlice};
use arrow_buffer::ArrowNativeType;
use std::collections::HashSet;

use crate::aggregate::sum::downcast_sum;
use crate::aggregate::utils::down_cast_any_ref;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::sum_return_type;
Expand Down Expand Up @@ -119,24 +119,6 @@ impl PartialEq<dyn Any> for DistinctSum {
}
}

/// A wrapper around a type to provide hash for floats
#[derive(Copy, Clone)]
struct Hashable<T>(T);

impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}

impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}

impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}

struct DistinctSumAccumulator<T: ArrowPrimitiveType> {
values: HashSet<Hashable<T::Native>, RandomState>,
data_type: DataType,
Expand Down
Loading

0 comments on commit 561d941

Please sign in to comment.