diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 10c5d2aa88..a464799f62 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -50,6 +50,7 @@ import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.ViewExpanders; import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; @@ -2769,11 +2770,36 @@ public RelNode visitFillNull(FillNull node, CalcitePlanContext context) { return context.relBuilder.peek(); } + /** Window {@code ORDER BY} keys from the current node's collation, or empty if it has none. */ + private static List deriveCollationOrderKeys(CalcitePlanContext context) { + RelBuilder relBuilder = context.relBuilder; + List collations = + relBuilder.getCluster().getMetadataQuery().collations(relBuilder.peek()); + if (collations == null || collations.isEmpty()) { + return List.of(); + } + List orderKeys = new ArrayList<>(); + for (RelFieldCollation fieldCollation : collations.get(0).getFieldCollations()) { + RexNode key = relBuilder.field(fieldCollation.getFieldIndex()); + if (fieldCollation.direction.isDescending()) { + key = relBuilder.desc(key); + } + if (fieldCollation.nullDirection == RelFieldCollation.NullDirection.LAST) { + key = relBuilder.nullsLast(key); + } else if (fieldCollation.nullDirection == RelFieldCollation.NullDirection.FIRST) { + key = relBuilder.nullsFirst(key); + } + orderKeys.add(key); + } + return orderKeys; + } + @Override public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { // 1. resolve main plan visitChildren(node, context); - // 2. add row_number() column to main + // 2. add row_number() column to main, ordered by its collation so the zip is deterministic + List mainOrderKeys = deriveCollationOrderKeys(context); RexNode mainRowNumber = PlanUtils.makeOver( context, @@ -2781,7 +2807,7 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { null, List.of(), List.of(), - List.of(), + mainOrderKeys, WindowFrame.toCurrentRow()); context.relBuilder.projectPlus( context.relBuilder.alias(mainRowNumber, ROW_NUMBER_COLUMN_FOR_MAIN)); @@ -2791,7 +2817,8 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { transformPlanToAttachChild(node.getSubSearch(), relation); // 4. resolve subsearch plan node.getSubSearch().accept(this, context); - // 5. add row_number() column to subsearch + // 5. add row_number() column to subsearch, ordered by its collation + List subsearchOrderKeys = deriveCollationOrderKeys(context); RexNode subsearchRowNumber = PlanUtils.makeOver( context, @@ -2799,7 +2826,7 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { null, List.of(), List.of(), - List.of(), + subsearchOrderKeys, WindowFrame.toCurrentRow()); context.relBuilder.projectPlus( context.relBuilder.alias(subsearchRowNumber, ROW_NUMBER_COLUMN_FOR_SUBSEARCH)); @@ -2821,6 +2848,11 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { context.relBuilder.join( JoinAndLookupUtils.translateJoinType(Join.JoinType.FULL), joinCondition); + // sort by the row numbers (nulls last) so the output order is stable across backends + context.relBuilder.sort( + context.relBuilder.nullsLast(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_MAIN)), + context.relBuilder.nullsLast(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_SUBSEARCH))); + if (!node.isOverride()) { // 8. if override = false, drop both _row_number_ columns context.relBuilder.projectExcept( diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java index 484dda37b7..e9aaa46c6b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java @@ -22,13 +22,16 @@ public void testAppendcol() { String expectedLogical = "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + " COMM=[$6], DEPTNO=[$7])\n" - + " LogicalJoin(condition=[=($8, $9)], joinType=[full])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER ()])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(_row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalFilter(condition=[=($7, 20)])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalSort(sort0=[$8], sort1=[$9], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($8, $9)], joinType=[full])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER (ORDER BY $0" + + " NULLS LAST)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(_row_number_subsearch_=[ROW_NUMBER() OVER (ORDER BY $0 NULLS" + + " LAST)])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); verifyResultCount(root, 14); @@ -36,12 +39,15 @@ public void testAppendcol() { "SELECT `t`.`EMPNO`, `t`.`ENAME`, `t`.`JOB`, `t`.`MGR`, `t`.`HIREDATE`, `t`.`SAL`," + " `t`.`COMM`, `t`.`DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," - + " ROW_NUMBER() OVER () `_row_number_main_`\n" + + " ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST) `_row_number_main_`\n" + "FROM `scott`.`EMP`) `t`\n" - + "FULL JOIN (SELECT ROW_NUMBER() OVER () `_row_number_subsearch_`\n" + + "FULL JOIN (SELECT ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" + "WHERE `DEPTNO` = 20) `t1` ON `t`.`_row_number_main_` =" - + " `t1`.`_row_number_subsearch_`"; + + " `t1`.`_row_number_subsearch_`\n" + + "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t1`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -54,16 +60,18 @@ public void testAppendcol2() { String expectedLogical = "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + " COMM=[$6], DEPTNO=[$7], left_col=[$8], right_col=[$10])\n" - + " LogicalJoin(condition=[=($9, $11)], joinType=[full])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " LogicalSort(sort0=[$9], sort1=[$11], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($9, $11)], joinType=[full])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], left_col=[$7], _row_number_main_=[ROW_NUMBER()" - + " OVER ()])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(right_col=[$8], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalFilter(condition=[=($7, 20)])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " OVER (ORDER BY $0 NULLS LAST)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(right_col=[$8], _row_number_subsearch_=[ROW_NUMBER() OVER" + + " (ORDER BY $0 NULLS LAST)])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], right_col=[$7])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); verifyResultCount(root, 14); @@ -71,14 +79,18 @@ public void testAppendcol2() { "SELECT `t`.`EMPNO`, `t`.`ENAME`, `t`.`JOB`, `t`.`MGR`, `t`.`HIREDATE`, `t`.`SAL`," + " `t`.`COMM`, `t`.`DEPTNO`, `t`.`left_col`, `t2`.`right_col`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," - + " `DEPTNO` `left_col`, ROW_NUMBER() OVER () `_row_number_main_`\n" + + " `DEPTNO` `left_col`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_main_`\n" + "FROM `scott`.`EMP`) `t`\n" - + "FULL JOIN (SELECT `right_col`, ROW_NUMBER() OVER () `_row_number_subsearch_`\n" + + "FULL JOIN (SELECT `right_col`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_subsearch_`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + " `DEPTNO` `right_col`\n" + "FROM `scott`.`EMP`) `t0`\n" + "WHERE `DEPTNO` = 20) `t2` ON `t`.`_row_number_main_` =" - + " `t2`.`_row_number_subsearch_`"; + + " `t2`.`_row_number_subsearch_`\n" + + "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t2`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -91,14 +103,17 @@ public void testAppendcolOverride() { + " JOB=[CASE(=($8, $17), $11, $2)], MGR=[CASE(=($8, $17), $12, $3)]," + " HIREDATE=[CASE(=($8, $17), $13, $4)], SAL=[CASE(=($8, $17), $14, $5)]," + " COMM=[CASE(=($8, $17), $15, $6)], DEPTNO=[CASE(=($8, $17), $16, $7)])\n" - + " LogicalJoin(condition=[=($8, $17)], joinType=[full])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER ()])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalFilter(condition=[=($7, 20)])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalSort(sort0=[$8], sort1=[$17], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($8, $17)], joinType=[full])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER (ORDER BY $0" + + " NULLS LAST)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_subsearch_=[ROW_NUMBER() OVER (ORDER" + + " BY $0 NULLS LAST)])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); verifyResultCount(root, 14); @@ -116,13 +131,16 @@ public void testAppendcolOverride() { + " `t`.`COMM` END `COMM`, CASE WHEN `t`.`_row_number_main_` =" + " `t1`.`_row_number_subsearch_` THEN `t1`.`DEPTNO` ELSE `t`.`DEPTNO` END `DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," - + " ROW_NUMBER() OVER () `_row_number_main_`\n" + + " ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST) `_row_number_main_`\n" + "FROM `scott`.`EMP`) `t`\n" + "FULL JOIN (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`," - + " `DEPTNO`, ROW_NUMBER() OVER () `_row_number_subsearch_`\n" + + " `DEPTNO`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" + "WHERE `DEPTNO` = 20) `t1` ON `t`.`_row_number_main_` =" - + " `t1`.`_row_number_subsearch_`"; + + " `t1`.`_row_number_subsearch_`\n" + + "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t1`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -132,16 +150,17 @@ public void testAppendcolStats() { RelNode root = getRelNode(ppl); String expectedLogical = "LogicalProject(count()=[$0], DEPTNO=[$1], avg(SAL)=[$3])\n" - + " LogicalJoin(condition=[=($2, $4)], joinType=[full])\n" - + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + + " LogicalSort(sort0=[$2], sort1=[$4], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($2, $4)], joinType=[full])\n" + + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + " ()])\n" - + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" - + " LogicalProject(DEPTNO=[$7])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(avg(SAL)=[$1], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(avg(SAL)=[$1], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" + + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); String expectedResult = "" @@ -159,7 +178,10 @@ public void testAppendcolStats() { + "FULL JOIN (SELECT AVG(`SAL`) `avg(SAL)`, ROW_NUMBER() OVER ()" + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" - + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` = `t4`.`_row_number_subsearch_`"; + + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` =" + + " `t4`.`_row_number_subsearch_`\n" + + "ORDER BY `t1`.`_row_number_main_` NULLS LAST, `t4`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -171,17 +193,18 @@ public void testAppendcolStatsOverride() { RelNode root = getRelNode(ppl); String expectedLogical = "LogicalProject(count()=[$0], DEPTNO=[CASE(=($2, $5), $4, $1)], avg(SAL)=[$3])\n" - + " LogicalJoin(condition=[=($2, $5)], joinType=[full])\n" - + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + + " LogicalSort(sort0=[$2], sort1=[$5], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($2, $5)], joinType=[full])\n" + + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + " ()])\n" - + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" - + " LogicalProject(DEPTNO=[$7])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(avg(SAL)=[$1], DEPTNO=[$0], _row_number_subsearch_=[ROW_NUMBER()" - + " OVER ()])\n" - + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(avg(SAL)=[$1], DEPTNO=[$0]," + + " _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" + + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); String expectedResult = "" @@ -200,7 +223,10 @@ public void testAppendcolStatsOverride() { + "FULL JOIN (SELECT AVG(`SAL`) `avg(SAL)`, `DEPTNO`, ROW_NUMBER() OVER ()" + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" - + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` = `t4`.`_row_number_subsearch_`"; + + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` =" + + " `t4`.`_row_number_subsearch_`\n" + + "ORDER BY `t1`.`_row_number_main_` NULLS LAST, `t4`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } }