diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsOnGroupKeysRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsOnGroupKeysRule.java index 10d30d620eb..987a25fea2e 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsOnGroupKeysRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsOnGroupKeysRule.java @@ -18,15 +18,19 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.hep.HepRelVertex; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.tools.RelBuilder; @@ -45,7 +49,8 @@ * {@code SELECT sal, sal FROM emp GROUP BY sal}. * *

Currently supports the following aggregate functions when their - * arguments exist in the aggregate's group set: + * arguments exist in the aggregate's group set or are deterministic + * expressions involving only group set columns and constants: *

* + *

Note: This optimization preserves NULL semantics correctly. For aggregate + * functions like MAX, MIN, and ANY_VALUE, NULL values in the source columns or + * expressions are handled the same way before and after the transformation: + * nulls are ignored by the aggregation, and if all grouped values are NULL, + * the result is NULL. + * * @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS */ @Value.Enclosing @@ -74,6 +85,8 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { final List newCalls = new ArrayList<>(); final List projects = new ArrayList<>(); + final List fieldNames = + new ArrayList<>(aggregate.getRowType().getFieldNames()); // Pass through group keys. for (int i = 0; i < groupCount; i++) { @@ -108,12 +121,13 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { aggregate.getGroupSets(), newCalls); relBuilder.push(newAggregate); - relBuilder.project(projects); + relBuilder.project(projects, fieldNames); call.transformTo(relBuilder.build()); } /** - * Tries to reduce an aggregate call to a reference to a group-by key. + * Tries to reduce an aggregate call to a reference to a group-by key + * or to an expression involving only group-by keys and constants. * * @return the reduced expression, or null if cannot reduce */ @@ -129,14 +143,6 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { || call.collation != RelCollations.EMPTY) { return null; } - final List argList = call.getArgList(); - if (argList.size() != 1) { - return null; - } - final int arg = argList.get(0); - if (!aggregate.getGroupSet().get(arg)) { - return null; - } final SqlKind kind = call.getAggregation().getKind(); switch (kind) { case AVG: @@ -147,12 +153,118 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) { default: return null; } - final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg); - RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList()); - if (!ref.getType().equals(call.getType())) { - ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref); + final List argList = call.getArgList(); + if (argList.size() != 1) { + return null; + } + final int arg = argList.get(0); + + // Case 1: argument directly references a group-by key + if (aggregate.getGroupSet().get(arg)) { + final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg); + RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList()); + if (!ref.getType().equals(call.getType())) { + ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref); + } + return ref; + } + + // Case 2: argument is an expression in a Project below the Aggregate + RelNode input = aggregate.getInput(); + if (input instanceof HepRelVertex) { + input = ((HepRelVertex) input).getCurrentRel(); + } + if (!(input instanceof Project)) { + return null; + } + final Project project = (Project) input; + if (arg < 0 || arg >= project.getProjects().size()) { + return null; + } + final RexNode expr = project.getProjects().get(arg); + if (!RexUtil.isDeterministic(expr)) { + return null; + } + // Check that all columns referenced in the expression are group-by keys. + // This ensures that the expression value is constant within each group. + final @Nullable RexNode translated = + translateToGroupRefs(expr, project, aggregate); + if (translated == null) { + return null; + } + if (!translated.getType().equals(call.getType())) { + return rexBuilder.makeCast(call.getParserPosition(), call.getType(), translated); + } + return translated; + } + + /** + * Translates an expression so that its {@link RexInputRef}s reference + * the group keys of the aggregate rather than the input to the project. + * + * @return the translated expression, or null if the expression references + * columns that are not group-by keys + */ + private static @Nullable RexNode translateToGroupRefs( + RexNode expr, Project project, Aggregate aggregate) { + final List projects = project.getProjects(); + final GroupRefTranslator translator = new GroupRefTranslator(projects, aggregate); + final RexNode result = expr.accept(translator); + return translator.failed ? null : result; + } + + /** + * Shuttle that translates input refs to aggregate group key refs. + * + *

For each column reference in the expression being examined: + * 1. If the expression is a direct pass-through of a project column, + * check if that project column is in the GROUP BY set + * 2. If the expression contains references to input columns, + * verify that those input columns are in the GROUP BY set + * 3. Map to the corresponding group key index in the aggregate + * + *

This ensures the expression references only columns that are constant + * within each group. + */ + private static class GroupRefTranslator extends RexShuttle { + private final List projects; + private final Aggregate aggregate; + private boolean failed = false; + + GroupRefTranslator(List projects, Aggregate aggregate) { + this.projects = projects; + this.aggregate = aggregate; + } + + @Override public RexNode visitInputRef(RexInputRef inputRef) { + if (failed) { + return inputRef; + } + final int inputIndex = inputRef.getIndex(); + // Look for a project column that is a direct pass-through of this input. + // For example, if a project has SAL=[$5], and the expression references $5, + // we need to map it to the corresponding group key. + int projectOutputIndex = -1; + for (int i = 0; i < projects.size(); i++) { + final RexNode projExpr = projects.get(i); + if (projExpr instanceof RexInputRef + && ((RexInputRef) projExpr).getIndex() == inputIndex) { + projectOutputIndex = i; + break; + } + } + // The input column must be available through a project column that is in + // the GROUP BY set. If not found, the input is embedded in a computed + // expression, which means the optimization cannot proceed safely. + if (projectOutputIndex < 0 + || !aggregate.getGroupSet().get(projectOutputIndex)) { + failed = true; + return inputRef; + } + final int groupIndex = + aggregate.getGroupSet().asList().indexOf(projectOutputIndex); + return RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList()); } - return ref; } /** Rule configuration. */ diff --git a/core/src/test/java/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.java b/core/src/test/java/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.java index 739ce1cb0b9..d4de2092aac 100644 --- a/core/src/test/java/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.java +++ b/core/src/test/java/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.java @@ -55,12 +55,100 @@ private static RelOptFixture sql(String sql) { sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); } + @Test void testAggregateFunctionOfGroupByKeysNullExpression() { + String sql = "select comm, max(comm + 1) as max_plus\n" + + "from empnullables group by comm"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionOfGroupByKeysNullGroupKey() { + String sql = "select comm, max(comm) as comm_max\n" + + "from empnullables group by comm"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + @Test void testAggregateFunctionOfGroupByKeysNoChange() { String sql = "select sal, max(comm) as comm_max\n" + "from emp group by sal, deptno"; sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); } + @Test void testAggregateFunctionOfGroupByKeysDeterministicExpression() { + String sql = "select sal, max(sal + 1) as max_plus\n" + + "from emp group by sal, deptno"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionOfGroupByKeysUnaryMinus() { + String sql = "select sal, max(-sal) as max_neg\n" + + "from emp group by sal, deptno"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionOfGroupByKeysBinaryExpression() { + String sql = "select sal, max(sal * 2) as max_double\n" + + "from emp group by sal, deptno"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionOfGroupByKeysMultipleGroupKeys() { + String sql = "select sal, max(sal + deptno) as max_sum\n" + + "from emp group by sal, deptno"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionOfGroupByKeysNestedExpression() { + // Nested expressions like (sal + 1) * 2 can be optimized by mapping the + // input references (sal) to group key references. The shuttle translates + // all input refs in the expression to their corresponding group keys. + String sql = "select sal, max((sal + 1) * 2) as max_expr\n" + + "from emp group by sal, deptno"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionOfGroupByKeysWithCastWider() { + // Test case where a cast is needed because the group key type differs + // from the aggregate result type. Cast to a wider type (BIGINT) is safe. + // The rule should preserve the cast. + String sql = "select cast(sal as bigint) as sal_big, max(sal) as sal_max\n" + + "from emp group by sal"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionOfGroupByKeysWithCastNarrower() { + // Test case where a cast is needed and the type is narrower than the source. + // Casting to SMALLINT could potentially lose information if sal has larger values, + // but this is the user's explicit choice. The rule should still optimize and + // preserve the cast, allowing SQL semantics to handle any data loss. + String sql = "select cast(sal as smallint) as sal_small, max(sal) as sal_max\n" + + "from emp group by sal"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check(); + } + + @Test void testAggregateFunctionWithMixedColumnsNoOptimization() { + // Negative test: expression references both group-by and non-group-by columns. + // The rule should NOT optimize because the expression is not constant within the group. + String sql = "select sal, max(sal + comm) as max_sum\n" + + "from emp group by sal, deptno"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); + } + + @Test void testAggregateFunctionWithNonGroupByColumnNoOptimization() { + // Negative test: expression references only non-group-by columns. + // The rule should NOT optimize because the column is not in the GROUP BY set. + String sql = "select sal, max(comm) as comm_max\n" + + "from emp group by sal"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); + } + + @Test void testAggregateFunctionWithMixedGroupByColumnsNoOptimization() { + // Negative test: expression contains GROUP BY column but also references + // a column from elsewhere that is not in GROUP BY. + String sql = "select sal, max(sal + empno) as max_sum\n" + + "from emp group by sal, deptno"; + sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged(); + } + @AfterAll static void checkActualAndReferenceFiles() { fixture().diffRepos.checkActualAndReferenceFiles(); } diff --git a/core/src/test/resources/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.xml b/core/src/test/resources/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.xml index e7eb9d5a927..2083a969d72 100644 --- a/core/src/test/resources/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/AggregateReduceFunctionsOnGroupKeysRuleTest.xml @@ -33,10 +33,102 @@ LogicalProject(SAL=[$0], SAL_MAX=[$2], SAL_MIN=[$3], SAL_AVG=[$4], SAL_VAL=[$5]) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -51,6 +143,48 @@ LogicalProject(SAL=[$0], COMM_MAX=[$2]) LogicalAggregate(group=[{0, 1}], COMM_MAX=[MAX($2)]) LogicalProject(SAL=[$5], DEPTNO=[$7], COMM=[$6]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + + + + + + + + + + + + + + + @@ -70,10 +204,120 @@ LogicalProject(SAL=[$0], SAL_MAX=[$2], COMM_SUM=[$3]) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +