-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[CALCITE-7492] Support expression that has a constant value within the group involving only GROUP BY keys as aggregate arguments #4961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
07dc578
45adc2f
09309f2
2419f66
5b6bb6a
68a85da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,15 +18,19 @@ | |
|
|
||
| import org.apache.calcite.plan.RelOptRuleCall; | ||
| import org.apache.calcite.plan.RelRule; | ||
| import org.apache.calcite.plan.hep.HepRelVertex; | ||
| import org.apache.calcite.rel.RelCollations; | ||
| import org.apache.calcite.rel.RelNode; | ||
| import org.apache.calcite.rel.core.Aggregate; | ||
| import org.apache.calcite.rel.core.AggregateCall; | ||
| import org.apache.calcite.rel.core.Project; | ||
| import org.apache.calcite.rel.core.RelFactories; | ||
| import org.apache.calcite.rel.logical.LogicalAggregate; | ||
| import org.apache.calcite.rex.RexBuilder; | ||
| import org.apache.calcite.rex.RexInputRef; | ||
| import org.apache.calcite.rex.RexNode; | ||
| import org.apache.calcite.rex.RexShuttle; | ||
| import org.apache.calcite.rex.RexUtil; | ||
| import org.apache.calcite.sql.SqlKind; | ||
| import org.apache.calcite.tools.RelBuilder; | ||
|
|
||
|
|
@@ -45,14 +49,21 @@ | |
| * {@code SELECT sal, sal FROM emp GROUP BY sal}. | ||
| * | ||
| * <p>Currently supports the following aggregate functions when their | ||
| * arguments exist in the aggregate's group set: | ||
| * arguments exist in the aggregate's group set or are deterministic | ||
| * expressions involving only group set columns and constants: | ||
| * <ul> | ||
| * <li>{@code MAX}</li> | ||
| * <li>{@code MIN}</li> | ||
| * <li>{@code AVG}</li> | ||
| * <li>{@code ANY_VALUE}</li> | ||
| * </ul> | ||
| * | ||
| * <p>Note: This optimization preserves NULL semantics correctly. For aggregate | ||
| * functions like MAX, MIN, and ANY_VALUE, NULL values in the source columns or | ||
| * expressions are handled the same way before and after the transformation: | ||
| * nulls are ignored by the aggregation, and if all grouped values are NULL, | ||
| * the result is NULL. | ||
| * | ||
| * @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS | ||
| */ | ||
| @Value.Enclosing | ||
|
|
@@ -74,6 +85,8 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { | |
|
|
||
| final List<AggregateCall> newCalls = new ArrayList<>(); | ||
| final List<RexNode> projects = new ArrayList<>(); | ||
| final List<String> fieldNames = | ||
| new ArrayList<>(aggregate.getRowType().getFieldNames()); | ||
|
|
||
| // Pass through group keys. | ||
| for (int i = 0; i < groupCount; i++) { | ||
|
|
@@ -108,12 +121,13 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { | |
| aggregate.getGroupSets(), | ||
| newCalls); | ||
| relBuilder.push(newAggregate); | ||
| relBuilder.project(projects); | ||
| relBuilder.project(projects, fieldNames); | ||
| call.transformTo(relBuilder.build()); | ||
| } | ||
|
|
||
| /** | ||
| * Tries to reduce an aggregate call to a reference to a group-by key. | ||
| * Tries to reduce an aggregate call to a reference to a group-by key | ||
| * or to an expression involving only group-by keys and constants. | ||
| * | ||
| * @return the reduced expression, or null if cannot reduce | ||
| */ | ||
|
|
@@ -129,14 +143,6 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { | |
| || call.collation != RelCollations.EMPTY) { | ||
| return null; | ||
| } | ||
| final List<Integer> argList = call.getArgList(); | ||
| if (argList.size() != 1) { | ||
| return null; | ||
| } | ||
| final int arg = argList.get(0); | ||
| if (!aggregate.getGroupSet().get(arg)) { | ||
| return null; | ||
| } | ||
| final SqlKind kind = call.getAggregation().getKind(); | ||
| switch (kind) { | ||
| case AVG: | ||
|
|
@@ -147,12 +153,118 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { | |
| default: | ||
| return null; | ||
| } | ||
| final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg); | ||
| RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList()); | ||
| if (!ref.getType().equals(call.getType())) { | ||
| ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref); | ||
| final List<Integer> argList = call.getArgList(); | ||
| if (argList.size() != 1) { | ||
| return null; | ||
| } | ||
| final int arg = argList.get(0); | ||
|
|
||
| // Case 1: argument directly references a group-by key | ||
| if (aggregate.getGroupSet().get(arg)) { | ||
| final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg); | ||
| RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList()); | ||
| if (!ref.getType().equals(call.getType())) { | ||
| ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref); | ||
| } | ||
| return ref; | ||
| } | ||
|
|
||
| // Case 2: argument is an expression in a Project below the Aggregate | ||
| RelNode input = aggregate.getInput(); | ||
| if (input instanceof HepRelVertex) { | ||
| input = ((HepRelVertex) input).getCurrentRel(); | ||
| } | ||
| if (!(input instanceof Project)) { | ||
| return null; | ||
| } | ||
| final Project project = (Project) input; | ||
| if (arg < 0 || arg >= project.getProjects().size()) { | ||
| return null; | ||
| } | ||
| final RexNode expr = project.getProjects().get(arg); | ||
| if (!RexUtil.isDeterministic(expr)) { | ||
| return null; | ||
| } | ||
| // Check that all columns referenced in the expression are group-by keys. | ||
| // This ensures that the expression value is constant within each group. | ||
| final @Nullable RexNode translated = | ||
| translateToGroupRefs(expr, project, aggregate); | ||
| if (translated == null) { | ||
| return null; | ||
| } | ||
| if (!translated.getType().equals(call.getType())) { | ||
| return rexBuilder.makeCast(call.getParserPosition(), call.getType(), translated); | ||
| } | ||
| return translated; | ||
| } | ||
|
|
||
| /** | ||
| * Translates an expression so that its {@link RexInputRef}s reference | ||
| * the group keys of the aggregate rather than the input to the project. | ||
| * | ||
| * @return the translated expression, or null if the expression references | ||
| * columns that are not group-by keys | ||
| */ | ||
| private static @Nullable RexNode translateToGroupRefs( | ||
| RexNode expr, Project project, Aggregate aggregate) { | ||
| final List<RexNode> projects = project.getProjects(); | ||
| final GroupRefTranslator translator = new GroupRefTranslator(projects, aggregate); | ||
| final RexNode result = expr.accept(translator); | ||
| return translator.failed ? null : result; | ||
| } | ||
|
|
||
| /** | ||
| * Shuttle that translates input refs to aggregate group key refs. | ||
| * | ||
| * <p>For each column reference in the expression being examined: | ||
| * 1. If the expression is a direct pass-through of a project column, | ||
| * check if that project column is in the GROUP BY set | ||
| * 2. If the expression contains references to input columns, | ||
| * verify that those input columns are in the GROUP BY set | ||
| * 3. Map to the corresponding group key index in the aggregate | ||
| * | ||
| * <p>This ensures the expression references only columns that are constant | ||
| * within each group. | ||
| */ | ||
| private static class GroupRefTranslator extends RexShuttle { | ||
| private final List<RexNode> projects; | ||
| private final Aggregate aggregate; | ||
| private boolean failed = false; | ||
|
|
||
| GroupRefTranslator(List<RexNode> projects, Aggregate aggregate) { | ||
| this.projects = projects; | ||
| this.aggregate = aggregate; | ||
| } | ||
|
|
||
| @Override public RexNode visitInputRef(RexInputRef inputRef) { | ||
| if (failed) { | ||
| return inputRef; | ||
| } | ||
| final int inputIndex = inputRef.getIndex(); | ||
| // Look for a project column that is a direct pass-through of this input. | ||
| // For example, if a project has SAL=[$5], and the expression references $5, | ||
| // we need to map it to the corresponding group key. | ||
| int projectOutputIndex = -1; | ||
| for (int i = 0; i < projects.size(); i++) { | ||
| final RexNode projExpr = projects.get(i); | ||
| if (projExpr instanceof RexInputRef | ||
| && ((RexInputRef) projExpr).getIndex() == inputIndex) { | ||
| projectOutputIndex = i; | ||
| break; | ||
| } | ||
| } | ||
| // The input column must be available through a project column that is in | ||
| // the GROUP BY set. If not found, the input is embedded in a computed | ||
| // expression, which means the optimization cannot proceed safely. | ||
| if (projectOutputIndex < 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a few negative tests that exercise these paths?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added three negative test cases to exercise the validation logic. |
||
| || !aggregate.getGroupSet().get(projectOutputIndex)) { | ||
| failed = true; | ||
| return inputRef; | ||
| } | ||
| final int groupIndex = | ||
| aggregate.getGroupSet().asList().indexOf(projectOutputIndex); | ||
| return RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList()); | ||
| } | ||
| return ref; | ||
| } | ||
|
|
||
| /** Rule configuration. */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,12 +55,100 @@ private static RelOptFixture sql(String sql) { | |
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysNullExpression() { | ||
| String sql = "select comm, max(comm + 1) as max_plus\n" | ||
| + "from empnullables group by comm"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysNullGroupKey() { | ||
| String sql = "select comm, max(comm) as comm_max\n" | ||
| + "from empnullables group by comm"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysNoChange() { | ||
| String sql = "select sal, max(comm) as comm_max\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysDeterministicExpression() { | ||
| String sql = "select sal, max(sal + 1) as max_plus\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysUnaryMinus() { | ||
| String sql = "select sal, max(-sal) as max_neg\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysBinaryExpression() { | ||
| String sql = "select sal, max(sal * 2) as max_double\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysMultipleGroupKeys() { | ||
| String sql = "select sal, max(sal + deptno) as max_sum\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysNestedExpression() { | ||
| // Nested expressions like (sal + 1) * 2 can be optimized by mapping the | ||
| // input references (sal) to group key references. The shuttle translates | ||
| // all input refs in the expression to their corresponding group keys. | ||
| String sql = "select sal, max((sal + 1) * 2) as max_expr\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysWithCastWider() { | ||
| // Test case where a cast is needed because the group key type differs | ||
| // from the aggregate result type. Cast to a wider type (BIGINT) is safe. | ||
| // The rule should preserve the cast. | ||
| String sql = "select cast(sal as bigint) as sal_big, max(sal) as sal_max\n" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This points to a possible problem: what if the cast is to a narrower type, e.g. TINYINT, which could lose information?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I've enhanced the test coverage to address this problem.Based on the results, information is not lost, which aligns with expectations. |
||
| + "from emp group by sal"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionOfGroupByKeysWithCastNarrower() { | ||
| // Test case where a cast is needed and the type is narrower than the source. | ||
| // Casting to SMALLINT could potentially lose information if sal has larger values, | ||
| // but this is the user's explicit choice. The rule should still optimize and | ||
| // preserve the cast, allowing SQL semantics to handle any data loss. | ||
| String sql = "select cast(sal as smallint) as sal_small, max(sal) as sal_max\n" | ||
| + "from emp group by sal"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionWithMixedColumnsNoOptimization() { | ||
| // Negative test: expression references both group-by and non-group-by columns. | ||
| // The rule should NOT optimize because the expression is not constant within the group. | ||
| String sql = "select sal, max(sal + comm) as max_sum\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionWithNonGroupByColumnNoOptimization() { | ||
| // Negative test: expression references only non-group-by columns. | ||
| // The rule should NOT optimize because the column is not in the GROUP BY set. | ||
| String sql = "select sal, max(comm) as comm_max\n" | ||
| + "from emp group by sal"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); | ||
| } | ||
|
|
||
| @Test void testAggregateFunctionWithMixedGroupByColumnsNoOptimization() { | ||
| // Negative test: expression contains GROUP BY column but also references | ||
| // a column from elsewhere that is not in GROUP BY. | ||
| String sql = "select sal, max(sal + empno) as max_sum\n" | ||
| + "from emp group by sal, deptno"; | ||
| sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); | ||
| } | ||
|
|
||
| @AfterAll static void checkActualAndReferenceFiles() { | ||
| fixture().diffRepos.checkActualAndReferenceFiles(); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do the tests cover this case, when a cast is inserted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out the case. I've added a new test:
testAggregateFunctionOfGroupByKeysWithCastto verify that when a CAST is inserted due to type mismatch between the GROUP BY key and the aggregate result, the optimization correctly preserves the CAST operation.