diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 6565d4f187339..18ddc361a0692 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -23,7 +23,7 @@ use datafusion_common::JoinType; use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, DFSchemaRef, Result, ScalarValue, plan_err}; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{EmptyRelation, Expr, Projection, Union, cast, lit}; +use datafusion_expr::{EmptyRelation, Expr, GroupingSet, Projection, Union, cast, lit}; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; @@ -174,7 +174,13 @@ impl OptimizerRule for PropagateEmptyRelation { } } LogicalPlan::Aggregate(ref agg) => { + // An aggregate over an empty input can be eliminated only when + // there is no empty grouping set. An empty grouping set `()` + // (from `GROUPING SETS(())`, `ROLLUP(...)`, or `CUBE(...)`) + // always produces exactly one row even on empty input, so it + // must not be replaced by an empty relation. if !agg.group_expr.is_empty() + && !has_empty_grouping_set(&agg.group_expr) && let Some(empty_plan) = empty_child(&plan)? { return Ok(Transformed::yes(empty_plan)); @@ -315,6 +321,30 @@ fn build_null_padded_projection( )?)) } +/// Returns `true` if any grouping set in the list of GROUP BY expressions is +/// the empty set `()`. +/// +/// An empty grouping set acts as a "grand total" group: the aggregate must +/// always produce **exactly one row** for it, even when the input is empty. +/// This means an aggregate with an empty grouping set cannot be replaced by +/// an empty relation. +/// +/// The three forms that can contain an empty grouping set: +/// - `GROUPING SETS (…, (), …)` — explicitly listed. +/// - `ROLLUP(exprs)` — always expands to include `()`. +/// - `CUBE(exprs)` — always expands to include `()`. +fn has_empty_grouping_set(group_expr: &[Expr]) -> bool { + match group_expr.first() { + Some(Expr::GroupingSet(GroupingSet::GroupingSets(groups))) => { + groups.iter().any(|g| g.is_empty()) + } + // Both ROLLUP and CUBE always include the empty grouping set (). + Some(Expr::GroupingSet(GroupingSet::Rollup(_))) + | Some(Expr::GroupingSet(GroupingSet::Cube(_))) => true, + _ => false, + } +} + #[cfg(test)] mod tests { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 76ecb3f1485a4..7cb549736cced 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -2063,11 +2063,11 @@ fn evaluate_optional( /// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 / /// UInt64` that can represent both parts. It matches the type returned by /// [`Aggregate::grouping_id_type`]. -fn group_id_array( +pub(crate) fn group_id_array( group: &[bool], ordinal: usize, max_ordinal: usize, - batch: &RecordBatch, + num_rows: usize, ) -> Result { let n = group.len(); if n > 64 { @@ -2087,7 +2087,6 @@ fn group_id_array( (acc << 1) | if is_null { 1 } else { 0 } }); let full_id = semantic_id | ((ordinal as u64) << n); - let num_rows = batch.num_rows(); if total_bits <= 8 { Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows]))) } else if total_bits <= 16 { @@ -2106,7 +2105,7 @@ fn group_id_array( /// ordinal 0, the second gets 1, and so on. If the same `Vec` appears /// three times the ordinals are 0, 1, 2 and this function returns 2. /// Returns 0 when no grouping set is duplicated. -fn max_duplicate_ordinal(groups: &[Vec]) -> usize { +pub(crate) fn max_duplicate_ordinal(groups: &[Vec]) -> usize { let mut counts: HashMap<&[bool], usize> = HashMap::new(); for group in groups { *counts.entry(group).or_insert(0) += 1; @@ -2160,7 +2159,7 @@ pub fn evaluate_group_by( group, current_ordinal, max_ordinal, - batch, + batch.num_rows(), )?); } Ok(group_values) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 056a7f171a516..a65aaf9134fe8 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -27,7 +27,8 @@ use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_val use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy, - create_schema, evaluate_group_by, evaluate_many, evaluate_optional, + create_schema, evaluate_group_by, evaluate_many, evaluate_optional, group_id_array, + max_duplicate_ordinal, }; use crate::metrics::{BaselineMetrics, MetricBuilder, MetricCategory, RecordOutput}; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; @@ -360,6 +361,7 @@ pub(crate) struct GroupedHashAggregateStream { // the execution. // ======================================================================== schema: SchemaRef, + input_schema: SchemaRef, input: SendableRecordBatchStream, mode: AggregateMode, @@ -661,6 +663,7 @@ impl GroupedHashAggregateStream { Ok(GroupedHashAggregateStream { schema: agg_schema, + input_schema: agg.input().schema(), input, mode: agg.mode, accumulators, @@ -1125,6 +1128,104 @@ impl GroupedHashAggregateStream { Ok(Some(batch)) } + /// Registers groups for empty grouping sets when no input rows were seen. + /// + /// `GROUP BY GROUPING SETS (())` must always produce one row even when there + /// are no input rows (standard SQL semantics for a "grand total" group). + /// Mixed grouping sets like `GROUPING SETS (a, ())` also produce one row for + /// the empty set `()` on empty input. + /// + /// This method interns the group keys and primes the accumulators so they + /// produce their zero-row aggregate values (e.g. `NULL` for `SUM`, + /// `0` for `COUNT`). + fn init_empty_grouping_sets(&mut self) -> Result<()> { + if !self.group_by.has_grouping_set() || !self.group_values.is_empty() { + return Ok(()); + } + + let max_ordinal = max_duplicate_ordinal(self.group_by.groups()); + let mut ordinals: std::collections::HashMap<&[bool], usize> = + std::collections::HashMap::new(); + let group_schema = self.group_by.group_schema(&self.input_schema)?; + let n_expr = self.group_by.expr().len(); + let mut any_interned = false; + + for group in self.group_by.groups() { + let ordinal = { + let entry = ordinals.entry(group.as_slice()).or_insert(0); + let o = *entry; + *entry += 1; + o + }; + + if !group.iter().all(|&is_null| is_null) { + continue; + } + + // Build the group key: one NULL per group-by expression, then the grouping_id. + let mut cols: Vec = group_schema + .fields() + .iter() + .take(n_expr) + .map(|f| new_null_array(f.data_type(), 1)) + .collect(); + cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); + + let starting_groups = self.group_values.len(); + self.group_values + .intern(&cols, &mut self.current_group_indices)?; + let total_groups = self.group_values.len(); + if total_groups > starting_groups { + self.group_ordering.new_groups( + &cols, + &self.current_group_indices, + total_groups, + )?; + } + any_interned = true; + } + + if any_interned { + // Prime each accumulator for the registered group count with no data. + // + // We build 1-row null arrays for each aggregate argument and pass them + // with an all-false filter. The filter ensures no row is accumulated + // into any group, which keeps every group in its "zero" initial state + // (NULL for SUM/AVG/MIN/MAX, 0 for COUNT). + // + // Using a 1-row batch rather than 0 rows is required to avoid a fast + // path in `NullState::accumulate` that treats "0 nulls in a 0-row + // array" as "all groups have been seen", which would cause SUM to + // return 0 instead of NULL. + // + // Argument types are inferred directly from the expression metadata so + // we never need to construct a full `RecordBatch`. + let total_groups = self.group_values.len(); + let null_args: Vec> = self + .aggregate_arguments + .iter() + .map(|args| { + args.iter() + .map(|expr| { + let dt = expr.data_type(&self.input_schema)?; + Ok(new_null_array(&dt, 1)) + }) + .collect::>>() + }) + .collect::>>()?; + let false_filter = BooleanArray::from(vec![false]); + for (acc, args) in self.accumulators.iter_mut().zip(null_args.iter()) { + if self.mode.input_mode() == AggregateInputMode::Raw { + acc.update_batch(args, &[0], Some(&false_filter), total_groups)?; + } else { + acc.merge_batch(args, &[0], Some(&false_filter), total_groups)?; + } + } + } + + Ok(()) + } + /// Emit all intermediate aggregation states, sort them, and store them on disk. /// This process helps in reducing memory pressure by allowing the data to be /// read back with streaming merge. @@ -1223,6 +1324,7 @@ impl GroupedHashAggregateStream { let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { // Input has been entirely processed without spilling to disk. + self.init_empty_grouping_sets()?; // Flush any remaining group values. let batch = self.emit(EmitTo::All, false)?; diff --git a/datafusion/sqllogictest/test_files/grouping.slt b/datafusion/sqllogictest/test_files/grouping.slt index 3d38576bdbf5f..8082eb74088a2 100644 --- a/datafusion/sqllogictest/test_files/grouping.slt +++ b/datafusion/sqllogictest/test_files/grouping.slt @@ -224,3 +224,22 @@ query I SELECT SUM(v1) FROM generate_series(10) AS t1(v1) GROUP BY GROUPING SETS(()) ---- 55 + +# grouping_sets_empty_input: GROUPING SETS (()) must produce one NULL row on empty input +# (standard SQL: the empty grouping set always defines exactly one group) +query I +SELECT SUM(v1) FROM generate_series(10) AS t1(v1) WHERE false GROUP BY GROUPING SETS(()) +---- +NULL + +# grouping_sets_empty_input_count: COUNT returns 0 for the empty group, not a missing row +query I +SELECT COUNT(*) FROM generate_series(10) AS t1(v1) WHERE false GROUP BY GROUPING SETS(()) +---- +0 + +# grouping_sets_mixed_empty_and_non_empty: only the empty set (()) produces a row on empty input +query II +SELECT SUM(v1), COUNT(*) FROM generate_series(10) AS t1(v1) WHERE false GROUP BY GROUPING SETS((), (v1)) +---- +NULL 0