From dbd05b56b4df7174b0fabcaf459c64e2c0a2d9d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 4 May 2026 12:57:16 +0200 Subject: [PATCH 1/5] Optimize ClickBench q17 aggregate limit --- benchmarks/src/bin/dfbench.rs | 7 +- benchmarks/src/bin/imdb.rs | 7 +- .../limited_distinct_aggregation.rs | 70 ++++++++++- .../src/limited_distinct_aggregation.rs | 41 +++++-- .../physical-plan/src/aggregates/mod.rs | 25 ++-- .../physical-plan/src/aggregates/row_hash.rs | 114 ++++++++++++++++-- .../sqllogictest/test_files/clickbench.slt | 4 +- 7 files changed, 223 insertions(+), 45 deletions(-) diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 3b1f54291e75c..01a10db04a731 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -20,12 +20,7 @@ use datafusion::error::Result; use clap::{Parser, Subcommand}; -#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] -compile_error!( - "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" -); - -#[cfg(feature = "snmalloc")] +#[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index e86735f87b8f1..026f5b55d1e21 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -21,12 +21,7 @@ use clap::{Parser, Subcommand}; use datafusion::error::Result; use datafusion_benchmarks::imdb; -#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] -compile_error!( - "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" -); - -#[cfg(feature = "snmalloc")] +#[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index c523b4a752a82..31ced5786a7be 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -28,11 +28,12 @@ use crate::physical_optimizer::test_utils::{ use arrow::datatypes::DataType; use arrow::{compute::SortOptions, util::pretty::pretty_format_batches}; use datafusion::prelude::SessionContext; -use datafusion_common::Result; +use datafusion_common::{Result, config::ConfigOptions}; use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{self, cast, col}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::{ ExecutionPlan, aggregates::{AggregateExec, AggregateMode}, @@ -332,7 +333,7 @@ fn test_has_aggregate_expression() -> Result<()> { let schema = source.schema(); let agg = TestAggregate::new_count_star(); - // `SELECT FROM DataSourceExec LIMIT 10;`, Single AggregateExec + // `SELECT a, COUNT(*) FROM DataSourceExec GROUP BY a LIMIT 10;`, Single AggregateExec let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema, vec!["a".to_string()]), @@ -345,7 +346,7 @@ fn test_has_aggregate_expression() -> Result<()> { Arc::new(single_agg), 10, // fetch ); - // expected not to push the limit to the AggregateExec + // expected to push the limit to the AggregateExec let plan: Arc = Arc::new(limit_exec); let formatted = get_optimized_plan(&plan)?; let actual = formatted.trim(); @@ -353,13 +354,74 @@ fn test_has_aggregate_expression() -> Result<()> { actual, @r" LocalLimitExec: fetch=10 - AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)] + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)], lim=[10] DataSourceExec: partitions=1, partition_sizes=[1] " ); Ok(()) } +#[tokio::test] +async fn test_partial_final_with_aggregate_expression() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // `SELECT a, COUNT(*) FROM DataSourceExec GROUP BY a LIMIT 4;`, + // Partial/Final AggregateExec. Both stages can keep the same deterministic + // top-k group keys. + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![None], /* filter_expr */ + Arc::new(partial_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(final_agg), + 4, // fetch + ); + let plan: Arc = Arc::new(limit_exec); + let formatted = get_optimized_plan(&plan)?; + let actual = formatted.trim(); + assert_snapshot!( + actual, + @r" + LocalLimitExec: fetch=4 + AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(*)], lim=[4] + AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[COUNT(*)], lim=[4] + DataSourceExec: partitions=1, partition_sizes=[1] + " + ); + let optimized = + datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation::new() + .optimize(Arc::clone(&plan), &ConfigOptions::new())?; + let expected = run_plan_and_format(optimized).await?; + assert_snapshot!( + expected, + @r" + +---+----------+ + | a | COUNT(*) | + +---+----------+ + | | 1 | + | 1 | 2 | + | 2 | 1 | + | 4 | 1 | + +---+----------+ + " + ); + Ok(()) +} + #[test] fn test_has_filter() -> Result<()> { let source = mock_data()?; diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs index 852dc2a2a9434..331205a726761 100644 --- a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! A special-case optimizer rule that pushes limit into a grouped aggregation -//! which has no aggregate expressions or sorting requirements +//! A special-case optimizer rule that pushes limit into unordered grouped +//! aggregation when the query only needs an arbitrary subset of groups. use std::sync::Arc; -use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions}; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, LimitOptions}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -32,9 +32,10 @@ use crate::PhysicalOptimizerRule; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all -/// rows in the group to be processed for correctness. Example queries fitting this description are: +/// groups to be produced for correctness. Example queries fitting this description are: /// - `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` /// - `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` +/// - `SELECT l_orderkey, COUNT(*) FROM lineitem GROUP BY l_orderkey LIMIT 10;` #[derive(Debug)] pub struct LimitedDistinctAggregation {} @@ -48,21 +49,43 @@ impl LimitedDistinctAggregation { aggr: &AggregateExec, limit: usize, ) -> Option> { - // rules for transforming this Aggregate are held in this method - if !aggr.is_unordered_unfiltered_group_by_distinct() { + if aggr.is_unordered_unfiltered_group_by_distinct() { + let new_aggr = aggr.with_new_limit_options(Some(LimitOptions::new(limit))); + return Some(Arc::new(new_aggr)); + } + + if !Self::can_limit_aggregate(aggr) { return None; } - // We found what we want: clone, copy the limit down, and return modified node let new_aggr = aggr.with_new_limit_options(Some(LimitOptions::new(limit))); Some(Arc::new(new_aggr)) } + fn can_limit_aggregate(aggr: &AggregateExec) -> bool { + if !aggr.is_unordered_unfiltered_group_by() { + return false; + } + if aggr.aggr_expr().is_empty() { + return false; + } + if !aggr.group_expr().is_single() { + return false; + } + matches!( + aggr.mode(), + AggregateMode::Partial + | AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned + ) + } + /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when - /// there is a group by, but no sorting, no aggregate expressions, and no filters in the - /// aggregation + /// there is a group by, but no sorting or filters in the aggregation fn transform_limit(plan: Arc) -> Option> { let limit: usize; let mut global_fetch: Option = None; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 76ecb3f1485a4..8ecbfe84f580e 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -915,6 +915,7 @@ impl AggregateExec { // grouping by an expression that has a sort/limit upstream if let Some(config) = self.limit_options + && config.descending().is_some() && !self.is_unordered_unfiltered_group_by_distinct() { return Ok(StreamType::GroupedPriorityQueue( @@ -934,11 +935,9 @@ impl AggregateExec { agg_expr.get_minmax_desc() } - /// true, if this Aggregate has a group-by with no required or explicit ordering, - /// no filtering and no aggregate expressions - /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule - /// on an AggregateExec. - pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + /// true if this Aggregate has a group-by with no required or explicit + /// ordering and no aggregate filters. + pub fn is_unordered_unfiltered_group_by(&self) -> bool { if self .limit_options() .and_then(|config| config.descending) @@ -950,12 +949,7 @@ impl AggregateExec { if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() { return false; } - // ensure there are no aggregate expressions - if !self.aggr_expr().is_empty() { - return false; - } - // ensure there are no filters on aggregate expressions; the above check - // may preclude this case + // ensure there are no filters on aggregate expressions if self.filter_expr().iter().any(|e| e.is_some()) { return false; } @@ -974,6 +968,15 @@ impl AggregateExec { true } + /// true, if this Aggregate has a group-by with no required or explicit ordering, + /// no filtering and no aggregate expressions. + /// + /// This method qualifies the distinct-only use of the + /// LimitedDistinctAggregation rewrite rule on an AggregateExec. + pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + self.aggr_expr().is_empty() && self.is_unordered_unfiltered_group_by() + } + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. pub fn compute_properties( input: &Arc, diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 056a7f171a516..62dd667013c0b 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -36,6 +36,7 @@ use crate::{PhysicalExpr, aggregates, metrics}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; +use arrow::compute::{SortColumn, SortOptions, lexsort_to_indices, take}; use arrow::datatypes::SchemaRef; use datafusion_common::{ DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err, @@ -372,6 +373,9 @@ pub(crate) struct GroupedHashAggregateStream { /// argument. aggregate_arguments: Vec>>, + /// Number of intermediate state columns produced by each accumulator. + state_field_counts: Vec, + /// Optional filter expression to evaluate, one for each for /// accumulator. If present, only those rows for which the filter /// evaluate to true should be included in the aggregate results. @@ -386,10 +390,11 @@ pub(crate) struct GroupedHashAggregateStream { /// max rows in output RecordBatches batch_size: usize, - /// Optional soft limit on the number of `group_values` in a batch - /// If the number of `group_values` in a single batch exceeds this value, - /// the `GroupedHashAggregateStream` operation immediately switches to - /// output mode and emits all groups. + /// Optional soft limit on the number of `group_values`. + /// + /// Distinct-style aggregates can stop once this many groups have been seen. + /// Aggregates with accumulator state keep the top `limit` group keys after + /// each input batch. group_values_soft_limit: Option, // ======================================================================== @@ -477,6 +482,10 @@ impl GroupedHashAggregateStream { let timer = baseline_metrics.elapsed_compute().timer(); let aggregate_exprs = Arc::clone(&agg.aggr_expr); + let state_field_counts = aggregate_exprs + .iter() + .map(|expr| expr.state_fields().map(|fields| fields.len())) + .collect::>>()?; // arguments for each aggregate, one vec of expressions per // aggregate @@ -626,6 +635,7 @@ impl GroupedHashAggregateStream { // aggregate state conversion // - there is only one GROUP BY expressions set let skip_aggregation_probe = if agg.mode == AggregateMode::Partial + && agg.limit_options().is_none() && matches!(group_ordering, GroupOrdering::None) && accumulators .iter() @@ -665,6 +675,7 @@ impl GroupedHashAggregateStream { mode: agg.mode, accumulators, aggregate_arguments, + state_field_counts, filter_expressions, group_by: agg_group_by, reservation, @@ -738,9 +749,11 @@ impl Stream for GroupedHashAggregateStream { assert!(!self.input_done); - // If the number of group values equals or exceeds the soft limit, - // emit all groups and switch to producing output - if self.hit_soft_group_limit() { + // Distinct-style aggregation can stop once enough groups have been + // found. Aggregates with accumulator state must keep reading input + // so the selected groups' aggregate values remain exact. + if self.accumulators.is_empty() && self.hit_soft_group_limit() + { timer.done(); self.set_input_done_and_produce_output()?; // make sure the exec_state just set is not overwritten below @@ -1006,6 +1019,8 @@ impl GroupedHashAggregateStream { } } + self.prune_to_group_key_topk()?; + Ok(()) } @@ -1086,6 +1101,91 @@ impl GroupedHashAggregateStream { reservation_result } + /// Keep only the smallest group keys for unordered `GROUP BY ... LIMIT` + /// aggregation. + /// + /// This is safe in partial aggregation because every partition uses the + /// same deterministic key order. Any globally top-k key is also in the + /// local top-k for every partition where it appears, so its partial state is + /// never dropped. + fn prune_to_group_key_topk(&mut self) -> Result<()> { + if !self.should_prune_to_group_key_topk() { + return Ok(()); + } + + let limit = self.group_values_soft_limit.unwrap(); + if limit == 0 { + self.clear_all(); + self.update_memory_reservation()?; + return Ok(()); + } + if self.group_values.len() <= limit { + return Ok(()); + } + + let Some(batch) = self.emit(EmitTo::All, true)? else { + return Ok(()); + }; + self.clear_shrink(0); + + let group_count = self.group_by.num_group_exprs(); + let sort_columns = batch + .columns() + .iter() + .take(group_count) + .map(|values| SortColumn { + values: Arc::clone(values), + options: Some(SortOptions::default()), + }) + .collect::>(); + let indices = lexsort_to_indices(&sort_columns, Some(limit))?; + let columns = batch + .columns() + .iter() + .map(|array| Ok(take(array.as_ref(), &indices, None)?)) + .collect::>>()?; + let batch = RecordBatch::try_new(Arc::clone(batch.schema_ref()), columns)?; + + let group_values = batch + .columns() + .iter() + .take(group_count) + .cloned() + .collect::>(); + self.group_values + .intern(&group_values, &mut self.current_group_indices)?; + let group_indices = &self.current_group_indices; + let total_num_groups = self.group_values.len(); + + let mut column_index = group_count; + for (acc, state_field_count) in self + .accumulators + .iter_mut() + .zip(self.state_field_counts.iter().copied()) + { + let next_column_index = column_index + state_field_count; + let values = batch.columns()[column_index..next_column_index].to_vec(); + acc.merge_batch(&values, group_indices, None, total_num_groups)?; + column_index = next_column_index; + } + assert_eq_or_internal_err!( + column_index, + batch.num_columns(), + "Mismatch rebuilding limited aggregate state" + ); + + self.update_memory_reservation()?; + + Ok(()) + } + + fn should_prune_to_group_key_topk(&self) -> bool { + self.group_values_soft_limit.is_some() + && !self.accumulators.is_empty() + && !self.spill_state.is_stream_merging + && matches!(self.group_ordering, GroupOrdering::None) + } + /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index c79701e347109..29c9810e874fa 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -495,9 +495,9 @@ logical_plan physical_plan 01)ProjectionExec: expr=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase, count(Int64(1))@2 as count(*)] 02)--CoalescePartitionsExec: fetch=10 -03)----AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))] +03)----AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))], lim=[10] 04)------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=1 -05)--------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))] +05)--------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))], lim=[10] 06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet query ITI rowsort From 5c40e583dcd6fb3e40c89b7068144ee5eb517b9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 4 May 2026 13:10:03 +0200 Subject: [PATCH 2/5] Remove benchmark allocator cfg changes --- benchmarks/src/bin/dfbench.rs | 7 ++++++- benchmarks/src/bin/imdb.rs | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 01a10db04a731..3b1f54291e75c 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -20,7 +20,12 @@ use datafusion::error::Result; use clap::{Parser, Subcommand}; -#[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))] +#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] +compile_error!( + "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" +); + +#[cfg(feature = "snmalloc")] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs index 026f5b55d1e21..e86735f87b8f1 100644 --- a/benchmarks/src/bin/imdb.rs +++ b/benchmarks/src/bin/imdb.rs @@ -21,7 +21,12 @@ use clap::{Parser, Subcommand}; use datafusion::error::Result; use datafusion_benchmarks::imdb; -#[cfg(all(feature = "snmalloc", not(feature = "mimalloc")))] +#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] +compile_error!( + "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" +); + +#[cfg(feature = "snmalloc")] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; From d61174244a287156679ee7a9b523807900b22f52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 4 May 2026 13:48:30 +0200 Subject: [PATCH 3/5] Simplify aggregate limit prefiltering --- .../limited_distinct_aggregation.rs | 70 +---------- datafusion/optimizer/src/push_down_limit.rs | 91 +++++++++++++- .../src/limited_distinct_aggregation.rs | 41 ++----- .../physical-plan/src/aggregates/mod.rs | 25 ++-- .../physical-plan/src/aggregates/row_hash.rs | 114 ++---------------- .../sqllogictest/test_files/aggregate.slt | 19 ++- .../sqllogictest/test_files/clickbench.slt | 24 +++- 7 files changed, 151 insertions(+), 233 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index 31ced5786a7be..c523b4a752a82 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -28,12 +28,11 @@ use crate::physical_optimizer::test_utils::{ use arrow::datatypes::DataType; use arrow::{compute::SortOptions, util::pretty::pretty_format_batches}; use datafusion::prelude::SessionContext; -use datafusion_common::{Result, config::ConfigOptions}; +use datafusion_common::Result; use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{self, cast, col}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::{ ExecutionPlan, aggregates::{AggregateExec, AggregateMode}, @@ -333,7 +332,7 @@ fn test_has_aggregate_expression() -> Result<()> { let schema = source.schema(); let agg = TestAggregate::new_count_star(); - // `SELECT a, COUNT(*) FROM DataSourceExec GROUP BY a LIMIT 10;`, Single AggregateExec + // `SELECT FROM DataSourceExec LIMIT 10;`, Single AggregateExec let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema, vec!["a".to_string()]), @@ -346,7 +345,7 @@ fn test_has_aggregate_expression() -> Result<()> { Arc::new(single_agg), 10, // fetch ); - // expected to push the limit to the AggregateExec + // expected not to push the limit to the AggregateExec let plan: Arc = Arc::new(limit_exec); let formatted = get_optimized_plan(&plan)?; let actual = formatted.trim(); @@ -354,74 +353,13 @@ fn test_has_aggregate_expression() -> Result<()> { actual, @r" LocalLimitExec: fetch=10 - AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)], lim=[10] + AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)] DataSourceExec: partitions=1, partition_sizes=[1] " ); Ok(()) } -#[tokio::test] -async fn test_partial_final_with_aggregate_expression() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - // `SELECT a, COUNT(*) FROM DataSourceExec GROUP BY a LIMIT 4;`, - // Partial/Final AggregateExec. Both stages can keep the same deterministic - // top-k group keys. - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ - vec![None], /* filter_expr */ - Arc::new(partial_agg), /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(final_agg), - 4, // fetch - ); - let plan: Arc = Arc::new(limit_exec); - let formatted = get_optimized_plan(&plan)?; - let actual = formatted.trim(); - assert_snapshot!( - actual, - @r" - LocalLimitExec: fetch=4 - AggregateExec: mode=Final, gby=[a@0 as a], aggr=[COUNT(*)], lim=[4] - AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[COUNT(*)], lim=[4] - DataSourceExec: partitions=1, partition_sizes=[1] - " - ); - let optimized = - datafusion_physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation::new() - .optimize(Arc::clone(&plan), &ConfigOptions::new())?; - let expected = run_plan_and_format(optimized).await?; - assert_snapshot!( - expected, - @r" - +---+----------+ - | a | COUNT(*) | - +---+----------+ - | | 1 | - | 1 | 2 | - | 2 | 1 | - | 4 | 1 | - +---+----------+ - " - ); - Ok(()) -} - #[test] fn test_has_filter() -> Result<()> { let source = mock_data()?; diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 4a26cd5884f6b..b604bca101252 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -18,16 +18,17 @@ //! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan use std::cmp::min; +use std::collections::HashSet; use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::Result; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; -use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; -use datafusion_expr::{FetchType, SkipType, lit}; +use datafusion_common::{NullEquality, Result}; +use datafusion_expr::logical_plan::{Aggregate, Join, JoinType, Limit, LogicalPlan}; +use datafusion_expr::{Expr, FetchType, LogicalPlanBuilder, SkipType, lit}; /// Optimization rule that tries to push down `LIMIT`. //. It will push down through projection, limits (taking the smaller limit) @@ -47,7 +48,6 @@ impl OptimizerRule for PushDownLimit { true } - #[expect(clippy::only_used_in_recursion)] fn rewrite( &self, plan: LogicalPlan, @@ -123,6 +123,21 @@ impl OptimizerRule for PushDownLimit { make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) })), + LogicalPlan::Aggregate(aggregate) + if config + .options() + .optimizer + .enable_distinct_aggregation_soft_limit => + { + if let Some(aggregate) = + prefilter_limited_aggregate(aggregate.clone(), fetch + skip)? + { + transformed_limit(skip, fetch, aggregate) + } else { + original_limit(skip, fetch, LogicalPlan::Aggregate(aggregate)) + } + } + LogicalPlan::Sort(mut sort) => { let new_fetch = { let sort_fetch = skip + fetch; @@ -237,6 +252,74 @@ fn transformed_limit( Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input)))) } +/// Rewrite `LIMIT K (GROUP BY keys, aggs)` into a key preselection followed +/// by a semi join. This keeps the aggregate itself ordinary while letting the +/// join's dynamic filter push the selected key set into the second input scan. +fn prefilter_limited_aggregate( + aggregate: Aggregate, + limit: usize, +) -> Result> { + if limit == 0 || aggregate.aggr_expr.is_empty() || aggregate.group_expr.is_empty() { + return Ok(None); + } + if is_key_prefiltered_aggregate(&aggregate) { + return Ok(None); + } + + let mut seen_columns = HashSet::with_capacity(aggregate.group_expr.len()); + let mut join_columns = Vec::with_capacity(aggregate.group_expr.len()); + for expr in &aggregate.group_expr { + let Expr::Column(column) = expr else { + return Ok(None); + }; + if !seen_columns.insert(column.clone()) { + return Ok(None); + } + join_columns.push(column.clone()); + } + + let key_input = aggregate.input.as_ref().clone(); + let keys = LogicalPlanBuilder::from(key_input) + .aggregate(aggregate.group_expr.clone(), Vec::::new())? + .limit(0, Some(limit))? + .build()?; + + let filtered_input = LogicalPlanBuilder::from(keys) + .join_detailed( + aggregate.input.as_ref().clone(), + JoinType::RightSemi, + (join_columns.clone(), join_columns), + None, + NullEquality::NullEqualsNull, + )? + .build()?; + + Aggregate::try_new( + Arc::new(filtered_input), + aggregate.group_expr, + aggregate.aggr_expr, + ) + .map(LogicalPlan::Aggregate) + .map(Some) +} + +fn is_key_prefiltered_aggregate(aggregate: &Aggregate) -> bool { + let LogicalPlan::Join(join) = aggregate.input.as_ref() else { + return false; + }; + if join.join_type != JoinType::RightSemi { + return false; + } + let LogicalPlan::Limit(limit) = join.left.as_ref() else { + return false; + }; + let LogicalPlan::Aggregate(keys) = limit.input.as_ref() else { + return false; + }; + + keys.aggr_expr.is_empty() && keys.group_expr == aggregate.group_expr +} + /// Adds a limit to the inputs of a join, if possible fn push_down_join(mut join: Join, limit: usize) -> Transformed { use JoinType::*; diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs index 331205a726761..852dc2a2a9434 100644 --- a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! A special-case optimizer rule that pushes limit into unordered grouped -//! aggregation when the query only needs an arbitrary subset of groups. +//! A special-case optimizer rule that pushes limit into a grouped aggregation +//! which has no aggregate expressions or sorting requirements use std::sync::Arc; -use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, LimitOptions}; +use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -32,10 +32,9 @@ use crate::PhysicalOptimizerRule; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all -/// groups to be produced for correctness. Example queries fitting this description are: +/// rows in the group to be processed for correctness. Example queries fitting this description are: /// - `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` /// - `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` -/// - `SELECT l_orderkey, COUNT(*) FROM lineitem GROUP BY l_orderkey LIMIT 10;` #[derive(Debug)] pub struct LimitedDistinctAggregation {} @@ -49,43 +48,21 @@ impl LimitedDistinctAggregation { aggr: &AggregateExec, limit: usize, ) -> Option> { - if aggr.is_unordered_unfiltered_group_by_distinct() { - let new_aggr = aggr.with_new_limit_options(Some(LimitOptions::new(limit))); - return Some(Arc::new(new_aggr)); - } - - if !Self::can_limit_aggregate(aggr) { + // rules for transforming this Aggregate are held in this method + if !aggr.is_unordered_unfiltered_group_by_distinct() { return None; } + // We found what we want: clone, copy the limit down, and return modified node let new_aggr = aggr.with_new_limit_options(Some(LimitOptions::new(limit))); Some(Arc::new(new_aggr)) } - fn can_limit_aggregate(aggr: &AggregateExec) -> bool { - if !aggr.is_unordered_unfiltered_group_by() { - return false; - } - if aggr.aggr_expr().is_empty() { - return false; - } - if !aggr.group_expr().is_single() { - return false; - } - matches!( - aggr.mode(), - AggregateMode::Partial - | AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned - ) - } - /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when - /// there is a group by, but no sorting or filters in the aggregation + /// there is a group by, but no sorting, no aggregate expressions, and no filters in the + /// aggregation fn transform_limit(plan: Arc) -> Option> { let limit: usize; let mut global_fetch: Option = None; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8ecbfe84f580e..76ecb3f1485a4 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -915,7 +915,6 @@ impl AggregateExec { // grouping by an expression that has a sort/limit upstream if let Some(config) = self.limit_options - && config.descending().is_some() && !self.is_unordered_unfiltered_group_by_distinct() { return Ok(StreamType::GroupedPriorityQueue( @@ -935,9 +934,11 @@ impl AggregateExec { agg_expr.get_minmax_desc() } - /// true if this Aggregate has a group-by with no required or explicit - /// ordering and no aggregate filters. - pub fn is_unordered_unfiltered_group_by(&self) -> bool { + /// true, if this Aggregate has a group-by with no required or explicit ordering, + /// no filtering and no aggregate expressions + /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule + /// on an AggregateExec. + pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { if self .limit_options() .and_then(|config| config.descending) @@ -949,7 +950,12 @@ impl AggregateExec { if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() { return false; } - // ensure there are no filters on aggregate expressions + // ensure there are no aggregate expressions + if !self.aggr_expr().is_empty() { + return false; + } + // ensure there are no filters on aggregate expressions; the above check + // may preclude this case if self.filter_expr().iter().any(|e| e.is_some()) { return false; } @@ -968,15 +974,6 @@ impl AggregateExec { true } - /// true, if this Aggregate has a group-by with no required or explicit ordering, - /// no filtering and no aggregate expressions. - /// - /// This method qualifies the distinct-only use of the - /// LimitedDistinctAggregation rewrite rule on an AggregateExec. - pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { - self.aggr_expr().is_empty() && self.is_unordered_unfiltered_group_by() - } - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. pub fn compute_properties( input: &Arc, diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 62dd667013c0b..056a7f171a516 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -36,7 +36,6 @@ use crate::{PhysicalExpr, aggregates, metrics}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; -use arrow::compute::{SortColumn, SortOptions, lexsort_to_indices, take}; use arrow::datatypes::SchemaRef; use datafusion_common::{ DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err, @@ -373,9 +372,6 @@ pub(crate) struct GroupedHashAggregateStream { /// argument. aggregate_arguments: Vec>>, - /// Number of intermediate state columns produced by each accumulator. - state_field_counts: Vec, - /// Optional filter expression to evaluate, one for each for /// accumulator. If present, only those rows for which the filter /// evaluate to true should be included in the aggregate results. @@ -390,11 +386,10 @@ pub(crate) struct GroupedHashAggregateStream { /// max rows in output RecordBatches batch_size: usize, - /// Optional soft limit on the number of `group_values`. - /// - /// Distinct-style aggregates can stop once this many groups have been seen. - /// Aggregates with accumulator state keep the top `limit` group keys after - /// each input batch. + /// Optional soft limit on the number of `group_values` in a batch + /// If the number of `group_values` in a single batch exceeds this value, + /// the `GroupedHashAggregateStream` operation immediately switches to + /// output mode and emits all groups. group_values_soft_limit: Option, // ======================================================================== @@ -482,10 +477,6 @@ impl GroupedHashAggregateStream { let timer = baseline_metrics.elapsed_compute().timer(); let aggregate_exprs = Arc::clone(&agg.aggr_expr); - let state_field_counts = aggregate_exprs - .iter() - .map(|expr| expr.state_fields().map(|fields| fields.len())) - .collect::>>()?; // arguments for each aggregate, one vec of expressions per // aggregate @@ -635,7 +626,6 @@ impl GroupedHashAggregateStream { // aggregate state conversion // - there is only one GROUP BY expressions set let skip_aggregation_probe = if agg.mode == AggregateMode::Partial - && agg.limit_options().is_none() && matches!(group_ordering, GroupOrdering::None) && accumulators .iter() @@ -675,7 +665,6 @@ impl GroupedHashAggregateStream { mode: agg.mode, accumulators, aggregate_arguments, - state_field_counts, filter_expressions, group_by: agg_group_by, reservation, @@ -749,11 +738,9 @@ impl Stream for GroupedHashAggregateStream { assert!(!self.input_done); - // Distinct-style aggregation can stop once enough groups have been - // found. Aggregates with accumulator state must keep reading input - // so the selected groups' aggregate values remain exact. - if self.accumulators.is_empty() && self.hit_soft_group_limit() - { + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { timer.done(); self.set_input_done_and_produce_output()?; // make sure the exec_state just set is not overwritten below @@ -1019,8 +1006,6 @@ impl GroupedHashAggregateStream { } } - self.prune_to_group_key_topk()?; - Ok(()) } @@ -1101,91 +1086,6 @@ impl GroupedHashAggregateStream { reservation_result } - /// Keep only the smallest group keys for unordered `GROUP BY ... LIMIT` - /// aggregation. - /// - /// This is safe in partial aggregation because every partition uses the - /// same deterministic key order. Any globally top-k key is also in the - /// local top-k for every partition where it appears, so its partial state is - /// never dropped. - fn prune_to_group_key_topk(&mut self) -> Result<()> { - if !self.should_prune_to_group_key_topk() { - return Ok(()); - } - - let limit = self.group_values_soft_limit.unwrap(); - if limit == 0 { - self.clear_all(); - self.update_memory_reservation()?; - return Ok(()); - } - if self.group_values.len() <= limit { - return Ok(()); - } - - let Some(batch) = self.emit(EmitTo::All, true)? else { - return Ok(()); - }; - self.clear_shrink(0); - - let group_count = self.group_by.num_group_exprs(); - let sort_columns = batch - .columns() - .iter() - .take(group_count) - .map(|values| SortColumn { - values: Arc::clone(values), - options: Some(SortOptions::default()), - }) - .collect::>(); - let indices = lexsort_to_indices(&sort_columns, Some(limit))?; - let columns = batch - .columns() - .iter() - .map(|array| Ok(take(array.as_ref(), &indices, None)?)) - .collect::>>()?; - let batch = RecordBatch::try_new(Arc::clone(batch.schema_ref()), columns)?; - - let group_values = batch - .columns() - .iter() - .take(group_count) - .cloned() - .collect::>(); - self.group_values - .intern(&group_values, &mut self.current_group_indices)?; - let group_indices = &self.current_group_indices; - let total_num_groups = self.group_values.len(); - - let mut column_index = group_count; - for (acc, state_field_count) in self - .accumulators - .iter_mut() - .zip(self.state_field_counts.iter().copied()) - { - let next_column_index = column_index + state_field_count; - let values = batch.columns()[column_index..next_column_index].to_vec(); - acc.merge_batch(&values, group_indices, None, total_num_groups)?; - column_index = next_column_index; - } - assert_eq_or_internal_err!( - column_index, - batch.num_columns(), - "Mismatch rebuilding limited aggregate state" - ); - - self.update_memory_reservation()?; - - Ok(()) - } - - fn should_prune_to_group_key_topk(&self) -> bool { - self.group_values_soft_limit.is_some() - && !self.accumulators.is_empty() - && !self.spill_state.is_stream_merging - && matches!(self.group_ordering, GroupOrdering::None) - } - /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 00ca0482a31e1..e10f2a163f9f3 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -6607,7 +6607,7 @@ SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c 14 17 -# An aggregate expression causes the limit to not be pushed to the aggregation +# An aggregate expression prefilters the input through a limited group-key set query TT EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5; ---- @@ -6615,15 +6615,26 @@ logical_plan 01)Projection: max(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 02)--Limit: skip=0, fetch=5 03)----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[max(aggregate_test_100.c1)]] -04)------TableScan: aggregate_test_100 projection=[c1, c2, c3] +04)------RightSemi Join: aggregate_test_100.c2 = aggregate_test_100.c2, aggregate_test_100.c3 = aggregate_test_100.c3 +05)--------Limit: skip=0, fetch=5 +06)----------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +07)------------TableScan: aggregate_test_100 projection=[c2, c3] +08)--------TableScan: aggregate_test_100 projection=[c1, c2, c3] physical_plan 01)ProjectionExec: expr=[max(aggregate_test_100.c1)@2 as max(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] 02)--GlobalLimitExec: skip=0, fetch=5 03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[max(aggregate_test_100.c1)] 04)------CoalescePartitionsExec 05)--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[max(aggregate_test_100.c1)] -06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true +06)----------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(c2@0, c2@1), (c3@1, c3@2)], NullsEqual: true +07)------------GlobalLimitExec: skip=0, fetch=5 +08)--------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[5] +09)----------------CoalescePartitionsExec +10)------------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[5] +11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2, c3], file_type=csv, has_header=true +13)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true # TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns # in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index 29c9810e874fa..2a8c9443b9ef1 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -490,15 +490,27 @@ logical_plan 01)Projection: hits.UserID, hits.SearchPhrase, count(Int64(1)) AS count(*) 02)--Limit: skip=0, fetch=10 03)----Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], aggr=[[count(Int64(1))]] -04)------SubqueryAlias: hits -05)--------TableScan: hits_raw projection=[UserID, SearchPhrase] +04)------RightSemi Join: hits.UserID = hits.UserID, hits.SearchPhrase = hits.SearchPhrase +05)--------Limit: skip=0, fetch=10 +06)----------Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], aggr=[[]] +07)------------SubqueryAlias: hits +08)--------------TableScan: hits_raw projection=[UserID, SearchPhrase] +09)--------SubqueryAlias: hits +10)----------TableScan: hits_raw projection=[UserID, SearchPhrase] physical_plan 01)ProjectionExec: expr=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase, count(Int64(1))@2 as count(*)] 02)--CoalescePartitionsExec: fetch=10 -03)----AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))], lim=[10] -04)------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=1 -05)--------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))], lim=[10] -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet +03)----AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))] +04)------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=4 +05)--------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(UserID@0, UserID@0), (SearchPhrase@1, SearchPhrase@1)], NullsEqual: true +08)--------------CoalescePartitionsExec: fetch=10 +09)----------------AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[], lim=[10] +10)------------------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=1 +11)--------------------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[], lim=[10] +12)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet +13)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet, predicate=DynamicFilter [ empty ] query ITI rowsort SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10; From 8104f783b4e015e83656a4b97c44bb891890eaa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 4 May 2026 13:57:59 +0200 Subject: [PATCH 4/5] Update aggregate limit optimizer tests --- datafusion/optimizer/src/eliminate_limit.rs | 6 +++- datafusion/optimizer/src/push_down_limit.rs | 32 +++++++++++++++------ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 1ec3c856080eb..d01e88e9a4398 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -210,7 +210,11 @@ mod tests { Sort: test.a ASC NULLS LAST, fetch=3 Limit: skip=0, fetch=2 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] - TableScan: test + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[]] + TableScan: test + TableScan: test " ) } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index b604bca101252..b9c017a669d01 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -666,7 +666,7 @@ mod test { } #[test] - fn limit_doesnt_push_down_aggregation() -> Result<()> { + fn limit_prefilters_aggregation() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -674,13 +674,17 @@ mod test { .limit(0, Some(1000))? .build()?; - // Limit should *not* push down aggregate node + // Limit preselects group keys before running the aggregate assert_optimized_plan_equal!( plan, @r" Limit: skip=0, fetch=1000 Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] - TableScan: test + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[]] + TableScan: test + TableScan: test " ) } @@ -758,14 +762,20 @@ mod test { .limit(0, Some(10))? .build()?; - // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation + // Limit should use deeper LIMIT 1000 and preselect group keys for the + // aggregate using the outer LIMIT 10. assert_optimized_plan_equal!( plan, @r" Limit: skip=0, fetch=10 Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] - Limit: skip=0, fetch=1000 - TableScan: test, fetch=1000 + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=10 + Aggregate: groupBy=[[test.a]], aggr=[[]] + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 " ) } @@ -869,7 +879,7 @@ mod test { } #[test] - fn limit_doesnt_push_down_with_offset_aggregation() -> Result<()> { + fn limit_with_offset_prefilters_aggregation() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -877,13 +887,17 @@ mod test { .limit(10, Some(1000))? .build()?; - // Limit should *not* push down aggregate node + // Limit preselects enough group keys to satisfy offset and fetch assert_optimized_plan_equal!( plan, @r" Limit: skip=10, fetch=1000 Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] - TableScan: test + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=1010 + Aggregate: groupBy=[[test.a]], aggr=[[]] + TableScan: test + TableScan: test " ) } From 2d272063055899eee414d22d72184ad879a6a220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Mon, 4 May 2026 14:34:49 +0200 Subject: [PATCH 5/5] fix: avoid limit prefilter for FD group keys --- datafusion/optimizer/src/push_down_limit.rs | 60 +++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index b9c017a669d01..1c72c838940f5 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -26,7 +26,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; -use datafusion_common::{NullEquality, Result}; +use datafusion_common::{NullEquality, Result, get_required_group_by_exprs_indices}; use datafusion_expr::logical_plan::{Aggregate, Join, JoinType, Limit, LogicalPlan}; use datafusion_expr::{Expr, FetchType, LogicalPlanBuilder, SkipType, lit}; @@ -265,6 +265,9 @@ fn prefilter_limited_aggregate( if is_key_prefiltered_aggregate(&aggregate) { return Ok(None); } + if has_functionally_reducible_group_exprs(&aggregate) { + return Ok(None); + } let mut seen_columns = HashSet::with_capacity(aggregate.group_expr.len()); let mut join_columns = Vec::with_capacity(aggregate.group_expr.len()); @@ -303,6 +306,28 @@ fn prefilter_limited_aggregate( .map(Some) } +fn has_functionally_reducible_group_exprs(aggregate: &Aggregate) -> bool { + if aggregate + .input + .schema() + .functional_dependencies() + .is_empty() + { + return false; + } + + let group_expr_names = aggregate + .group_expr + .iter() + .map(|expr| expr.schema_name().to_string()) + .collect::>(); + + get_required_group_by_exprs_indices(aggregate.input.schema(), &group_expr_names) + .is_some_and(|required_indices| { + required_indices.len() < aggregate.group_expr.len() + }) +} + fn is_key_prefiltered_aggregate(aggregate: &Aggregate) -> bool { let LogicalPlan::Join(join) = aggregate.input.as_ref() else { return false; @@ -362,10 +387,11 @@ mod test { use crate::test::*; use crate::OptimizerContext; - use datafusion_common::DFSchemaRef; + use arrow::datatypes::Schema; + use datafusion_common::{Constraint, Constraints, DFSchemaRef}; use datafusion_expr::{ Expr, Extension, UserDefinedLogicalNodeCore, col, exists, - logical_plan::builder::LogicalPlanBuilder, + logical_plan::builder::{LogicalPlanBuilder, table_source_with_constraints}, }; use datafusion_functions_aggregate::expr_fn::max; @@ -689,6 +715,34 @@ mod test { ) } + #[test] + fn limit_does_not_prefilter_fd_reducible_aggregation() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let table_source = table_source_with_constraints( + &Schema::new(test_table_scan_fields()), + constraints, + ); + let table_scan = LogicalPlanBuilder::scan("test", table_source, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("b"), col("c")], vec![max(col("b"))])? + .limit(0, Some(1000))? + .build()?; + + // SQL planning may add functionally dependent fields as implicit group + // keys. Do not turn those redundant keys into semijoin predicates before + // projection optimization has a chance to simplify them. + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[max(test.b)]] + TableScan: test + " + ) + } + #[test] fn limit_should_push_down_union() -> Result<()> { let table_scan = test_table_scan()?;