From 2e7892b8bda178621f34021af248a513d8b8bef2 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 20 Jun 2026 20:28:12 +0800 Subject: [PATCH 1/3] refactor: use EmitTo for aggregate state output --- .../src/aggregates/hash_aggregate.rs | 4 +- .../src/aggregates/hash_table.rs | 206 +++++++++--------- 2 files changed, 111 insertions(+), 99 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 59ee09912f621..7372fabf4108c 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -131,7 +131,7 @@ pub(crate) struct PartialHashAggregateStream { group_values_soft_limit: Option, /// Tracks the high-level stream lifecycle. The hash table owns the lower-level - /// state for materializing and slicing output batches. + /// state for emitting output batches. state: Option, } @@ -203,7 +203,7 @@ pub(crate) struct FinalHashAggregateStream { group_values_soft_limit: Option, /// Tracks the high-level stream lifecycle. The hash table owns the lower-level - /// state for materializing and slicing output batches. + /// state for emitting output batches. state: Option, } diff --git a/datafusion/physical-plan/src/aggregates/hash_table.rs b/datafusion/physical-plan/src/aggregates/hash_table.rs index e6b2fa22c137f..ba5c423ce698e 100644 --- a/datafusion/physical-plan/src/aggregates/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/hash_table.rs @@ -46,7 +46,8 @@ pub(super) struct Final; /// Grouped hash table shared by the partial and final paths. /// /// While building, it consumes input batches and updates group / accumulator -/// state. While outputting, it incrementally output the materialized batches. +/// state. While outputting, it incrementally drains that state into output +/// batches. /// /// # Marker Type /// `AggrMode` selects the aggregate semantics. @@ -68,7 +69,7 @@ pub(super) struct AggregateHashTable { /// Output schema: group columns followed by aggregate state or final values. output_schema: SchemaRef, - /// Maximum rows per emitted output batch. + /// Maximum rows per emitted output batch, from config `batch_size`. batch_size: usize, /// Lifecycle-specific state: building stage / outputting stage @@ -144,10 +145,7 @@ struct BuildingHashTableState { enum AggregateHashTableState { Building(BuildingHashTableState), - Outputting { - output_batch: Option, - output_batch_offset: usize, - }, + Outputting(BuildingHashTableState), Done, } @@ -281,6 +279,10 @@ impl AggregateHashTable { batch_size: usize, filters: Vec>>, ) -> Result { + if batch_size == 0 { + return internal_err!("AggregateHashTable requires config batch_size >= 1"); + } + let input_schema = agg.input().schema(); let aggregate_arguments = aggregate_expressions( &agg.aggr_expr, @@ -349,7 +351,8 @@ impl AggregateHashTable { pub(super) fn memory_size(&self) -> usize { match &self.state { - AggregateHashTableState::Building(state) => { + AggregateHashTableState::Building(state) + | AggregateHashTableState::Outputting(state) => { let acc = state .accumulators .iter() @@ -359,9 +362,6 @@ impl AggregateHashTable { acc + state.group_values.size() + state.batch_group_indices.allocated_size() } - AggregateHashTableState::Outputting { output_batch, .. } => { - output_batch_memory_size(output_batch) - } AggregateHashTableState::Done => 0, } } @@ -379,52 +379,24 @@ impl AggregateHashTable { matches!(self.state, AggregateHashTableState::Done) } - fn set_output_batch(&mut self, output_batch: Option) { - self.state = AggregateHashTableState::Outputting { - output_batch, - output_batch_offset: 0, + fn start_outputting(&mut self) { + let AggregateHashTableState::Building(mut state) = + std::mem::replace(&mut self.state, AggregateHashTableState::Done) + else { + unreachable!("hash aggregate table is not building") }; - } - pub(super) fn next_output_batch(&mut self) -> Result> { - match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { - AggregateHashTableState::Outputting { - output_batch, - mut output_batch_offset, - } => { - let Some(batch) = output_batch.as_ref() else { - return Ok(None); - }; - - let num_rows = batch.num_rows(); - if output_batch_offset >= num_rows { - return Ok(None); - } - - debug_assert!(self.batch_size > 0); - let output_len = - self.batch_size.max(1).min(num_rows - output_batch_offset); - let output = batch.slice(output_batch_offset, output_len); - output_batch_offset += output_len; - - if output_batch_offset == num_rows { - self.state = AggregateHashTableState::Done; - } else { - self.state = AggregateHashTableState::Outputting { - output_batch, - output_batch_offset, - }; - } + state.batch_group_indices = Vec::new(); + self.state = AggregateHashTableState::Outputting(state); + } +} - debug_assert!(output.num_rows() > 0); - debug_assert!(output.num_rows() <= self.batch_size.max(1)); - Ok(Some(output)) - } - _ => { - self.state = AggregateHashTableState::Done; - internal_err!("next_output_batch must be called in the outputting state") - } - } +fn emit_to_for_batch_size(batch_size: usize, group_count: usize) -> EmitTo { + debug_assert!(batch_size > 0); + if group_count <= batch_size { + EmitTo::All + } else { + EmitTo::First(batch_size) } } @@ -444,6 +416,47 @@ impl AggregateHashTable { ) } + /// Emits the next batch of aggregated group keys and aggregate states. + /// + /// The output batch size is determined by `self.batch_size`. + /// + /// Returns `Some(batch)` for each emitted batch, `None` when output is + /// exhausted, and an internal error if polled in the `Building` state. + pub(super) fn next_output_batch(&mut self) -> Result> { + let output_schema = Arc::clone(&self.output_schema); + let batch_size = self.batch_size; + match &mut self.state { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + self.state = AggregateHashTableState::Done; + return Ok(None); + } + + let emit_to = + emit_to_for_batch_size(batch_size, state.group_values.len()); + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(emit_to)?); + } + let done = state.group_values.is_empty(); + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + if done { + self.state = AggregateHashTableState::Done; + } + Ok(Some(batch)) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!("next_output_batch must be called in the outputting state") + } + } + } + pub(super) fn can_skip_aggregation(&self) -> bool { self.state .building() @@ -507,25 +520,7 @@ impl AggregateHashTable { pub(super) fn start_output(&mut self) -> Result<()> { self.init_empty_grouping_sets()?; - let state = self.state.building_mut(); - - let output_batch = if state.group_values.is_empty() { - None - } else { - let timer = self.group_by_metrics.emitting_time.timer(); - let mut output = state.group_values.emit(EmitTo::All)?; - - for acc in state.accumulators.iter_mut() { - output.extend(acc.state(EmitTo::All)?); - } - - let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; - debug_assert!(batch.num_rows() > 0); - drop(timer); - Some(batch) - }; - - self.set_output_batch(output_batch); + self.start_outputting(); Ok(()) } @@ -648,6 +643,47 @@ impl AggregateHashTable { ) } + /// Emits the next batch of aggregated group keys and aggregate states. + /// + /// The output batch size is determined by `self.batch_size`. + /// + /// Returns `Some(batch)` for each emitted batch, `None` when output is + /// exhausted, and an internal error if polled in the `Building` state. + pub(super) fn next_output_batch(&mut self) -> Result> { + let output_schema = Arc::clone(&self.output_schema); + let batch_size = self.batch_size; + match &mut self.state { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + self.state = AggregateHashTableState::Done; + return Ok(None); + } + + let emit_to = + emit_to_for_batch_size(batch_size, state.group_values.len()); + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to)?; + + for acc in state.accumulators.iter_mut() { + output.push(acc.evaluate_final(emit_to)?); + } + let done = state.group_values.is_empty(); + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + if done { + self.state = AggregateHashTableState::Done; + } + Ok(Some(batch)) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!("next_output_batch must be called in the outputting state") + } + } + } + pub(super) fn aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { let evaluated_batch = self.evaluate_batch(batch)?; let state = self.state.building_mut(); @@ -674,31 +710,7 @@ impl AggregateHashTable { } pub(super) fn start_output(&mut self) -> Result<()> { - let state = self.state.building_mut(); - let output_batch = if state.group_values.is_empty() { - None - } else { - let timer = self.group_by_metrics.emitting_time.timer(); - let mut output = state.group_values.emit(EmitTo::All)?; - - for acc in state.accumulators.iter_mut() { - output.push(acc.evaluate_final(EmitTo::All)?); - } - - let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; - debug_assert!(batch.num_rows() > 0); - drop(timer); - Some(batch) - }; - - self.set_output_batch(output_batch); + self.start_outputting(); Ok(()) } } - -fn output_batch_memory_size(output_batch: &Option) -> usize { - output_batch - .as_ref() - .map(RecordBatch::get_array_memory_size) - .unwrap_or_default() -} From d96b68c5d50ea231b59b26e10db5e57af75c6578 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 20 Jun 2026 20:53:22 +0800 Subject: [PATCH 2/3] split hash_table.rs into small files --- .../aggregates/aggregate_hash_table/common.rs | 405 ++++++++++ .../aggregate_hash_table/final_table.rs | 121 +++ .../aggregates/aggregate_hash_table/mod.rs | 22 + .../aggregate_hash_table/partial_table.rs | 269 +++++++ .../src/aggregates/hash_aggregate.rs | 2 +- .../src/aggregates/hash_table.rs | 716 ------------------ .../physical-plan/src/aggregates/mod.rs | 2 +- 7 files changed, 819 insertions(+), 718 deletions(-) create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs create mode 100644 datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs delete mode 100644 datafusion/physical-plan/src/aggregates/hash_table.rs diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs new file mode 100644 index 0000000000000..0b1cb422d345f --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -0,0 +1,405 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, new_null_array}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_err}; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; + +use crate::PhysicalExpr; +use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; +use crate::aggregates::order::GroupOrdering; +use crate::aggregates::row_hash::create_group_accumulator; +use crate::aggregates::{ + AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, +}; + +/// Marker for raw rows -> partial state aggregation. +pub(in crate::aggregates) struct Partial; +/// Marker for raw rows -> partial state conversion without aggregation. +pub(in crate::aggregates) struct PartialSkip; +/// Marker for partial state -> final value aggregation. +pub(in crate::aggregates) struct Final; + +/// Grouped hash table shared by the partial and final paths. +/// +/// While building, it consumes input batches and updates group / accumulator +/// state. While outputting, it incrementally drains that state into output +/// batches. +/// +/// # Marker Type +/// `AggrMode` selects the aggregate semantics. +/// +/// e.g. `AggregateHashTable::::new(...)` creates an aggregate hash table +/// for the partial hash aggregate stage, the input schema is raw rows and output +/// schema is intermediate states. +/// +/// It is a zero-sized compile-time marker, so each stage keeps its update logic +/// in a separate impl block, to make the behavior difference explicit. +pub(in crate::aggregates) struct AggregateHashTable { + /// Grouping and accumulator-specific timing metrics. + pub(super) group_by_metrics: GroupByMetrics, + + /// Raw input schema, used to evaluate expressions and synthesize empty + /// grouping-set rows. + pub(super) input_schema: SchemaRef, + + /// Output schema: group columns followed by aggregate state or final values. + pub(super) output_schema: SchemaRef, + + /// Maximum rows per emitted output batch, from config `batch_size`. + pub(super) batch_size: usize, + + /// Lifecycle-specific state: building stage / outputting stage. + pub(super) state: AggregateHashTableState, + + pub(super) _mode: PhantomData, +} + +pub(super) struct HashAggregateAccumulator { + /// Aggregate expression used to create a fresh accumulator for related + /// hash tables, such as the partial-skip table. + aggregate_expr: Arc, + + /// Arguments to pass to this accumulator. + /// + /// Example: `CORR(x, y)` stores two expressions here, while `SUM(x)` stores one. + arguments: Vec>, + + /// Optional `FILTER` expression for this accumulator. + /// + /// Example: `SUM(x) FILTER (WHERE x > 10)` stores the `x > 10` predicate. + filter: Option>, + + /// Accumulator state for all groups for one aggregate expression. + accumulator: Box, +} + +pub(super) struct EvaluatedHashAggregateAccumulator { + pub(super) arguments: Vec, + pub(super) filter: Option, +} + +/// Evaluated all group by keys and accumulator args. +/// +/// e.g., `select k+1, sum(v*v) from t group by (k+1)`, this function evaluates +/// `k+1`, `v*v` +pub(super) struct EvaluatedAggregateBatch { + /// One entry per grouping set; each entry contains all evaluated group key + /// arrays for the current input batch. + pub(super) grouping_set_args: Vec>, + + /// Evaluated arguments and filters, one entry per aggregate expression. + pub(super) accumulator_args: Vec, +} + +/// Hash table state while grouped aggregation is consuming input. +/// +/// This owns the coupled state for: +/// - evaluating group keys, +/// - interning each distinct group, +/// - mapping each input row to its group index, +/// - evaluating aggregate inputs, +/// - updating per-group accumulator state. +pub(super) struct BuildingHashTableState { + /// GROUP BY expressions evaluated for each input batch. + pub(super) group_by: Arc, + + /// Interned group keys. Accumulator state is stored separately by group index. + pub(super) group_values: Box, + + /// Group index for each row in the current input batch. + /// + /// Each value indexes into `group_values`, and the same index is used by every + /// accumulator to update that group's aggregate state. + pub(super) batch_group_indices: Vec, + + /// One item per aggregate expression. + /// + /// Example: `COUNT(x), SUM(y)` creates two items. Each item owns the input + /// expressions, optional filter, and accumulator state for all groups. + pub(super) accumulators: Vec, +} + +pub(super) enum AggregateHashTableState { + Building(BuildingHashTableState), + Outputting(BuildingHashTableState), + Done, +} + +impl HashAggregateAccumulator { + fn new( + aggregate_expr: Arc, + arguments: Vec>, + filter: Option>, + accumulator: Box, + ) -> Self { + Self { + aggregate_expr, + arguments, + filter, + accumulator, + } + } + + pub(super) fn empty_like(&self) -> Result { + let accumulator = create_group_accumulator(&self.aggregate_expr)?; + Ok(Self::new( + Arc::clone(&self.aggregate_expr), + self.arguments.clone(), + self.filter.clone(), + accumulator, + )) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arguments = self + .arguments + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|value| value.into_array(batch.num_rows())) + }) + .collect::>()?; + + let filter = self + .filter + .as_ref() + .map(|filter| { + filter + .evaluate(batch) + .and_then(|value| value.into_array(batch.num_rows())) + }) + .transpose()?; + + Ok(EvaluatedHashAggregateAccumulator { arguments, filter }) + } + + pub(super) fn update_batch( + &mut self, + values: &EvaluatedHashAggregateAccumulator, + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + let filter = values.filter.as_ref().map(|filter| filter.as_boolean()); + self.accumulator.update_batch( + &values.arguments, + group_indices, + filter, + total_num_groups, + ) + } + + pub(super) fn merge_batch( + &mut self, + values: &EvaluatedHashAggregateAccumulator, + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + debug_assert!(values.filter.is_none()); + self.accumulator + .merge_batch(&values.arguments, group_indices, total_num_groups) + } + + pub(super) fn evaluate_final(&mut self, emit_to: EmitTo) -> Result { + self.accumulator.evaluate(emit_to) + } + + pub(super) fn state(&mut self, emit_to: EmitTo) -> Result> { + self.accumulator.state(emit_to) + } + + pub(super) fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state() + } + + pub(super) fn convert_to_state( + &mut self, + values: &EvaluatedHashAggregateAccumulator, + ) -> Result> { + let opt_filter = values.filter.as_ref().map(|filter| filter.as_boolean()); + self.accumulator + .convert_to_state(&values.arguments, opt_filter) + } + + pub(super) fn null_arguments( + &self, + input_schema: &SchemaRef, + ) -> Result> { + self.arguments + .iter() + .map(|expr| { + let data_type = expr.data_type(input_schema)?; + Ok(new_null_array(&data_type, 1)) + }) + .collect() + } +} + +impl AggregateHashTableState { + pub(super) fn building(&self) -> &BuildingHashTableState { + let Self::Building(state) = self else { + unreachable!("hash aggregate table is not building") + }; + state + } + + pub(super) fn building_mut(&mut self) -> &mut BuildingHashTableState { + let Self::Building(state) = self else { + unreachable!("hash aggregate table is not building") + }; + state + } +} + +impl AggregateHashTable { + pub(super) fn new_with_filters( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + filters: Vec>>, + ) -> Result { + if batch_size == 0 { + return internal_err!("AggregateHashTable requires config batch_size >= 1"); + } + + let input_schema = agg.input().schema(); + let aggregate_arguments = aggregate_expressions( + &agg.aggr_expr, + &agg.mode, + agg.group_by.num_group_exprs(), + )?; + let accumulators: Vec<_> = agg + .aggr_expr + .iter() + .zip(aggregate_arguments) + .zip(filters) + .map(|((agg_expr, arguments), filter)| { + let accumulator = create_group_accumulator(agg_expr)?; + Ok(HashAggregateAccumulator::new( + Arc::clone(agg_expr), + arguments, + filter, + accumulator, + )) + }) + .collect::>()?; + + let group_schema = agg.group_by.group_schema(&input_schema)?; + let group_values = new_group_values(group_schema, &GroupOrdering::None)?; + + Ok(Self { + group_by_metrics: GroupByMetrics::new(&agg.metrics, partition), + input_schema, + output_schema, + batch_size, + state: AggregateHashTableState::Building(BuildingHashTableState { + group_by: Arc::clone(&agg.group_by), + group_values, + batch_group_indices: Default::default(), + accumulators, + }), + _mode: PhantomData, + }) + } + + /// See comments in [`EvaluatedAggregateBatch`] + pub(super) fn evaluate_batch( + &self, + batch: &RecordBatch, + ) -> Result { + let state = self.state.building(); + let timer = self.group_by_metrics.time_calculating_group_ids.timer(); + // outer vec: one per each grouping set + // inner vec: all group by exprs for the current grouping set + let grouping_set_args = evaluate_group_by(&state.group_by, batch)?; + drop(timer); + + let timer = self.group_by_metrics.aggregate_arguments_time.timer(); + // The evaluated args for each accumulator + let accumulator_args = self + .state + .building() + .accumulators + .iter() + .map(|acc| acc.evaluate(batch)) + .collect::>>()?; + drop(timer); + + Ok(EvaluatedAggregateBatch { + grouping_set_args, + accumulator_args, + }) + } + + pub(in crate::aggregates) fn memory_size(&self) -> usize { + match &self.state { + AggregateHashTableState::Building(state) + | AggregateHashTableState::Outputting(state) => { + let acc = state + .accumulators + .iter() + .map(|acc| acc.accumulator.size()) + .sum::(); + + acc + state.group_values.size() + + state.batch_group_indices.allocated_size() + } + AggregateHashTableState::Done => 0, + } + } + + /// Returns the number of distinct groups accumulated so far. + pub(in crate::aggregates) fn building_group_count(&self) -> usize { + self.state.building().group_values.len() + } + + pub(in crate::aggregates) fn is_building(&self) -> bool { + matches!(self.state, AggregateHashTableState::Building(_)) + } + + pub(in crate::aggregates) fn is_done(&self) -> bool { + matches!(self.state, AggregateHashTableState::Done) + } + + pub(super) fn start_outputting(&mut self) { + let AggregateHashTableState::Building(mut state) = + std::mem::replace(&mut self.state, AggregateHashTableState::Done) + else { + unreachable!("hash aggregate table is not building") + }; + + state.batch_group_indices = Vec::new(); + self.state = AggregateHashTableState::Outputting(state); + } +} + +pub(super) fn emit_to_for_batch_size(batch_size: usize, group_count: usize) -> EmitTo { + debug_assert!(batch_size > 0); + if group_count <= batch_size { + EmitTo::All + } else { + EmitTo::First(batch_size) + } +} diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs new file mode 100644 index 0000000000000..e318d22c1dc5e --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_err}; + +use crate::aggregates::AggregateExec; + +use super::common::{ + AggregateHashTable, AggregateHashTableState, Final, emit_to_for_batch_size, +}; + +impl AggregateHashTable { + pub(in crate::aggregates) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + vec![None; agg.aggr_expr.len()], + ) + } + + /// Emits the next batch of aggregated group keys and final aggregate values. + /// + /// The output batch size is determined by `self.batch_size`. + /// + /// Returns `Some(batch)` for each emitted batch, `None` when output is + /// exhausted, and an internal error if polled in the `Building` state. + pub(in crate::aggregates) fn next_output_batch( + &mut self, + ) -> Result> { + let output_schema = Arc::clone(&self.output_schema); + let batch_size = self.batch_size; + match &mut self.state { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + self.state = AggregateHashTableState::Done; + return Ok(None); + } + + let emit_to = + emit_to_for_batch_size(batch_size, state.group_values.len()); + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to)?; + + for acc in state.accumulators.iter_mut() { + output.push(acc.evaluate_final(emit_to)?); + } + let done = state.group_values.is_empty(); + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + if done { + self.state = AggregateHashTableState::Done; + } + Ok(Some(batch)) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!("next_output_batch must be called in the outputting state") + } + } + } + + pub(in crate::aggregates) fn aggregate_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.merge_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(in crate::aggregates) fn start_output(&mut self) -> Result<()> { + self.start_outputting(); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs new file mode 100644 index 0000000000000..9879ed82951c5 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; +mod final_table; +mod partial_table; + +pub(super) use common::{AggregateHashTable, Final, Partial, PartialSkip}; diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs new file mode 100644 index 0000000000000..0afea3a33b43c --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, new_null_array}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; + +use crate::aggregates::group_values::new_group_values; +use crate::aggregates::order::GroupOrdering; +use crate::aggregates::{AggregateExec, group_id_array, max_duplicate_ordinal}; + +use super::common::{ + AggregateHashTable, AggregateHashTableState, BuildingHashTableState, + EvaluatedHashAggregateAccumulator, HashAggregateAccumulator, Partial, PartialSkip, + emit_to_for_batch_size, +}; + +impl AggregateHashTable { + pub(in crate::aggregates) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + agg.filter_expr.iter().cloned().collect(), + ) + } + + /// Emits the next batch of aggregated group keys and aggregate states. + /// + /// The output batch size is determined by `self.batch_size`. + /// + /// Returns `Some(batch)` for each emitted batch, `None` when output is + /// exhausted, and an internal error if polled in the `Building` state. + pub(in crate::aggregates) fn next_output_batch( + &mut self, + ) -> Result> { + let output_schema = Arc::clone(&self.output_schema); + let batch_size = self.batch_size; + match &mut self.state { + AggregateHashTableState::Outputting(state) => { + if state.group_values.is_empty() { + self.state = AggregateHashTableState::Done; + return Ok(None); + } + + let emit_to = + emit_to_for_batch_size(batch_size, state.group_values.len()); + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(emit_to)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(emit_to)?); + } + let done = state.group_values.is_empty(); + drop(timer); + + let batch = RecordBatch::try_new(output_schema, output)?; + debug_assert!(batch.num_rows() > 0); + if done { + self.state = AggregateHashTableState::Done; + } + Ok(Some(batch)) + } + AggregateHashTableState::Done => Ok(None), + AggregateHashTableState::Building(_) => { + internal_err!("next_output_batch must be called in the outputting state") + } + } + } + + pub(in crate::aggregates) fn can_skip_aggregation(&self) -> bool { + self.state + .building() + .accumulators + .iter() + .all(|acc| acc.supports_convert_to_state()) + } + + /// In skip-partial-aggregation optimization, when a decision has made to skip + /// partial stage, build a typed hash table only for aggregation state conversion + /// row-by-row. + pub(in crate::aggregates) fn partial_skip_table( + &self, + ) -> Result> { + let state = self.state.building(); + let group_schema = state.group_by.group_schema(&self.input_schema)?; + let group_values = new_group_values(group_schema, &GroupOrdering::None)?; + let accumulators = state + .accumulators + .iter() + .map(HashAggregateAccumulator::empty_like) + .collect::>>()?; + + Ok(AggregateHashTable { + group_by_metrics: self.group_by_metrics.clone(), + input_schema: Arc::clone(&self.input_schema), + output_schema: Arc::clone(&self.output_schema), + batch_size: self.batch_size, + state: AggregateHashTableState::Building(BuildingHashTableState { + group_by: Arc::clone(&state.group_by), + group_values, + batch_group_indices: Default::default(), + accumulators, + }), + _mode: PhantomData, + }) + } + + pub(in crate::aggregates) fn aggregate_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.update_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(in crate::aggregates) fn start_output(&mut self) -> Result<()> { + self.init_empty_grouping_sets()?; + self.start_outputting(); + Ok(()) + } + + /// Creates the required empty grouping-set rows when the input is empty. + /// + /// For example, this query must still produce one grand-total group even if + /// `t` has no rows: + /// + /// ```sql + /// SELECT COUNT(v) + /// FROM t + /// GROUP BY GROUPING SETS (()); + /// ``` + /// + /// The synthetic row is filtered out before accumulator update so aggregates + /// see the same state they would see for an empty input, rather than a real + /// null-valued row. + fn init_empty_grouping_sets(&mut self) -> Result<()> { + let state = self.state.building_mut(); + if !state.group_by.has_grouping_set() || !state.group_values.is_empty() { + return Ok(()); + } + + let max_ordinal = max_duplicate_ordinal(state.group_by.groups()); + let mut ordinals: HashMap<&[bool], usize> = HashMap::new(); + let group_schema = state.group_by.group_schema(&self.input_schema)?; + let n_expr = state.group_by.expr().len(); + let mut any_interned = false; + + for group in state.group_by.groups() { + let ordinal = { + let entry = ordinals.entry(group.as_slice()).or_insert(0); + let ordinal = *entry; + *entry += 1; + ordinal + }; + + if !group.iter().all(|&is_null| is_null) { + continue; + } + + let mut cols: Vec = group_schema + .fields() + .iter() + .take(n_expr) + .map(|field| new_null_array(field.data_type(), 1)) + .collect(); + cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); + + state + .group_values + .intern(&cols, &mut state.batch_group_indices)?; + any_interned = true; + } + + if any_interned { + let total_groups = state.group_values.len(); + let false_filter = BooleanArray::from(vec![false]); + for acc in state.accumulators.iter_mut() { + let null_args = acc.null_arguments(&self.input_schema)?; + let values = EvaluatedHashAggregateAccumulator { + arguments: null_args, + filter: Some(Arc::new(false_filter.clone())), + }; + acc.update_batch(&values, &[0], total_groups)?; + } + } + + Ok(()) + } +} + +impl AggregateHashTable { + pub(in crate::aggregates) fn convert_batch_to_state( + &mut self, + batch: &RecordBatch, + ) -> Result { + let evaluated_batch = self.evaluate_batch(batch)?; + + assert_eq_or_internal_err!( + evaluated_batch.grouping_set_args.len(), + 1, + "group_values expected to have single element" + ); + let mut output = evaluated_batch + .grouping_set_args + .into_iter() + .next() + .unwrap_or_default(); + + let state = self.state.building_mut(); + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + output.extend(acc.convert_to_state(values)?); + } + + Ok(RecordBatch::try_new( + Arc::clone(&self.output_schema), + output, + )?) + } +} diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 7372fabf4108c..076d56a3ca87d 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -37,7 +37,7 @@ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use futures::stream::{Stream, StreamExt}; use super::AggregateExec; -use super::hash_table::{AggregateHashTable, Final, Partial, PartialSkip}; +use super::aggregate_hash_table::{AggregateHashTable, Final, Partial, PartialSkip}; use super::skip_partial::SkipAggregationProbe; use crate::metrics::{ BaselineMetrics, MetricBuilder, MetricCategory, RecordOutput, SpillMetrics, diff --git a/datafusion/physical-plan/src/aggregates/hash_table.rs b/datafusion/physical-plan/src/aggregates/hash_table.rs deleted file mode 100644 index ba5c423ce698e..0000000000000 --- a/datafusion/physical-plan/src/aggregates/hash_table.rs +++ /dev/null @@ -1,716 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashMap; -use std::marker::PhantomData; -use std::sync::Arc; - -use arrow::array::{ArrayRef, AsArray, BooleanArray, new_null_array}; -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; -use datafusion_execution::memory_pool::proxy::VecAllocExt; -use datafusion_expr::{EmitTo, GroupsAccumulator}; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; - -use super::group_values::{GroupByMetrics, GroupValues, new_group_values}; -use super::order::GroupOrdering; -use super::row_hash::create_group_accumulator; -use super::{ - AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, - group_id_array, max_duplicate_ordinal, -}; -use crate::PhysicalExpr; - -/// Marker for raw rows -> partial state aggregation. -pub(super) struct Partial; -/// Marker for raw rows -> partial state conversion without aggregation. -pub(super) struct PartialSkip; -/// Marker for partial state -> final value aggregation. -pub(super) struct Final; - -/// Grouped hash table shared by the partial and final paths. -/// -/// While building, it consumes input batches and updates group / accumulator -/// state. While outputting, it incrementally drains that state into output -/// batches. -/// -/// # Marker Type -/// `AggrMode` selects the aggregate semantics. -/// -/// e.g. `AggregateHashTable::::new(...)` creates an aggregate hash table -/// for the partial hash aggregate stage, the input schema is raw rows and output -/// schema is intermediate states. -/// -/// It is a zero-sized compile-time marker, so each stage keeps its update logic -/// in a separate impl block, to make the behavior difference explicit. -pub(super) struct AggregateHashTable { - /// Grouping and accumulator-specific timing metrics. - group_by_metrics: GroupByMetrics, - - /// Raw input schema, used to evaluate expressions and synthesize empty - /// grouping-set rows. - input_schema: SchemaRef, - - /// Output schema: group columns followed by aggregate state or final values. - output_schema: SchemaRef, - - /// Maximum rows per emitted output batch, from config `batch_size`. - batch_size: usize, - - /// Lifecycle-specific state: building stage / outputting stage - state: AggregateHashTableState, - - _mode: PhantomData, -} - -struct HashAggregateAccumulator { - /// Aggregate expression used to create a fresh accumulator for related - /// hash tables, such as the partial-skip table. - aggregate_expr: Arc, - - /// Arguments to pass to this accumulator. - /// - /// Example: `CORR(x, y)` stores two expressions here, while `SUM(x)` stores one. - arguments: Vec>, - - /// Optional `FILTER` expression for this accumulator. - /// - /// Example: `SUM(x) FILTER (WHERE x > 10)` stores the `x > 10` predicate. - filter: Option>, - - /// Accumulator state for all groups for one aggregate expression. - accumulator: Box, -} - -struct EvaluatedHashAggregateAccumulator { - arguments: Vec, - filter: Option, -} - -/// Evaluated all group by keys and accumulator args. -/// -/// e.g., `select k+1, sum(v*v) from t group by (k+1)`, this function evaluates -/// `k+1`, `v*v` -struct EvaluatedAggregateBatch { - /// One entry per grouping set; each entry contains all evaluated group key - /// arrays for the current input batch. - grouping_set_args: Vec>, - - /// Evaluated arguments and filters, one entry per aggregate expression. - accumulator_args: Vec, -} - -/// Hash table state while grouped aggregation is consuming input. -/// -/// This owns the coupled state for: -/// - evaluating group keys, -/// - interning each distinct group, -/// - mapping each input row to its group index, -/// - evaluating aggregate inputs, -/// - updating per-group accumulator state. -struct BuildingHashTableState { - /// GROUP BY expressions evaluated for each input batch. - group_by: Arc, - - /// Interned group keys. Accumulator state is stored separately by group index. - group_values: Box, - - /// Group index for each row in the current input batch. - /// - /// Each value indexes into `group_values`, and the same index is used by every - /// accumulator to update that group's aggregate state. - batch_group_indices: Vec, - - /// One item per aggregate expression. - /// - /// Example: `COUNT(x), SUM(y)` creates two items. Each item owns the input - /// expressions, optional filter, and accumulator state for all groups. - accumulators: Vec, -} - -enum AggregateHashTableState { - Building(BuildingHashTableState), - Outputting(BuildingHashTableState), - Done, -} - -impl HashAggregateAccumulator { - fn new( - aggregate_expr: Arc, - arguments: Vec>, - filter: Option>, - accumulator: Box, - ) -> Self { - Self { - aggregate_expr, - arguments, - filter, - accumulator, - } - } - - fn empty_like(&self) -> Result { - let accumulator = create_group_accumulator(&self.aggregate_expr)?; - Ok(Self::new( - Arc::clone(&self.aggregate_expr), - self.arguments.clone(), - self.filter.clone(), - accumulator, - )) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - let arguments = self - .arguments - .iter() - .map(|expr| { - expr.evaluate(batch) - .and_then(|value| value.into_array(batch.num_rows())) - }) - .collect::>()?; - - let filter = self - .filter - .as_ref() - .map(|filter| { - filter - .evaluate(batch) - .and_then(|value| value.into_array(batch.num_rows())) - }) - .transpose()?; - - Ok(EvaluatedHashAggregateAccumulator { arguments, filter }) - } - - fn update_batch( - &mut self, - values: &EvaluatedHashAggregateAccumulator, - group_indices: &[usize], - total_num_groups: usize, - ) -> Result<()> { - let filter = values.filter.as_ref().map(|filter| filter.as_boolean()); - self.accumulator.update_batch( - &values.arguments, - group_indices, - filter, - total_num_groups, - ) - } - - fn merge_batch( - &mut self, - values: &EvaluatedHashAggregateAccumulator, - group_indices: &[usize], - total_num_groups: usize, - ) -> Result<()> { - debug_assert!(values.filter.is_none()); - self.accumulator - .merge_batch(&values.arguments, group_indices, total_num_groups) - } - - fn evaluate_final(&mut self, emit_to: EmitTo) -> Result { - self.accumulator.evaluate(emit_to) - } - - fn state(&mut self, emit_to: EmitTo) -> Result> { - self.accumulator.state(emit_to) - } - - fn supports_convert_to_state(&self) -> bool { - self.accumulator.supports_convert_to_state() - } - - fn convert_to_state( - &mut self, - values: &EvaluatedHashAggregateAccumulator, - ) -> Result> { - let opt_filter = values.filter.as_ref().map(|filter| filter.as_boolean()); - self.accumulator - .convert_to_state(&values.arguments, opt_filter) - } - - fn null_arguments(&self, input_schema: &SchemaRef) -> Result> { - self.arguments - .iter() - .map(|expr| { - let data_type = expr.data_type(input_schema)?; - Ok(new_null_array(&data_type, 1)) - }) - .collect() - } -} - -impl AggregateHashTableState { - fn building(&self) -> &BuildingHashTableState { - let Self::Building(state) = self else { - unreachable!("hash aggregate table is not building") - }; - state - } - - fn building_mut(&mut self) -> &mut BuildingHashTableState { - let Self::Building(state) = self else { - unreachable!("hash aggregate table is not building") - }; - state - } -} - -impl AggregateHashTable { - fn new_with_filters( - agg: &AggregateExec, - partition: usize, - output_schema: SchemaRef, - batch_size: usize, - filters: Vec>>, - ) -> Result { - if batch_size == 0 { - return internal_err!("AggregateHashTable requires config batch_size >= 1"); - } - - let input_schema = agg.input().schema(); - let aggregate_arguments = aggregate_expressions( - &agg.aggr_expr, - &agg.mode, - agg.group_by.num_group_exprs(), - )?; - let accumulators: Vec<_> = agg - .aggr_expr - .iter() - .zip(aggregate_arguments) - .zip(filters) - .map(|((agg_expr, arguments), filter)| { - let accumulator = create_group_accumulator(agg_expr)?; - Ok(HashAggregateAccumulator::new( - Arc::clone(agg_expr), - arguments, - filter, - accumulator, - )) - }) - .collect::>()?; - - let group_schema = agg.group_by.group_schema(&input_schema)?; - let group_values = new_group_values(group_schema, &GroupOrdering::None)?; - - Ok(Self { - group_by_metrics: GroupByMetrics::new(&agg.metrics, partition), - input_schema, - output_schema, - batch_size, - state: AggregateHashTableState::Building(BuildingHashTableState { - group_by: Arc::clone(&agg.group_by), - group_values, - batch_group_indices: Default::default(), - accumulators, - }), - _mode: PhantomData, - }) - } - - /// See comments in [`EvaluatedAggregateBatch`] - fn evaluate_batch(&self, batch: &RecordBatch) -> Result { - let state = self.state.building(); - let timer = self.group_by_metrics.time_calculating_group_ids.timer(); - // outer vec: one per each grouping set - // inner vec: all group by exprs for the current grouping set - let grouping_set_args = evaluate_group_by(&state.group_by, batch)?; - drop(timer); - - let timer = self.group_by_metrics.aggregate_arguments_time.timer(); - // The evaluated args for each accumulator - let accumulator_args = self - .state - .building() - .accumulators - .iter() - .map(|acc| acc.evaluate(batch)) - .collect::>>()?; - drop(timer); - - Ok(EvaluatedAggregateBatch { - grouping_set_args, - accumulator_args, - }) - } - - pub(super) fn memory_size(&self) -> usize { - match &self.state { - AggregateHashTableState::Building(state) - | AggregateHashTableState::Outputting(state) => { - let acc = state - .accumulators - .iter() - .map(|acc| acc.accumulator.size()) - .sum::(); - - acc + state.group_values.size() - + state.batch_group_indices.allocated_size() - } - AggregateHashTableState::Done => 0, - } - } - - /// How many distinct groups has been accumulated now. - pub(super) fn building_group_count(&self) -> usize { - self.state.building().group_values.len() - } - - pub(super) fn is_building(&self) -> bool { - matches!(self.state, AggregateHashTableState::Building(_)) - } - - pub(super) fn is_done(&self) -> bool { - matches!(self.state, AggregateHashTableState::Done) - } - - fn start_outputting(&mut self) { - let AggregateHashTableState::Building(mut state) = - std::mem::replace(&mut self.state, AggregateHashTableState::Done) - else { - unreachable!("hash aggregate table is not building") - }; - - state.batch_group_indices = Vec::new(); - self.state = AggregateHashTableState::Outputting(state); - } -} - -fn emit_to_for_batch_size(batch_size: usize, group_count: usize) -> EmitTo { - debug_assert!(batch_size > 0); - if group_count <= batch_size { - EmitTo::All - } else { - EmitTo::First(batch_size) - } -} - -impl AggregateHashTable { - pub(super) fn new( - agg: &AggregateExec, - partition: usize, - output_schema: SchemaRef, - batch_size: usize, - ) -> Result { - Self::new_with_filters( - agg, - partition, - output_schema, - batch_size, - agg.filter_expr.iter().cloned().collect(), - ) - } - - /// Emits the next batch of aggregated group keys and aggregate states. - /// - /// The output batch size is determined by `self.batch_size`. - /// - /// Returns `Some(batch)` for each emitted batch, `None` when output is - /// exhausted, and an internal error if polled in the `Building` state. - pub(super) fn next_output_batch(&mut self) -> Result> { - let output_schema = Arc::clone(&self.output_schema); - let batch_size = self.batch_size; - match &mut self.state { - AggregateHashTableState::Outputting(state) => { - if state.group_values.is_empty() { - self.state = AggregateHashTableState::Done; - return Ok(None); - } - - let emit_to = - emit_to_for_batch_size(batch_size, state.group_values.len()); - let timer = self.group_by_metrics.emitting_time.timer(); - let mut output = state.group_values.emit(emit_to)?; - - for acc in state.accumulators.iter_mut() { - output.extend(acc.state(emit_to)?); - } - let done = state.group_values.is_empty(); - drop(timer); - - let batch = RecordBatch::try_new(output_schema, output)?; - debug_assert!(batch.num_rows() > 0); - if done { - self.state = AggregateHashTableState::Done; - } - Ok(Some(batch)) - } - AggregateHashTableState::Done => Ok(None), - AggregateHashTableState::Building(_) => { - internal_err!("next_output_batch must be called in the outputting state") - } - } - } - - pub(super) fn can_skip_aggregation(&self) -> bool { - self.state - .building() - .accumulators - .iter() - .all(|acc| acc.supports_convert_to_state()) - } - - /// In skip-partial-aggregation optimization, when a decision has made to skip - /// partial stage, build a typed hash table only for aggregation state conversion - /// row-by-row. - pub(super) fn partial_skip_table(&self) -> Result> { - let state = self.state.building(); - let group_schema = state.group_by.group_schema(&self.input_schema)?; - let group_values = new_group_values(group_schema, &GroupOrdering::None)?; - let accumulators = state - .accumulators - .iter() - .map(HashAggregateAccumulator::empty_like) - .collect::>>()?; - - Ok(AggregateHashTable { - group_by_metrics: self.group_by_metrics.clone(), - input_schema: Arc::clone(&self.input_schema), - output_schema: Arc::clone(&self.output_schema), - batch_size: self.batch_size, - state: AggregateHashTableState::Building(BuildingHashTableState { - group_by: Arc::clone(&state.group_by), - group_values, - batch_group_indices: Default::default(), - accumulators, - }), - _mode: PhantomData, - }) - } - - pub(super) fn aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { - let evaluated_batch = self.evaluate_batch(batch)?; - let state = self.state.building_mut(); - - let timer = self.group_by_metrics.aggregation_time.timer(); - for group_values in &evaluated_batch.grouping_set_args { - state - .group_values - .intern(group_values, &mut state.batch_group_indices)?; - let group_indices = &state.batch_group_indices; - let total_num_groups = state.group_values.len(); - - for (acc, values) in state - .accumulators - .iter_mut() - .zip(evaluated_batch.accumulator_args.iter()) - { - acc.update_batch(values, group_indices, total_num_groups)?; - } - } - drop(timer); - - Ok(()) - } - - pub(super) fn start_output(&mut self) -> Result<()> { - self.init_empty_grouping_sets()?; - self.start_outputting(); - Ok(()) - } - - /// Creates the required empty grouping-set rows when the input is empty. - /// - /// For example, this query must still produce one grand-total group even if - /// `t` has no rows: - /// - /// ```sql - /// SELECT COUNT(v) - /// FROM t - /// GROUP BY GROUPING SETS (()); - /// ``` - /// - /// The synthetic row is filtered out before accumulator update so aggregates - /// see the same state they would see for an empty input, rather than a real - /// null-valued row. - fn init_empty_grouping_sets(&mut self) -> Result<()> { - let state = self.state.building_mut(); - if !state.group_by.has_grouping_set() || !state.group_values.is_empty() { - return Ok(()); - } - - let max_ordinal = max_duplicate_ordinal(state.group_by.groups()); - let mut ordinals: HashMap<&[bool], usize> = HashMap::new(); - let group_schema = state.group_by.group_schema(&self.input_schema)?; - let n_expr = state.group_by.expr().len(); - let mut any_interned = false; - - for group in state.group_by.groups() { - let ordinal = { - let entry = ordinals.entry(group.as_slice()).or_insert(0); - let ordinal = *entry; - *entry += 1; - ordinal - }; - - if !group.iter().all(|&is_null| is_null) { - continue; - } - - let mut cols: Vec = group_schema - .fields() - .iter() - .take(n_expr) - .map(|field| new_null_array(field.data_type(), 1)) - .collect(); - cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); - - state - .group_values - .intern(&cols, &mut state.batch_group_indices)?; - any_interned = true; - } - - if any_interned { - let total_groups = state.group_values.len(); - let false_filter = BooleanArray::from(vec![false]); - for acc in state.accumulators.iter_mut() { - let null_args = acc.null_arguments(&self.input_schema)?; - let values = EvaluatedHashAggregateAccumulator { - arguments: null_args, - filter: Some(Arc::new(false_filter.clone())), - }; - acc.update_batch(&values, &[0], total_groups)?; - } - } - - Ok(()) - } -} - -impl AggregateHashTable { - pub(super) fn convert_batch_to_state( - &mut self, - batch: &RecordBatch, - ) -> Result { - let evaluated_batch = self.evaluate_batch(batch)?; - - assert_eq_or_internal_err!( - evaluated_batch.grouping_set_args.len(), - 1, - "group_values expected to have single element" - ); - let mut output = evaluated_batch - .grouping_set_args - .into_iter() - .next() - .unwrap_or_default(); - - let state = self.state.building_mut(); - for (acc, values) in state - .accumulators - .iter_mut() - .zip(evaluated_batch.accumulator_args.iter()) - { - output.extend(acc.convert_to_state(values)?); - } - - Ok(RecordBatch::try_new( - Arc::clone(&self.output_schema), - output, - )?) - } -} - -impl AggregateHashTable { - pub(super) fn new( - agg: &AggregateExec, - partition: usize, - output_schema: SchemaRef, - batch_size: usize, - ) -> Result { - Self::new_with_filters( - agg, - partition, - output_schema, - batch_size, - vec![None; agg.aggr_expr.len()], - ) - } - - /// Emits the next batch of aggregated group keys and aggregate states. - /// - /// The output batch size is determined by `self.batch_size`. - /// - /// Returns `Some(batch)` for each emitted batch, `None` when output is - /// exhausted, and an internal error if polled in the `Building` state. - pub(super) fn next_output_batch(&mut self) -> Result> { - let output_schema = Arc::clone(&self.output_schema); - let batch_size = self.batch_size; - match &mut self.state { - AggregateHashTableState::Outputting(state) => { - if state.group_values.is_empty() { - self.state = AggregateHashTableState::Done; - return Ok(None); - } - - let emit_to = - emit_to_for_batch_size(batch_size, state.group_values.len()); - let timer = self.group_by_metrics.emitting_time.timer(); - let mut output = state.group_values.emit(emit_to)?; - - for acc in state.accumulators.iter_mut() { - output.push(acc.evaluate_final(emit_to)?); - } - let done = state.group_values.is_empty(); - drop(timer); - - let batch = RecordBatch::try_new(output_schema, output)?; - debug_assert!(batch.num_rows() > 0); - if done { - self.state = AggregateHashTableState::Done; - } - Ok(Some(batch)) - } - AggregateHashTableState::Done => Ok(None), - AggregateHashTableState::Building(_) => { - internal_err!("next_output_batch must be called in the outputting state") - } - } - } - - pub(super) fn aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { - let evaluated_batch = self.evaluate_batch(batch)?; - let state = self.state.building_mut(); - - let timer = self.group_by_metrics.aggregation_time.timer(); - for group_values in &evaluated_batch.grouping_set_args { - state - .group_values - .intern(group_values, &mut state.batch_group_indices)?; - let group_indices = &state.batch_group_indices; - let total_num_groups = state.group_values.len(); - - for (acc, values) in state - .accumulators - .iter_mut() - .zip(evaluated_batch.accumulator_args.iter()) - { - acc.merge_batch(values, group_indices, total_num_groups)?; - } - } - drop(timer); - - Ok(()) - } - - pub(super) fn start_output(&mut self) -> Result<()> { - self.start_outputting(); - Ok(()) - } -} diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e1c598e02dfff..08468bffc0dd9 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -72,9 +72,9 @@ use itertools::Itertools; use topk::hash_table::is_supported_hash_key_type; use topk::heap::is_supported_heap_type; +mod aggregate_hash_table; pub mod group_values; mod hash_aggregate; -mod hash_table; mod no_grouping; pub mod order; mod row_hash; From 6feef688084036c9d0f1480ef3e7e817a26d82d9 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Sat, 20 Jun 2026 21:02:43 +0800 Subject: [PATCH 3/3] small comments update --- .../src/aggregates/aggregate_hash_table/common.rs | 3 ++- .../src/aggregates/aggregate_hash_table/final_table.rs | 1 + .../src/aggregates/aggregate_hash_table/partial_table.rs | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs index 0b1cb422d345f..f49edc8a9b04f 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/common.rs @@ -272,7 +272,8 @@ impl AggregateHashTableState { } } -impl AggregateHashTable { +/// Methods shared by all aggregate hash table modes. +impl AggregateHashTable { pub(super) fn new_with_filters( agg: &AggregateExec, partition: usize, diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs index e318d22c1dc5e..e802d7934bce3 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/final_table.rs @@ -27,6 +27,7 @@ use super::common::{ AggregateHashTable, AggregateHashTableState, Final, emit_to_for_batch_size, }; +/// Methods specific to the aggregate hash table used in the final aggregation stage. impl AggregateHashTable { pub(in crate::aggregates) fn new( agg: &AggregateExec, diff --git a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs index 0afea3a33b43c..cc42ec5e67125 100644 --- a/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs +++ b/datafusion/physical-plan/src/aggregates/aggregate_hash_table/partial_table.rs @@ -34,6 +34,7 @@ use super::common::{ emit_to_for_batch_size, }; +/// Methods specific to the aggregate hash table used in the partial aggregation stage. impl AggregateHashTable { pub(in crate::aggregates) fn new( agg: &AggregateExec,