From 41ef5f7de1ac6471b3a2d5a8d1d4bbed85fe13d9 Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 27 May 2026 13:34:15 -0700 Subject: [PATCH] fix(core): make appendcol row ordering deterministic on parallel engines appendcol lowers to a FULL JOIN of two ROW_NUMBER() OVER () windows (empty PARTITION BY / ORDER BY) on _row_number_main_ = _row_number_subsearch_, with no trailing sort. That positional zip is only correct on a serial, order-preserving executor: a bare ROW_NUMBER() OVER () assigns sequence numbers in input order and the join preserves it. On a parallel/distributed backend the row-number assignment is arbitrary and the hash join drops ordering, so columns get zipped onto the wrong rows and downstream `head` slices a non-deterministic subset. Fix visitAppendCol to not depend on implicit input-order preservation: - derive an explicit window ORDER BY from each child's collation (deriveCollationOrderKeys), so ROW_NUMBER assignment follows the upstream sort; falls back to the prior bare OVER () when the input has no collation (positional correspondence is undefined without a sort). - add a trailing sort by the row-number columns after the join (NULLS LAST, same pattern as streamstats) so output order is deterministic regardless of how the backend executes the join. No behavior change on the serial v2/Calcite engine; makes the lowering correct on parallel backends. Updates CalcitePPLAppendcolTest expected plans/SparkSQL. Signed-off-by: Kai Huang --- .../sql/calcite/CalciteRelNodeVisitor.java | 40 +++++- .../ppl/calcite/CalcitePPLAppendcolTest.java | 132 +++++++++++------- 2 files changed, 115 insertions(+), 57 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 10c5d2aa888..a464799f620 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -50,6 +50,7 @@ import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.ViewExpanders; import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; @@ -2769,11 +2770,36 @@ public RelNode visitFillNull(FillNull node, CalcitePlanContext context) { return context.relBuilder.peek(); } + /** Window {@code ORDER BY} keys from the current node's collation, or empty if it has none. */ + private static List deriveCollationOrderKeys(CalcitePlanContext context) { + RelBuilder relBuilder = context.relBuilder; + List collations = + relBuilder.getCluster().getMetadataQuery().collations(relBuilder.peek()); + if (collations == null || collations.isEmpty()) { + return List.of(); + } + List orderKeys = new ArrayList<>(); + for (RelFieldCollation fieldCollation : collations.get(0).getFieldCollations()) { + RexNode key = relBuilder.field(fieldCollation.getFieldIndex()); + if (fieldCollation.direction.isDescending()) { + key = relBuilder.desc(key); + } + if (fieldCollation.nullDirection == RelFieldCollation.NullDirection.LAST) { + key = relBuilder.nullsLast(key); + } else if (fieldCollation.nullDirection == RelFieldCollation.NullDirection.FIRST) { + key = relBuilder.nullsFirst(key); + } + orderKeys.add(key); + } + return orderKeys; + } + @Override public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { // 1. resolve main plan visitChildren(node, context); - // 2. add row_number() column to main + // 2. add row_number() column to main, ordered by its collation so the zip is deterministic + List mainOrderKeys = deriveCollationOrderKeys(context); RexNode mainRowNumber = PlanUtils.makeOver( context, @@ -2781,7 +2807,7 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { null, List.of(), List.of(), - List.of(), + mainOrderKeys, WindowFrame.toCurrentRow()); context.relBuilder.projectPlus( context.relBuilder.alias(mainRowNumber, ROW_NUMBER_COLUMN_FOR_MAIN)); @@ -2791,7 +2817,8 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { transformPlanToAttachChild(node.getSubSearch(), relation); // 4. resolve subsearch plan node.getSubSearch().accept(this, context); - // 5. add row_number() column to subsearch + // 5. add row_number() column to subsearch, ordered by its collation + List subsearchOrderKeys = deriveCollationOrderKeys(context); RexNode subsearchRowNumber = PlanUtils.makeOver( context, @@ -2799,7 +2826,7 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { null, List.of(), List.of(), - List.of(), + subsearchOrderKeys, WindowFrame.toCurrentRow()); context.relBuilder.projectPlus( context.relBuilder.alias(subsearchRowNumber, ROW_NUMBER_COLUMN_FOR_SUBSEARCH)); @@ -2821,6 +2848,11 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) { context.relBuilder.join( JoinAndLookupUtils.translateJoinType(Join.JoinType.FULL), joinCondition); + // sort by the row numbers (nulls last) so the output order is stable across backends + context.relBuilder.sort( + context.relBuilder.nullsLast(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_MAIN)), + context.relBuilder.nullsLast(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_SUBSEARCH))); + if (!node.isOverride()) { // 8. if override = false, drop both _row_number_ columns context.relBuilder.projectExcept( diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java index 484dda37b7d..e9aaa46c6b6 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAppendcolTest.java @@ -22,13 +22,16 @@ public void testAppendcol() { String expectedLogical = "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + " COMM=[$6], DEPTNO=[$7])\n" - + " LogicalJoin(condition=[=($8, $9)], joinType=[full])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER ()])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(_row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalFilter(condition=[=($7, 20)])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalSort(sort0=[$8], sort1=[$9], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($8, $9)], joinType=[full])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER (ORDER BY $0" + + " NULLS LAST)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(_row_number_subsearch_=[ROW_NUMBER() OVER (ORDER BY $0 NULLS" + + " LAST)])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); verifyResultCount(root, 14); @@ -36,12 +39,15 @@ public void testAppendcol() { "SELECT `t`.`EMPNO`, `t`.`ENAME`, `t`.`JOB`, `t`.`MGR`, `t`.`HIREDATE`, `t`.`SAL`," + " `t`.`COMM`, `t`.`DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," - + " ROW_NUMBER() OVER () `_row_number_main_`\n" + + " ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST) `_row_number_main_`\n" + "FROM `scott`.`EMP`) `t`\n" - + "FULL JOIN (SELECT ROW_NUMBER() OVER () `_row_number_subsearch_`\n" + + "FULL JOIN (SELECT ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" + "WHERE `DEPTNO` = 20) `t1` ON `t`.`_row_number_main_` =" - + " `t1`.`_row_number_subsearch_`"; + + " `t1`.`_row_number_subsearch_`\n" + + "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t1`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -54,16 +60,18 @@ public void testAppendcol2() { String expectedLogical = "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + " COMM=[$6], DEPTNO=[$7], left_col=[$8], right_col=[$10])\n" - + " LogicalJoin(condition=[=($9, $11)], joinType=[full])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " LogicalSort(sort0=[$9], sort1=[$11], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($9, $11)], joinType=[full])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], left_col=[$7], _row_number_main_=[ROW_NUMBER()" - + " OVER ()])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(right_col=[$8], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalFilter(condition=[=($7, 20)])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " OVER (ORDER BY $0 NULLS LAST)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(right_col=[$8], _row_number_subsearch_=[ROW_NUMBER() OVER" + + " (ORDER BY $0 NULLS LAST)])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + " SAL=[$5], COMM=[$6], DEPTNO=[$7], right_col=[$7])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); verifyResultCount(root, 14); @@ -71,14 +79,18 @@ public void testAppendcol2() { "SELECT `t`.`EMPNO`, `t`.`ENAME`, `t`.`JOB`, `t`.`MGR`, `t`.`HIREDATE`, `t`.`SAL`," + " `t`.`COMM`, `t`.`DEPTNO`, `t`.`left_col`, `t2`.`right_col`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," - + " `DEPTNO` `left_col`, ROW_NUMBER() OVER () `_row_number_main_`\n" + + " `DEPTNO` `left_col`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_main_`\n" + "FROM `scott`.`EMP`) `t`\n" - + "FULL JOIN (SELECT `right_col`, ROW_NUMBER() OVER () `_row_number_subsearch_`\n" + + "FULL JOIN (SELECT `right_col`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_subsearch_`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + " `DEPTNO` `right_col`\n" + "FROM `scott`.`EMP`) `t0`\n" + "WHERE `DEPTNO` = 20) `t2` ON `t`.`_row_number_main_` =" - + " `t2`.`_row_number_subsearch_`"; + + " `t2`.`_row_number_subsearch_`\n" + + "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t2`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -91,14 +103,17 @@ public void testAppendcolOverride() { + " JOB=[CASE(=($8, $17), $11, $2)], MGR=[CASE(=($8, $17), $12, $3)]," + " HIREDATE=[CASE(=($8, $17), $13, $4)], SAL=[CASE(=($8, $17), $14, $5)]," + " COMM=[CASE(=($8, $17), $15, $6)], DEPTNO=[CASE(=($8, $17), $16, $7)])\n" - + " LogicalJoin(condition=[=($8, $17)], joinType=[full])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER ()])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," - + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalFilter(condition=[=($7, 20)])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalSort(sort0=[$8], sort1=[$17], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($8, $17)], joinType=[full])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER (ORDER BY $0" + + " NULLS LAST)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_subsearch_=[ROW_NUMBER() OVER (ORDER" + + " BY $0 NULLS LAST)])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); verifyResultCount(root, 14); @@ -116,13 +131,16 @@ public void testAppendcolOverride() { + " `t`.`COMM` END `COMM`, CASE WHEN `t`.`_row_number_main_` =" + " `t1`.`_row_number_subsearch_` THEN `t1`.`DEPTNO` ELSE `t`.`DEPTNO` END `DEPTNO`\n" + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," - + " ROW_NUMBER() OVER () `_row_number_main_`\n" + + " ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST) `_row_number_main_`\n" + "FROM `scott`.`EMP`) `t`\n" + "FULL JOIN (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`," - + " `DEPTNO`, ROW_NUMBER() OVER () `_row_number_subsearch_`\n" + + " `DEPTNO`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)" + + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" + "WHERE `DEPTNO` = 20) `t1` ON `t`.`_row_number_main_` =" - + " `t1`.`_row_number_subsearch_`"; + + " `t1`.`_row_number_subsearch_`\n" + + "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t1`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -132,16 +150,17 @@ public void testAppendcolStats() { RelNode root = getRelNode(ppl); String expectedLogical = "LogicalProject(count()=[$0], DEPTNO=[$1], avg(SAL)=[$3])\n" - + " LogicalJoin(condition=[=($2, $4)], joinType=[full])\n" - + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + + " LogicalSort(sort0=[$2], sort1=[$4], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($2, $4)], joinType=[full])\n" + + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + " ()])\n" - + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" - + " LogicalProject(DEPTNO=[$7])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(avg(SAL)=[$1], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" - + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(avg(SAL)=[$1], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" + + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); String expectedResult = "" @@ -159,7 +178,10 @@ public void testAppendcolStats() { + "FULL JOIN (SELECT AVG(`SAL`) `avg(SAL)`, ROW_NUMBER() OVER ()" + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" - + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` = `t4`.`_row_number_subsearch_`"; + + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` =" + + " `t4`.`_row_number_subsearch_`\n" + + "ORDER BY `t1`.`_row_number_main_` NULLS LAST, `t4`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } @@ -171,17 +193,18 @@ public void testAppendcolStatsOverride() { RelNode root = getRelNode(ppl); String expectedLogical = "LogicalProject(count()=[$0], DEPTNO=[CASE(=($2, $5), $4, $1)], avg(SAL)=[$3])\n" - + " LogicalJoin(condition=[=($2, $5)], joinType=[full])\n" - + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + + " LogicalSort(sort0=[$2], sort1=[$5], dir0=[ASC], dir1=[ASC])\n" + + " LogicalJoin(condition=[=($2, $5)], joinType=[full])\n" + + " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER" + " ()])\n" - + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" - + " LogicalProject(DEPTNO=[$7])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n" - + " LogicalProject(avg(SAL)=[$1], DEPTNO=[$0], _row_number_subsearch_=[ROW_NUMBER()" - + " OVER ()])\n" - + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" - + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + " LogicalAggregate(group=[{0}], count()=[COUNT()])\n" + + " LogicalProject(DEPTNO=[$7])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(avg(SAL)=[$1], DEPTNO=[$0]," + + " _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n" + + " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n" + + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); String expectedResult = "" @@ -200,7 +223,10 @@ public void testAppendcolStatsOverride() { + "FULL JOIN (SELECT AVG(`SAL`) `avg(SAL)`, `DEPTNO`, ROW_NUMBER() OVER ()" + " `_row_number_subsearch_`\n" + "FROM `scott`.`EMP`\n" - + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` = `t4`.`_row_number_subsearch_`"; + + "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` =" + + " `t4`.`_row_number_subsearch_`\n" + + "ORDER BY `t1`.`_row_number_main_` NULLS LAST, `t4`.`_row_number_subsearch_` NULLS" + + " LAST"; verifyPPLToSparkSQL(root, expectedSparkSql); } }