diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 1ec3c856080eb..d01e88e9a4398 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -210,7 +210,11 @@ mod tests { Sort: test.a ASC NULLS LAST, fetch=3 Limit: skip=0, fetch=2 Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]] - TableScan: test + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=2 + Aggregate: groupBy=[[test.a]], aggr=[[]] + TableScan: test + TableScan: test " ) } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 4a26cd5884f6b..1c72c838940f5 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -18,16 +18,17 @@ //! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan use std::cmp::min; +use std::collections::HashSet; use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::Result; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; -use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; -use datafusion_expr::{FetchType, SkipType, lit}; +use datafusion_common::{NullEquality, Result, get_required_group_by_exprs_indices}; +use datafusion_expr::logical_plan::{Aggregate, Join, JoinType, Limit, LogicalPlan}; +use datafusion_expr::{Expr, FetchType, LogicalPlanBuilder, SkipType, lit}; /// Optimization rule that tries to push down `LIMIT`. //. It will push down through projection, limits (taking the smaller limit) @@ -47,7 +48,6 @@ impl OptimizerRule for PushDownLimit { true } - #[expect(clippy::only_used_in_recursion)] fn rewrite( &self, plan: LogicalPlan, @@ -123,6 +123,21 @@ impl OptimizerRule for PushDownLimit { make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) })), + LogicalPlan::Aggregate(aggregate) + if config + .options() + .optimizer + .enable_distinct_aggregation_soft_limit => + { + if let Some(aggregate) = + prefilter_limited_aggregate(aggregate.clone(), fetch + skip)? + { + transformed_limit(skip, fetch, aggregate) + } else { + original_limit(skip, fetch, LogicalPlan::Aggregate(aggregate)) + } + } + LogicalPlan::Sort(mut sort) => { let new_fetch = { let sort_fetch = skip + fetch; @@ -237,6 +252,99 @@ fn transformed_limit( Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input)))) } +/// Rewrite `LIMIT K (GROUP BY keys, aggs)` into a key preselection followed +/// by a semi join. This keeps the aggregate itself ordinary while letting the +/// join's dynamic filter push the selected key set into the second input scan. +fn prefilter_limited_aggregate( + aggregate: Aggregate, + limit: usize, +) -> Result> { + if limit == 0 || aggregate.aggr_expr.is_empty() || aggregate.group_expr.is_empty() { + return Ok(None); + } + if is_key_prefiltered_aggregate(&aggregate) { + return Ok(None); + } + if has_functionally_reducible_group_exprs(&aggregate) { + return Ok(None); + } + + let mut seen_columns = HashSet::with_capacity(aggregate.group_expr.len()); + let mut join_columns = Vec::with_capacity(aggregate.group_expr.len()); + for expr in &aggregate.group_expr { + let Expr::Column(column) = expr else { + return Ok(None); + }; + if !seen_columns.insert(column.clone()) { + return Ok(None); + } + join_columns.push(column.clone()); + } + + let key_input = aggregate.input.as_ref().clone(); + let keys = LogicalPlanBuilder::from(key_input) + .aggregate(aggregate.group_expr.clone(), Vec::::new())? + .limit(0, Some(limit))? + .build()?; + + let filtered_input = LogicalPlanBuilder::from(keys) + .join_detailed( + aggregate.input.as_ref().clone(), + JoinType::RightSemi, + (join_columns.clone(), join_columns), + None, + NullEquality::NullEqualsNull, + )? + .build()?; + + Aggregate::try_new( + Arc::new(filtered_input), + aggregate.group_expr, + aggregate.aggr_expr, + ) + .map(LogicalPlan::Aggregate) + .map(Some) +} + +fn has_functionally_reducible_group_exprs(aggregate: &Aggregate) -> bool { + if aggregate + .input + .schema() + .functional_dependencies() + .is_empty() + { + return false; + } + + let group_expr_names = aggregate + .group_expr + .iter() + .map(|expr| expr.schema_name().to_string()) + .collect::>(); + + get_required_group_by_exprs_indices(aggregate.input.schema(), &group_expr_names) + .is_some_and(|required_indices| { + required_indices.len() < aggregate.group_expr.len() + }) +} + +fn is_key_prefiltered_aggregate(aggregate: &Aggregate) -> bool { + let LogicalPlan::Join(join) = aggregate.input.as_ref() else { + return false; + }; + if join.join_type != JoinType::RightSemi { + return false; + } + let LogicalPlan::Limit(limit) = join.left.as_ref() else { + return false; + }; + let LogicalPlan::Aggregate(keys) = limit.input.as_ref() else { + return false; + }; + + keys.aggr_expr.is_empty() && keys.group_expr == aggregate.group_expr +} + /// Adds a limit to the inputs of a join, if possible fn push_down_join(mut join: Join, limit: usize) -> Transformed { use JoinType::*; @@ -279,10 +387,11 @@ mod test { use crate::test::*; use crate::OptimizerContext; - use datafusion_common::DFSchemaRef; + use arrow::datatypes::Schema; + use datafusion_common::{Constraint, Constraints, DFSchemaRef}; use datafusion_expr::{ Expr, Extension, UserDefinedLogicalNodeCore, col, exists, - logical_plan::builder::LogicalPlanBuilder, + logical_plan::builder::{LogicalPlanBuilder, table_source_with_constraints}, }; use datafusion_functions_aggregate::expr_fn::max; @@ -583,7 +692,7 @@ mod test { } #[test] - fn limit_doesnt_push_down_aggregation() -> Result<()> { + fn limit_prefilters_aggregation() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -591,12 +700,44 @@ mod test { .limit(0, Some(1000))? .build()?; - // Limit should *not* push down aggregate node + // Limit preselects group keys before running the aggregate assert_optimized_plan_equal!( plan, @r" Limit: skip=0, fetch=1000 Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a]], aggr=[[]] + TableScan: test + TableScan: test + " + ) + } + + #[test] + fn limit_does_not_prefilter_fd_reducible_aggregation() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let table_source = table_source_with_constraints( + &Schema::new(test_table_scan_fields()), + constraints, + ); + let table_scan = LogicalPlanBuilder::scan("test", table_source, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("b"), col("c")], vec![max(col("b"))])? + .limit(0, Some(1000))? + .build()?; + + // SQL planning may add functionally dependent fields as implicit group + // keys. Do not turn those redundant keys into semijoin predicates before + // projection optimization has a chance to simplify them. + assert_optimized_plan_equal!( + plan, + @r" + Limit: skip=0, fetch=1000 + Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[max(test.b)]] TableScan: test " ) @@ -675,14 +816,20 @@ mod test { .limit(0, Some(10))? .build()?; - // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation + // Limit should use deeper LIMIT 1000 and preselect group keys for the + // aggregate using the outer LIMIT 10. assert_optimized_plan_equal!( plan, @r" Limit: skip=0, fetch=10 Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] - Limit: skip=0, fetch=1000 - TableScan: test, fetch=1000 + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=10 + Aggregate: groupBy=[[test.a]], aggr=[[]] + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 + Limit: skip=0, fetch=1000 + TableScan: test, fetch=1000 " ) } @@ -786,7 +933,7 @@ mod test { } #[test] - fn limit_doesnt_push_down_with_offset_aggregation() -> Result<()> { + fn limit_with_offset_prefilters_aggregation() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -794,13 +941,17 @@ mod test { .limit(10, Some(1000))? .build()?; - // Limit should *not* push down aggregate node + // Limit preselects enough group keys to satisfy offset and fetch assert_optimized_plan_equal!( plan, @r" Limit: skip=10, fetch=1000 Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]] - TableScan: test + RightSemi Join: test.a = test.a + Limit: skip=0, fetch=1010 + Aggregate: groupBy=[[test.a]], aggr=[[]] + TableScan: test + TableScan: test " ) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 00ca0482a31e1..e10f2a163f9f3 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -6607,7 +6607,7 @@ SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c 14 17 -# An aggregate expression causes the limit to not be pushed to the aggregation +# An aggregate expression prefilters the input through a limited group-key set query TT EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5; ---- @@ -6615,15 +6615,26 @@ logical_plan 01)Projection: max(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 02)--Limit: skip=0, fetch=5 03)----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[max(aggregate_test_100.c1)]] -04)------TableScan: aggregate_test_100 projection=[c1, c2, c3] +04)------RightSemi Join: aggregate_test_100.c2 = aggregate_test_100.c2, aggregate_test_100.c3 = aggregate_test_100.c3 +05)--------Limit: skip=0, fetch=5 +06)----------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]] +07)------------TableScan: aggregate_test_100 projection=[c2, c3] +08)--------TableScan: aggregate_test_100 projection=[c1, c2, c3] physical_plan 01)ProjectionExec: expr=[max(aggregate_test_100.c1)@2 as max(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] 02)--GlobalLimitExec: skip=0, fetch=5 03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[max(aggregate_test_100.c1)] 04)------CoalescePartitionsExec 05)--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[max(aggregate_test_100.c1)] -06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true +06)----------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(c2@0, c2@1), (c3@1, c3@2)], NullsEqual: true +07)------------GlobalLimitExec: skip=0, fetch=5 +08)--------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[5] +09)----------------CoalescePartitionsExec +10)------------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[5] +11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2, c3], file_type=csv, has_header=true +13)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true # TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns # in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index c79701e347109..2a8c9443b9ef1 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -490,15 +490,27 @@ logical_plan 01)Projection: hits.UserID, hits.SearchPhrase, count(Int64(1)) AS count(*) 02)--Limit: skip=0, fetch=10 03)----Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], aggr=[[count(Int64(1))]] -04)------SubqueryAlias: hits -05)--------TableScan: hits_raw projection=[UserID, SearchPhrase] +04)------RightSemi Join: hits.UserID = hits.UserID, hits.SearchPhrase = hits.SearchPhrase +05)--------Limit: skip=0, fetch=10 +06)----------Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], aggr=[[]] +07)------------SubqueryAlias: hits +08)--------------TableScan: hits_raw projection=[UserID, SearchPhrase] +09)--------SubqueryAlias: hits +10)----------TableScan: hits_raw projection=[UserID, SearchPhrase] physical_plan 01)ProjectionExec: expr=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase, count(Int64(1))@2 as count(*)] 02)--CoalescePartitionsExec: fetch=10 03)----AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))] -04)------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=1 +04)------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=4 05)--------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))] -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(UserID@0, UserID@0), (SearchPhrase@1, SearchPhrase@1)], NullsEqual: true +08)--------------CoalescePartitionsExec: fetch=10 +09)----------------AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[], lim=[10] +10)------------------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=1 +11)--------------------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[], lim=[10] +12)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet +13)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet, predicate=DynamicFilter [ empty ] query ITI rowsort SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10;