Skip to content

Commit

Permalink
feat: Support Emit::First for SumDecimalGroupsAccumulator (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Feb 20, 2024
1 parent 7772d4c commit ca54845
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
42 changes: 27 additions & 15 deletions core/src/execution/datafusion/expressions/sum_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer};
use arrow::{
array::BooleanBufferBuilder,
buffer::{BooleanBuffer, NullBuffer},
};
use arrow_array::{
cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array,
};
Expand Down Expand Up @@ -314,6 +317,25 @@ fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) {
}
}

/// Build a boolean buffer from the state and reset the state, based on the emit_to
/// strategy.
fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer {
let bool_state: BooleanBuffer = state.finish();

match emit_to {
EmitTo::All => bool_state,
EmitTo::First(n) => {
// split off the first N values in bool_state
let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect();
// reset the existing seen buffer
for seen in bool_state.iter().skip(*n) {
state.append(seen);
}
first_n_bools
}
}
}

impl GroupsAccumulator for SumDecimalGroupsAccumulator {
fn update_batch(
&mut self,
Expand Down Expand Up @@ -350,18 +372,13 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
}

fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
// TODO: we do not support group-by ordering yet, but should fix here once it is supported
assert!(
matches!(emit_to, EmitTo::All),
"EmitTo::First is not supported"
);
// For each group:
// 1. if `is_empty` is true, it means either there is no value or all values for the group
// are null, in this case we'll return null
// 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
// non-ANSI mode Spark returns null.
let nulls = self.is_not_null.finish();
let is_empty = self.is_empty.finish();
let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
let x = (!&is_empty).bitand(&nulls);

let result = emit_to.take_needed(&mut self.sum);
Expand All @@ -372,19 +389,14 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
}

fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
// TODO: we do not support group-by ordering yet, but should fix here once it is supported
assert!(
matches!(emit_to, EmitTo::All),
"EmitTo::First is not supported"
);
let nulls = self.is_not_null.finish();
let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
let nulls = Some(NullBuffer::new(nulls));

let sum = emit_to.take_needed(&mut self.sum);
let sum = Decimal128Array::new(sum.into(), nulls.clone())
.with_data_type(self.result_type.clone());

let is_empty = self.is_empty.finish();
let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
let is_empty = BooleanArray::new(is_empty, None);

Ok(vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ import org.apache.hadoop.fs.Path
import org.apache.parquet.example.data.simple.SimpleGroup
import org.apache.parquet.schema.MessageTypeParser
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
Expand All @@ -36,6 +39,24 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
*/
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("SUM decimal supports emit.first") {
withSQLConf(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> EliminateSorts.ruleName,
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test")
makeParquetFile(path, 10000, 10, dictionaryEnabled)
withParquetTable(path.toUri.toString, "tbl") {
checkSparkAnswer(sql("SELECT * FROM tbl").sort("_g1").groupBy("_g1").agg(sum("_8")))
}
}
}
}
}

test("Fix NPE in partial decimal sum") {
val table = "tbl"
withTable(table) {
Expand Down

0 comments on commit ca54845

Please sign in to comment.