Skip to content

Commit be816c1

Browse files
committed
[CALCITE-7439] Qualify GROUP BY keys for DISTINCT over joins
1 parent c628e68 commit be816c1

2 files changed

Lines changed: 69 additions & 2 deletions

File tree

core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,12 +858,21 @@ private List<SqlNode> generateGroupList(Builder builder,
858858
+ aggregate.getGroupSet() + ", just possibly a different order";
859859

860860
final List<SqlNode> groupKeys = new ArrayList<>();
861+
final Join aggregateJoinInput =
862+
aggregate.getInput() instanceof Join ? (Join) aggregate.getInput() : null;
863+
final SqlJoin fromJoin =
864+
builder.select.getFrom() instanceof SqlJoin ? (SqlJoin) builder.select.getFrom() : null;
865+
final int leftFieldCount = aggregateJoinInput == null
866+
? -1
867+
: aggregateJoinInput.getLeft().getRowType().getFieldCount();
861868
for (int key : groupList) {
862-
final SqlNode field = builder.context.field(key);
869+
SqlNode field = builder.context.field(key);
870+
field = maybeQualifyJoinKey(field, key, fromJoin, leftFieldCount);
863871
groupKeys.add(field);
864872
}
865873
for (int key : sortedGroupList) {
866-
final SqlNode field = builder.context.field(key);
874+
SqlNode field =
875+
maybeQualifyJoinKey(builder.context.field(key), key, fromJoin, leftFieldCount);
867876
addSelect(selectList, field, aggregate.getRowType());
868877
}
869878
switch (aggregate.getGroupType()) {
@@ -905,6 +914,31 @@ private List<SqlNode> generateGroupList(Builder builder,
905914
}
906915
}
907916

917+
private SqlNode maybeQualifyJoinKey(SqlNode field, int key,
918+
@Nullable SqlJoin fromJoin, int leftFieldCount) {
919+
if (!(field instanceof SqlIdentifier)
920+
|| ((SqlIdentifier) field).names.size() != 1
921+
|| fromJoin == null) {
922+
return field;
923+
}
924+
925+
final SqlNode side;
926+
if (leftFieldCount < 0) {
927+
if (key != 0) {
928+
return field;
929+
}
930+
side = fromJoin.getLeft();
931+
} else {
932+
side = key < leftFieldCount ? fromJoin.getLeft() : fromJoin.getRight();
933+
}
934+
final String sideAlias = SqlValidatorUtil.alias(side);
935+
if (sideAlias == null) {
936+
return field;
937+
}
938+
939+
return new SqlIdentifier(ImmutableList.of(sideAlias, ((SqlIdentifier) field).getSimple()), POS);
940+
}
941+
908942
private static SqlNode groupItem(List<SqlNode> groupKeys,
909943
ImmutableBitSet groupSet, ImmutableBitSet wholeGroupSet) {
910944
final List<SqlNode> nodes = groupSet.asList().stream()

core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
import static org.hamcrest.CoreMatchers.is;
128128
import static org.hamcrest.CoreMatchers.notNullValue;
129129
import static org.hamcrest.MatcherAssert.assertThat;
130+
import static org.hamcrest.Matchers.containsString;
130131
import static org.hamcrest.Matchers.hasToString;
131132
import static org.junit.jupiter.api.Assertions.assertFalse;
132133
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -11887,6 +11888,38 @@ public Sql schema(CalciteAssert.SchemaSpec schemaSpec) {
1188711888
sql(sql).schema(CalciteAssert.SchemaSpec.JDBC_SCOTT).ok(expected);
1188811889
}
1188911890

11891+
/** Test case for
11892+
* <a href="https://issues.apache.org/jira/browse/CALCITE-7439">[CALCITE-7439]
11893+
* RelToSqlConverter emits ambiguous GROUP BY after LEFT JOIN USING with
11894+
* semi-join rewrite.</a>. */
11895+
@Test void testPostgresqlRoundTripDistinctLeftJoinInSubqueryWithSemiJoinRules() {
11896+
final String query = "WITH product_keys AS (\n"
11897+
+ " SELECT p.\"product_id\",\n"
11898+
+ " (SELECT MAX(p3.\"product_id\")\n"
11899+
+ " FROM \"foodmart\".\"product\" p3\n"
11900+
+ " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n"
11901+
+ " FROM \"foodmart\".\"product\" p\n"
11902+
+ ")\n"
11903+
+ "SELECT DISTINCT pk.\"product_id\"\n"
11904+
+ "FROM product_keys pk\n"
11905+
+ "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n"
11906+
+ "WHERE pk.\"product_id\" IN (\n"
11907+
+ " SELECT p4.\"product_id\"\n"
11908+
+ " FROM \"foodmart\".\"product\" p4\n"
11909+
+ ")";
11910+
11911+
final RuleSet rules =
11912+
RuleSets.ofList(CoreRules.PROJECT_SUB_QUERY_TO_CORRELATE,
11913+
CoreRules.FILTER_SUB_QUERY_TO_CORRELATE,
11914+
CoreRules.JOIN_SUB_QUERY_TO_CORRELATE,
11915+
CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE,
11916+
CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE,
11917+
CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE,
11918+
CoreRules.PROJECT_TO_SEMI_JOIN);
11919+
11920+
final String generated = sql(query).withPostgresql().optimize(rules, null).exec();
11921+
assertThat(generated, containsString("GROUP BY \"t2\".\"product_id\""));
11922+
}
1189011923
@Test void testNotBetween() {
1189111924
Sql f = fixture().withConvertletTable(new SqlRexConvertletTable() {
1189211925
@Override public @Nullable SqlRexConvertlet get(SqlCall call) {

0 commit comments

Comments
 (0)