Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;

Expand All @@ -45,14 +49,21 @@
* {@code SELECT sal, sal FROM emp GROUP BY sal}.
*
* <p>Currently supports the following aggregate functions when their
* arguments exist in the aggregate's group set:
* arguments exist in the aggregate's group set or are deterministic
* expressions involving only group set columns and constants:
* <ul>
* <li>{@code MAX}</li>
* <li>{@code MIN}</li>
* <li>{@code AVG}</li>
* <li>{@code ANY_VALUE}</li>
* </ul>
*
* <p>Note: This optimization preserves NULL semantics correctly. For aggregate
* functions like MAX, MIN, and ANY_VALUE, NULL values in the source columns or
* expressions are handled the same way before and after the transformation:
* nulls are ignored by the aggregation, and if all grouped values are NULL,
* the result is NULL.
*
* @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS
*/
@Value.Enclosing
Expand All @@ -74,6 +85,8 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) {

final List<AggregateCall> newCalls = new ArrayList<>();
final List<RexNode> projects = new ArrayList<>();
final List<String> fieldNames =
new ArrayList<>(aggregate.getRowType().getFieldNames());

// Pass through group keys.
for (int i = 0; i < groupCount; i++) {
Expand Down Expand Up @@ -108,12 +121,13 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) {
aggregate.getGroupSets(),
newCalls);
relBuilder.push(newAggregate);
relBuilder.project(projects);
relBuilder.project(projects, fieldNames);
call.transformTo(relBuilder.build());
}

/**
* Tries to reduce an aggregate call to a reference to a group-by key.
* Tries to reduce an aggregate call to a reference to a group-by key
* or to an expression involving only group-by keys and constants.
*
* @return the reduced expression, or null if cannot reduce
*/
Expand All @@ -129,14 +143,6 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) {
|| call.collation != RelCollations.EMPTY) {
return null;
}
final List<Integer> argList = call.getArgList();
if (argList.size() != 1) {
return null;
}
final int arg = argList.get(0);
if (!aggregate.getGroupSet().get(arg)) {
return null;
}
final SqlKind kind = call.getAggregation().getKind();
switch (kind) {
case AVG:
Expand All @@ -147,12 +153,118 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config config) {
default:
return null;
}
final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList());
if (!ref.getType().equals(call.getType())) {
ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref);
final List<Integer> argList = call.getArgList();
if (argList.size() != 1) {
return null;
}
final int arg = argList.get(0);

// Case 1: argument directly references a group-by key
if (aggregate.getGroupSet().get(arg)) {
final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList());
if (!ref.getType().equals(call.getType())) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do the tests cover this case, when a cast is inserted?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out the case. I've added a new test:testAggregateFunctionOfGroupByKeysWithCast to verify that when a CAST is inserted due to type mismatch between the GROUP BY key and the aggregate result, the optimization correctly preserves the CAST operation.

ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref);
}
return ref;
}

// Case 2: argument is an expression in a Project below the Aggregate
RelNode input = aggregate.getInput();
if (input instanceof HepRelVertex) {
input = ((HepRelVertex) input).getCurrentRel();
}
if (!(input instanceof Project)) {
return null;
}
final Project project = (Project) input;
if (arg < 0 || arg >= project.getProjects().size()) {
return null;
}
final RexNode expr = project.getProjects().get(arg);
if (!RexUtil.isDeterministic(expr)) {
return null;
}
// Check that all columns referenced in the expression are group-by keys.
// This ensures that the expression value is constant within each group.
final @Nullable RexNode translated =
translateToGroupRefs(expr, project, aggregate);
if (translated == null) {
return null;
}
if (!translated.getType().equals(call.getType())) {
return rexBuilder.makeCast(call.getParserPosition(), call.getType(), translated);
}
return translated;
}

/**
* Translates an expression so that its {@link RexInputRef}s reference
* the group keys of the aggregate rather than the input to the project.
*
* @return the translated expression, or null if the expression references
* columns that are not group-by keys
*/
private static @Nullable RexNode translateToGroupRefs(
RexNode expr, Project project, Aggregate aggregate) {
final List<RexNode> projects = project.getProjects();
final GroupRefTranslator translator = new GroupRefTranslator(projects, aggregate);
final RexNode result = expr.accept(translator);
return translator.failed ? null : result;
}

/**
* Shuttle that translates input refs to aggregate group key refs.
*
* <p>For each column reference in the expression being examined:
* 1. If the expression is a direct pass-through of a project column,
* check if that project column is in the GROUP BY set
* 2. If the expression contains references to input columns,
* verify that those input columns are in the GROUP BY set
* 3. Map to the corresponding group key index in the aggregate
*
* <p>This ensures the expression references only columns that are constant
* within each group.
*/
private static class GroupRefTranslator extends RexShuttle {
private final List<RexNode> projects;
private final Aggregate aggregate;
private boolean failed = false;

GroupRefTranslator(List<RexNode> projects, Aggregate aggregate) {
this.projects = projects;
this.aggregate = aggregate;
}

@Override public RexNode visitInputRef(RexInputRef inputRef) {
if (failed) {
return inputRef;
}
final int inputIndex = inputRef.getIndex();
// Look for a project column that is a direct pass-through of this input.
// For example, if a project has SAL=[$5], and the expression references $5,
// we need to map it to the corresponding group key.
int projectOutputIndex = -1;
for (int i = 0; i < projects.size(); i++) {
final RexNode projExpr = projects.get(i);
if (projExpr instanceof RexInputRef
&& ((RexInputRef) projExpr).getIndex() == inputIndex) {
projectOutputIndex = i;
break;
}
}
// The input column must be available through a project column that is in
// the GROUP BY set. If not found, the input is embedded in a computed
// expression, which means the optimization cannot proceed safely.
if (projectOutputIndex < 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a few negative tests that exercise these paths?
Some expressions that have a mix of the group-by columns and other columns?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added three negative test cases to exercise the validation logic.

|| !aggregate.getGroupSet().get(projectOutputIndex)) {
failed = true;
return inputRef;
}
final int groupIndex =
aggregate.getGroupSet().asList().indexOf(projectOutputIndex);
return RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList());
}
return ref;
}

/** Rule configuration. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,100 @@ private static RelOptFixture sql(String sql) {
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysNullExpression() {
String sql = "select comm, max(comm + 1) as max_plus\n"
+ "from empnullables group by comm";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysNullGroupKey() {
String sql = "select comm, max(comm) as comm_max\n"
+ "from empnullables group by comm";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysNoChange() {
String sql = "select sal, max(comm) as comm_max\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@Test void testAggregateFunctionOfGroupByKeysDeterministicExpression() {
String sql = "select sal, max(sal + 1) as max_plus\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysUnaryMinus() {
String sql = "select sal, max(-sal) as max_neg\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysBinaryExpression() {
String sql = "select sal, max(sal * 2) as max_double\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysMultipleGroupKeys() {
String sql = "select sal, max(sal + deptno) as max_sum\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysNestedExpression() {
// Nested expressions like (sal + 1) * 2 can be optimized by mapping the
// input references (sal) to group key references. The shuttle translates
// all input refs in the expression to their corresponding group keys.
String sql = "select sal, max((sal + 1) * 2) as max_expr\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysWithCastWider() {
// Test case where a cast is needed because the group key type differs
// from the aggregate result type. Cast to a wider type (BIGINT) is safe.
// The rule should preserve the cast.
String sql = "select cast(sal as bigint) as sal_big, max(sal) as sal_max\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This points to a possible problem: what if the cast is to a narrower type, e.g. TINYINT, which could lose information?

Copy link
Copy Markdown
Member Author

@xuzifu666 xuzifu666 May 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've enhanced the test coverage to address this problem.Based on the results, information is not lost, which aligns with expectations.

+ "from emp group by sal";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionOfGroupByKeysWithCastNarrower() {
// Test case where a cast is needed and the type is narrower than the source.
// Casting to SMALLINT could potentially lose information if sal has larger values,
// but this is the user's explicit choice. The rule should still optimize and
// preserve the cast, allowing SQL semantics to handle any data loss.
String sql = "select cast(sal as smallint) as sal_small, max(sal) as sal_max\n"
+ "from emp group by sal";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testAggregateFunctionWithMixedColumnsNoOptimization() {
// Negative test: expression references both group-by and non-group-by columns.
// The rule should NOT optimize because the expression is not constant within the group.
String sql = "select sal, max(sal + comm) as max_sum\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@Test void testAggregateFunctionWithNonGroupByColumnNoOptimization() {
// Negative test: expression references only non-group-by columns.
// The rule should NOT optimize because the column is not in the GROUP BY set.
String sql = "select sal, max(comm) as comm_max\n"
+ "from emp group by sal";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@Test void testAggregateFunctionWithMixedGroupByColumnsNoOptimization() {
// Negative test: expression contains GROUP BY column but also references
// a column from elsewhere that is not in GROUP BY.
String sql = "select sal, max(sal + empno) as max_sum\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@AfterAll static void checkActualAndReferenceFiles() {
fixture().diffRepos.checkActualAndReferenceFiles();
}
Expand Down
Loading
Loading