diff --git a/pinot-common/src/main/proto/plan.proto b/pinot-common/src/main/proto/plan.proto index 352cd92f4c19..68edeb46911f 100644 --- a/pinot-common/src/main/proto/plan.proto +++ b/pinot-common/src/main/proto/plan.proto @@ -243,6 +243,15 @@ enum WindowFrameType { RANGE = 1; } +// SQL standard `` clause. Names are wire-stable; renumbering breaks mixed-version brokers/servers. +// `EXCLUDE_NO_OTHERS = 0` so that nodes serialized before this field existed deserialize to the default behavior. +enum WindowExclusion { + EXCLUDE_NO_OTHERS = 0; + EXCLUDE_CURRENT_ROW = 1; + EXCLUDE_GROUP = 2; + EXCLUDE_TIES = 3; +} + message WindowNode { repeated int32 keys = 1; repeated Collation collations = 2; @@ -251,6 +260,7 @@ message WindowNode { int32 lowerBound = 5; int32 upperBound = 6; repeated Literal constants = 7; + WindowExclusion exclude = 8; } // A node that doesn't carry semantic information. diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java index c60027778a4e..1b4821fb8aeb 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/explain/PlanNodeMerger.java @@ -435,6 +435,9 @@ public PlanNode visitWindow(WindowNode node, PlanNode context) { if (node.getUpperBound() != otherNode.getUpperBound()) { return null; } + if (node.getExclude() != otherNode.getExclude()) { + return null; + } if (!node.getConstants().equals(otherNode.getConstants())) { return null; } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java index eb8899dc8fba..3290b6e22502 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/PlanNodeToRelConverter.java @@ -377,7 +377,7 @@ public Void visitWindow(WindowNode node, Void context) { Window.Group group = new Window.Group(keys, isRow, getWindowBound(node.getLowerBound()), getWindowBound(node.getUpperBound()), - RexWindowExclusion.EXCLUDE_NO_OTHER, orderKeys, aggCalls); + toRexWindowExclusion(node.getExclude()), orderKeys, aggCalls); List constants = node.getConstants().stream().map(constant -> RexExpressionUtils.toRexLiteral(_builder, constant)) @@ -395,6 +395,21 @@ public Void visitWindow(WindowNode node, Void context) { return null; } + private static RexWindowExclusion toRexWindowExclusion(WindowNode.WindowExclusion exclude) { + switch (exclude) { + case NO_OTHERS: + return RexWindowExclusion.EXCLUDE_NO_OTHER; + case CURRENT_ROW: + return RexWindowExclusion.EXCLUDE_CURRENT_ROW; + case GROUP: + return RexWindowExclusion.EXCLUDE_GROUP; + case TIES: + return RexWindowExclusion.EXCLUDE_TIES; + default: + throw new IllegalStateException("Unsupported WindowExclusion: " + exclude); + } + } + private RexWindowBound getWindowBound(int bound) { if (bound == Integer.MIN_VALUE) { return RexWindowBounds.UNBOUNDED_PRECEDING; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java index aaeda2328716..be8192e344b7 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java @@ -629,18 +629,12 @@ private ValueNode convertLogicalValues(LogicalValues node) { convertInputs(node.getInputs()), literalRows); } - /** - * TODO: Add support for exclude clauses ({@link org.apache.calcite.rex.RexWindowExclusion}) - */ private WindowNode convertLogicalWindow(LogicalWindow node) { // Only a single Window Group should exist per WindowNode. Preconditions.checkState(node.groups.size() == 1, "Only a single window group is allowed, got: %s", node.groups.size()); Window.Group windowGroup = node.groups.get(0); - Preconditions.checkState(windowGroup.exclude == RexWindowExclusion.EXCLUDE_NO_OTHER, - "EXCLUDE clauses for window functions are not currently supported"); - int numAggregates = windowGroup.aggCalls.size(); List aggCalls = new ArrayList<>(numAggregates); for (int i = 0; i < numAggregates; i++) { @@ -684,7 +678,18 @@ private WindowNode convertLogicalWindow(LogicalWindow node) { } return new WindowNode(DEFAULT_STAGE_ID, toDataSchema(node.getRowType()), NodeHint.fromRelHints(node.getHints()), convertInputs(node.getInputs()), windowGroup.keys.asList(), windowGroup.orderKeys.getFieldCollations(), - aggCalls, windowFrameType, lowerBound, upperBound, constants); + aggCalls, windowFrameType, lowerBound, upperBound, fromRexWindowExclusion(windowGroup.exclude), constants); + } + + public static WindowNode.WindowExclusion fromRexWindowExclusion(RexWindowExclusion exclude) { + if (exclude == RexWindowExclusion.EXCLUDE_CURRENT_ROW) { + return WindowNode.WindowExclusion.CURRENT_ROW; + } else if (exclude == RexWindowExclusion.EXCLUDE_GROUP) { + return WindowNode.WindowExclusion.GROUP; + } else if (exclude == RexWindowExclusion.EXCLUDE_TIES) { + return WindowNode.WindowExclusion.TIES; + } + return WindowNode.WindowExclusion.NO_OTHERS; } private SortNode convertLogicalSort(LogicalSort node) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelToPlanNodeConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelToPlanNodeConverter.java index 7217d4676ad5..1de8193c252f 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelToPlanNodeConverter.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/v2/PRelToPlanNodeConverter.java @@ -47,6 +47,7 @@ import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.DatabaseUtils; import org.apache.pinot.common.utils.request.RequestUtils; +import org.apache.pinot.query.planner.logical.RelToPlanNodeConverter; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.planner.logical.RexExpressionUtils; import org.apache.pinot.query.planner.physical.v2.nodes.PhysicalAggregate; @@ -188,7 +189,8 @@ public static WindowNode convertWindow(Window node) { } return new WindowNode(DEFAULT_STAGE_ID, toDataSchema(node.getRowType()), NodeHint.fromRelHints(node.getHints()), new ArrayList<>(), windowGroup.keys.asList(), windowGroup.orderKeys.getFieldCollations(), - aggCalls, windowFrameType, lowerBound, upperBound, constants); + aggCalls, windowFrameType, lowerBound, upperBound, + RelToPlanNodeConverter.fromRexWindowExclusion(windowGroup.exclude), constants); } public static SortNode convertSort(Sort node) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java index 6263c277adfd..bd91330b67e0 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/WindowNode.java @@ -35,11 +35,12 @@ public class WindowNode extends BasePlanNode { // Integer.MAX_VALUE represents UNBOUNDED FOLLOWING which is only allowed for the upper bound (ensured by Calcite). private final int _lowerBound; private final int _upperBound; + private final WindowExclusion _exclude; private final List _constants; public WindowNode(int stageId, DataSchema dataSchema, NodeHint nodeHint, List inputs, List keys, List collations, List aggCalls, WindowFrameType windowFrameType, - int lowerBound, int upperBound, List constants) { + int lowerBound, int upperBound, WindowExclusion exclude, List constants) { super(stageId, dataSchema, nodeHint, inputs); _keys = keys; _collations = collations; @@ -47,6 +48,7 @@ public WindowNode(int stageId, DataSchema dataSchema, NodeHint nodeHint, List getConstants() { return _constants; } @@ -91,7 +97,7 @@ public T visit(PlanNodeVisitor visitor, C context) { @Override public PlanNode withInputs(List inputs) { return new WindowNode(_stageId, _dataSchema, _nodeHint, inputs, _keys, _collations, _aggCalls, _windowFrameType, - _lowerBound, _upperBound, _constants); + _lowerBound, _upperBound, _exclude, _constants); } @Override @@ -108,13 +114,14 @@ public boolean equals(Object o) { WindowNode that = (WindowNode) o; return _lowerBound == that._lowerBound && _upperBound == that._upperBound && Objects.equals(_aggCalls, that._aggCalls) && Objects.equals(_keys, that._keys) && Objects.equals(_collations, that._collations) - && _windowFrameType == that._windowFrameType && Objects.equals(_constants, that._constants); + && _windowFrameType == that._windowFrameType && _exclude == that._exclude + && Objects.equals(_constants, that._constants); } @Override public int hashCode() { return Objects.hash(super.hashCode(), _aggCalls, _keys, _collations, _windowFrameType, _lowerBound, _upperBound, - _constants); + _exclude, _constants); } /** @@ -125,4 +132,18 @@ public int hashCode() { public enum WindowFrameType { ROWS, RANGE } + + /** + * Enum to denote the frame exclusion option (SQL standard {@code EXCLUDE} clause). + * {@link #NO_OTHERS} is the default and means no rows are excluded. + * {@link #CURRENT_ROW} excludes only the current row from the frame. + * {@link #GROUP} excludes the current row and all its ordering peers. + * {@link #TIES} excludes the ordering peers of the current row but keeps the current row. + * + *

The constant names are part of the wire protocol via {@code Plan.WindowExclusion} and must remain stable across + * mixed-version brokers and servers. + */ + public enum WindowExclusion { + NO_OTHERS, CURRENT_ROW, GROUP, TIES + } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java index 7f2144124d36..3909c5cf2b22 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeDeserializer.java @@ -213,7 +213,8 @@ private static WindowNode deserializeWindowNode(Plan.PlanNode protoNode) { extractInputs(protoNode), protoWindowNode.getKeysList(), convertCollations(protoWindowNode.getCollationsList()), convertFunctionCalls(protoWindowNode.getAggCallsList()), convertWindowFrameType(protoWindowNode.getWindowFrameType()), protoWindowNode.getLowerBound(), - protoWindowNode.getUpperBound(), convertLiterals(protoWindowNode.getConstantsList())); + protoWindowNode.getUpperBound(), convertWindowExclusion(protoWindowNode.getExclude()), + convertLiterals(protoWindowNode.getConstantsList())); } private static ExplainedNode deserializeExplainedNode(Plan.PlanNode protoNode) { @@ -496,4 +497,19 @@ private static WindowNode.WindowFrameType convertWindowFrameType(Plan.WindowFram throw new IllegalStateException("Unsupported WindowFrameType: " + windowFrameType); } } + + private static WindowNode.WindowExclusion convertWindowExclusion(Plan.WindowExclusion exclude) { + switch (exclude) { + case EXCLUDE_NO_OTHERS: + return WindowNode.WindowExclusion.NO_OTHERS; + case EXCLUDE_CURRENT_ROW: + return WindowNode.WindowExclusion.CURRENT_ROW; + case EXCLUDE_GROUP: + return WindowNode.WindowExclusion.GROUP; + case EXCLUDE_TIES: + return WindowNode.WindowExclusion.TIES; + default: + throw new IllegalStateException("Unsupported WindowExclusion: " + exclude); + } + } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java index 99665d5c29bc..01dc25704e31 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/PlanNodeSerializer.java @@ -261,6 +261,7 @@ public Void visitWindow(WindowNode node, Plan.PlanNode.Builder builder) { .setWindowFrameType(convertWindowFrameType(node.getWindowFrameType())) .setLowerBound(node.getLowerBound()) .setUpperBound(node.getUpperBound()) + .setExclude(convertWindowExclusion(node.getExclude())) .addAllConstants(convertLiterals(node.getConstants())) .build(); builder.setWindowNode(windowNode); @@ -477,5 +478,20 @@ private static Plan.WindowFrameType convertWindowFrameType(WindowNode.WindowFram throw new IllegalStateException("Unsupported WindowFrameType: " + windowFrameType); } } + + private static Plan.WindowExclusion convertWindowExclusion(WindowNode.WindowExclusion exclude) { + switch (exclude) { + case NO_OTHERS: + return Plan.WindowExclusion.EXCLUDE_NO_OTHERS; + case CURRENT_ROW: + return Plan.WindowExclusion.EXCLUDE_CURRENT_ROW; + case GROUP: + return Plan.WindowExclusion.EXCLUDE_GROUP; + case TIES: + return Plan.WindowExclusion.EXCLUDE_TIES; + default: + throw new IllegalStateException("Unsupported WindowExclusion: " + exclude); + } + } } } diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java index 842bbf97c150..39decfabfd45 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java @@ -624,12 +624,6 @@ public void testWindowFunctions() { + "CURRENT ROW) FROM a"; e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(ntileQueryWithNoArg)); assertTrue(e.getMessage().contains("expecting 1 argument")); - - String excludeCurrentRowQuery = - "SELECT col1, col2, SUM(col3) OVER (PARTITION BY col1 ORDER BY col2 ROWS BETWEEN UNBOUNDED PRECEDING AND " - + "CURRENT ROW EXCLUDE CURRENT ROW) FROM a"; - e = expectThrows(RuntimeException.class, () -> _queryEnvironment.planQuery(excludeCurrentRowQuery)); - assertTrue(e.getMessage().contains("EXCLUDE clauses for window functions are not currently supported")); } @Test diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java index f7583d1b4262..a923f62f17e3 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java @@ -123,7 +123,8 @@ public WindowAggregateOperator(OpChainExecutionContext context, MultiStageOperat for (int i = 0; i < numKeys; i++) { _keys[i] = keys.get(i); } - WindowFrame windowFrame = new WindowFrame(node.getWindowFrameType(), node.getLowerBound(), node.getUpperBound()); + WindowFrame windowFrame = + new WindowFrame(node.getWindowFrameType(), node.getLowerBound(), node.getUpperBound(), node.getExclude()); Preconditions.checkState( windowFrame.isRowType() || ((windowFrame.isUnboundedPreceding() || windowFrame.isLowerBoundCurrentRow()) && ( windowFrame.isUnboundedFollowing() || windowFrame.isUpperBoundCurrentRow())), diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFrame.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFrame.java index f1d8a7a902f3..2a110defc69b 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFrame.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFrame.java @@ -23,7 +23,8 @@ /** * Defines the window frame to be used for a window function. The 'lowerBound' and 'upperBound' indicate the frame - * boundaries to be used. The frame can be of two types: ROWS or RANGE. + * boundaries to be used. The frame can be of two types: ROWS or RANGE. Optionally an {@code EXCLUDE} clause may + * specify a subset of rows around the current row to be excluded from the frame. */ public class WindowFrame { // Enum to denote the FRAME type, can be either ROWS or RANGE types @@ -33,11 +34,14 @@ public class WindowFrame { // Integer.MAX_VALUE represents UNBOUNDED FOLLOWING which is only allowed for the upper bound (ensured by Calcite). private final int _lowerBound; private final int _upperBound; + private final WindowNode.WindowExclusion _exclude; - public WindowFrame(WindowNode.WindowFrameType type, int lowerBound, int upperBound) { + public WindowFrame(WindowNode.WindowFrameType type, int lowerBound, int upperBound, + WindowNode.WindowExclusion exclude) { _type = type; _lowerBound = lowerBound; _upperBound = upperBound; + _exclude = exclude; } public boolean isUnboundedPreceding() { @@ -68,8 +72,17 @@ public int getUpperBound() { return _upperBound; } + public WindowNode.WindowExclusion getExclude() { + return _exclude; + } + + public boolean isExcludeNoOthers() { + return _exclude == WindowNode.WindowExclusion.NO_OTHERS; + } + @Override public String toString() { - return "WindowFrame{" + "type=" + _type + ", lowerBound=" + _lowerBound + ", upperBound=" + _upperBound + '}'; + return "WindowFrame{type=" + _type + ", lowerBound=" + _lowerBound + ", upperBound=" + _upperBound + ", exclude=" + + _exclude + '}'; } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java index 19aabcfe9dd4..032189c89c9d 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/WindowFunction.java @@ -19,9 +19,11 @@ package org.apache.pinot.query.runtime.operator.window; import java.util.List; +import java.util.Objects; import org.apache.calcite.rel.RelFieldCollation; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.planner.plannode.WindowNode; import org.apache.pinot.query.runtime.operator.WindowAggregateOperator; import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; @@ -66,4 +68,133 @@ public WindowFunction(RexExpression.FunctionCall aggCall, DataSchema inputSchema protected Object extractValueFromRow(Object[] row) { return _inputRef == -1 ? _literal : (row == null ? null : row[_inputRef]); } + + /** + * Returns whether peer-group information is required to apply the given EXCLUDE clause. {@code CURRENT_ROW} only + * touches the current row, so callers can skip the O(n) peer-boundary computation when the frame doesn't otherwise + * depend on peer bounds (i.e. ROWS frames). + */ + protected static boolean needsPeerBoundaries(WindowNode.WindowExclusion exclude) { + return exclude == WindowNode.WindowExclusion.GROUP || exclude == WindowNode.WindowExclusion.TIES; + } + + /** + * Fills {@code peerStart} and {@code peerEnd} with the inclusive bounds of each row's peer group based on the + * ORDER BY keys. Rows are peers iff they share the same ORDER BY values. With no ORDER BY clause every row is a peer + * of every other row in the partition. + */ + protected void computePeerBoundaries(List rows, int[] peerStart, int[] peerEnd) { + int numRows = rows.size(); + if (_orderKeys.length == 0) { + for (int i = 0; i < numRows; i++) { + peerStart[i] = 0; + peerEnd[i] = numRows - 1; + } + return; + } + int groupStart = 0; + for (int i = 1; i <= numRows; i++) { + if (i == numRows || !samePeerKey(rows.get(i - 1), rows.get(i))) { + int peerLast = i - 1; + for (int j = groupStart; j < i; j++) { + peerStart[j] = groupStart; + peerEnd[j] = peerLast; + } + groupStart = i; + } + } + } + + private boolean samePeerKey(Object[] a, Object[] b) { + for (int k : _orderKeys) { + if (!Objects.equals(a[k], b[k])) { + return false; + } + } + return true; + } + + /** + * Returns the first index in {@code [fs, fe]} that is not excluded by the SQL EXCLUDE clause for the current row + * {@code i} whose peer group is {@code [pStart, pEnd]}. Returns {@code -1} if no such index exists. + */ + protected static int firstNonExcluded(int fs, int fe, int i, int pStart, int pEnd, + WindowNode.WindowExclusion exclude) { + if (fs > fe) { + return -1; + } + switch (exclude) { + case CURRENT_ROW: + return fs == i ? (fs + 1 > fe ? -1 : fs + 1) : fs; + case GROUP: + if (fs < pStart || fs > pEnd) { + return fs; + } + return pEnd + 1 > fe ? -1 : pEnd + 1; + case TIES: + if (fs == i || fs < pStart || fs > pEnd) { + return fs; + } + if (i >= fs && i <= fe) { + return i; + } + return pEnd + 1 > fe ? -1 : pEnd + 1; + default: + return fs; + } + } + + /** + * Returns the inclusive lower index of the base frame for the row at index {@code i} given its peer group + * {@code [pStart, pEnd]} and the total {@code numRows} in the partition. + */ + protected int frameStartForRow(int i, int pStart, int numRows) { + if (_windowFrame.isRowType()) { + int lb = _windowFrame.getLowerBound(); + return lb == Integer.MIN_VALUE ? 0 : Math.max(0, lb + i); + } + return _windowFrame.isUnboundedPreceding() ? 0 : pStart; + } + + /** + * Returns the inclusive upper index of the base frame for the row at index {@code i} given its peer group + * {@code [pStart, pEnd]} and the total {@code numRows} in the partition. + */ + protected int frameEndForRow(int i, int pEnd, int numRows) { + if (_windowFrame.isRowType()) { + int ub = _windowFrame.getUpperBound(); + return ub == Integer.MAX_VALUE ? numRows - 1 : Math.min(numRows - 1, ub + i); + } + return _windowFrame.isUnboundedFollowing() ? numRows - 1 : pEnd; + } + + /** + * Returns the last index in {@code [fs, fe]} that is not excluded by the SQL EXCLUDE clause for the current row + * {@code i} whose peer group is {@code [pStart, pEnd]}. Returns {@code -1} if no such index exists. + */ + protected static int lastNonExcluded(int fs, int fe, int i, int pStart, int pEnd, + WindowNode.WindowExclusion exclude) { + if (fs > fe) { + return -1; + } + switch (exclude) { + case CURRENT_ROW: + return fe == i ? (fe - 1 < fs ? -1 : fe - 1) : fe; + case GROUP: + if (fe < pStart || fe > pEnd) { + return fe; + } + return pStart - 1 < fs ? -1 : pStart - 1; + case TIES: + if (fe == i || fe < pStart || fe > pEnd) { + return fe; + } + if (i >= fs && i <= fe) { + return i; + } + return pStart - 1 < fs ? -1 : pStart - 1; + default: + return fe; + } + } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java index 2bb7d6adffa6..b82be9c1be21 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/AggregateWindowFunction.java @@ -27,6 +27,7 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.data.table.Key; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.planner.plannode.WindowNode; import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; import org.apache.pinot.query.runtime.operator.window.WindowFrame; import org.apache.pinot.query.runtime.operator.window.WindowFunction; @@ -40,18 +41,21 @@ public AggregateWindowFunction(RexExpression.FunctionCall aggCall, DataSchema in List collations, WindowFrame windowFrame) { super(aggCall, inputSchema, collations, windowFrame); _functionName = aggCall.getFunctionName(); + // Removal support is required for sliding ROWS frames and whenever an EXCLUDE clause forces per-row corrections. + boolean nonDefaultExclude = !windowFrame.isExcludeNoOthers(); + boolean supportRemoval = nonDefaultExclude || (windowFrame.isRowType() && !( + windowFrame.isUnboundedPreceding() && windowFrame.isUnboundedFollowing())); _windowValueAggregator = WindowValueAggregatorFactory.getWindowValueAggregator(_functionName, _dataType, - windowFrame.isRowType() && !(_windowFrame.isUnboundedPreceding() && _windowFrame.isUnboundedFollowing())); + supportRemoval, nonDefaultExclude); } @Override public final List processRows(List rows) { _windowValueAggregator.clear(); - if (_windowFrame.isRowType()) { - return processRowsWindow(rows); - } else { - return processRangeWindow(rows); + if (_windowFrame.isExcludeNoOthers()) { + return _windowFrame.isRowType() ? processRowsWindow(rows) : processRangeWindow(rows); } + return _windowFrame.isRowType() ? processRowsWindowWithExclude(rows) : processRangeWindowWithExclude(rows); } /** @@ -172,4 +176,203 @@ private List processRangeWindow(List rows) { throw new IllegalStateException("RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); } } + + /** + * ROWS frame with a non-default EXCLUDE clause. Loads the base frame into the aggregator and removes / re-adds the + * excluded values per row. Peer-group boundaries are precomputed once per partition. + */ + private List processRowsWindowWithExclude(List rows) { + int numRows = rows.size(); + WindowNode.WindowExclusion exclude = _windowFrame.getExclude(); + int[] peerStart = null; + int[] peerEnd = null; + if (needsPeerBoundaries(exclude)) { + peerStart = new int[numRows]; + peerEnd = new int[numRows]; + computePeerBoundaries(rows, peerStart, peerEnd); + } + + int lowerBound = _windowFrame.getLowerBound(); + int upperBound = Math.min(_windowFrame.getUpperBound(), numRows - 1); + + for (int i = Math.max(0, lowerBound); i <= upperBound; i++) { + _windowValueAggregator.addValue(extractValueFromRow(rows.get(i))); + } + + List result = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + if (lowerBound >= numRows) { + for (int j = i; j < numRows; j++) { + result.add(null); + } + return result; + } + int frameStart = Math.max(0, lowerBound); + int frameEnd = upperBound; + int pStart = peerStart != null ? peerStart[i] : i; + int pEnd = peerEnd != null ? peerEnd[i] : i; + + applyExclude(rows, i, frameStart, frameEnd, pStart, pEnd, exclude, true); + result.add(_windowValueAggregator.getCurrentAggregatedValue()); + applyExclude(rows, i, frameStart, frameEnd, pStart, pEnd, exclude, false); + + if (lowerBound >= 0) { + _windowValueAggregator.removeValue(extractValueFromRow(rows.get(lowerBound))); + } + lowerBound++; + if (upperBound < numRows - 1) { + upperBound++; + if (upperBound >= 0) { + _windowValueAggregator.addValue(extractValueFromRow(rows.get(upperBound))); + } + } + } + return result; + } + + /** + * RANGE frame with a non-default EXCLUDE clause. The frame for each row is determined by its peer group; we maintain + * the aggregator state corresponding to the base frame and apply per-row EXCLUDE corrections. + */ + private List processRangeWindowWithExclude(List rows) { + int numRows = rows.size(); + int[] peerStart = new int[numRows]; + int[] peerEnd = new int[numRows]; + computePeerBoundaries(rows, peerStart, peerEnd); + + boolean lowerCurrentRow = _windowFrame.isLowerBoundCurrentRow(); + boolean upperCurrentRow = _windowFrame.isUpperBoundCurrentRow(); + WindowNode.WindowExclusion exclude = _windowFrame.getExclude(); + List result = new ArrayList<>(numRows); + + if (_windowFrame.isUnboundedPreceding() && _windowFrame.isUnboundedFollowing()) { + // Frame = whole partition for every row + for (Object[] row : rows) { + _windowValueAggregator.addValue(extractValueFromRow(row)); + } + for (int i = 0; i < numRows; i++) { + applyExclude(rows, i, 0, numRows - 1, peerStart[i], peerEnd[i], exclude, true); + result.add(_windowValueAggregator.getCurrentAggregatedValue()); + applyExclude(rows, i, 0, numRows - 1, peerStart[i], peerEnd[i], exclude, false); + } + return result; + } + + if (_windowFrame.isUnboundedPreceding() && upperCurrentRow) { + // Frame for row i = [0, peerEnd[i]]; aggregator is built peer-group by peer-group + int loaded = 0; + int i = 0; + while (i < numRows) { + int end = peerEnd[i]; + for (int j = loaded; j <= end; j++) { + _windowValueAggregator.addValue(extractValueFromRow(rows.get(j))); + } + loaded = end + 1; + while (i <= end) { + applyExclude(rows, i, 0, end, peerStart[i], peerEnd[i], exclude, true); + result.add(_windowValueAggregator.getCurrentAggregatedValue()); + applyExclude(rows, i, 0, end, peerStart[i], peerEnd[i], exclude, false); + i++; + } + } + return result; + } + + if (lowerCurrentRow && _windowFrame.isUnboundedFollowing()) { + // Frame for row i = [peerStart[i], numRows-1]; build up the aggregator peer group by peer group, from the + // rightmost peer toward the leftmost. After adding peer g, the aggregator contains [peerStart_g, numRows-1]. + Object[] perRow = new Object[numRows]; + int i = numRows - 1; + while (i >= 0) { + int start = peerStart[i]; + for (int j = start; j <= i; j++) { + _windowValueAggregator.addValue(extractValueFromRow(rows.get(j))); + } + for (int j = start; j <= i; j++) { + applyExclude(rows, j, start, numRows - 1, peerStart[j], peerEnd[j], exclude, true); + perRow[j] = _windowValueAggregator.getCurrentAggregatedValue(); + applyExclude(rows, j, start, numRows - 1, peerStart[j], peerEnd[j], exclude, false); + } + i = start - 1; + } + Collections.addAll(result, perRow); + return result; + } + + if (lowerCurrentRow && upperCurrentRow) { + // Frame for row i = peer group of i; load each peer group separately + int i = 0; + while (i < numRows) { + int start = peerStart[i]; + int end = peerEnd[i]; + for (int j = start; j <= end; j++) { + _windowValueAggregator.addValue(extractValueFromRow(rows.get(j))); + } + while (i <= end) { + applyExclude(rows, i, start, end, start, end, exclude, true); + result.add(_windowValueAggregator.getCurrentAggregatedValue()); + applyExclude(rows, i, start, end, start, end, exclude, false); + i++; + } + for (int j = start; j <= end; j++) { + _windowValueAggregator.removeValue(extractValueFromRow(rows.get(j))); + } + } + return result; + } + + throw new IllegalStateException("RANGE window frame with offset PRECEDING / FOLLOWING is not supported"); + } + + /** + * Removes (when {@code remove} is true) or re-adds (otherwise) the rows in the EXCLUDE set, restricted to the base + * frame {@code [frameStart, frameEnd]}. The exclude set is derived from the current row {@code i} and its peer group + * {@code [pStart, pEnd]} as defined by SQL's EXCLUDE clause. + */ + private void applyExclude(List rows, int i, int frameStart, int frameEnd, int pStart, int pEnd, + WindowNode.WindowExclusion exclude, boolean remove) { + switch (exclude) { + case CURRENT_ROW: + if (i >= frameStart && i <= frameEnd) { + Object value = extractValueFromRow(rows.get(i)); + if (remove) { + _windowValueAggregator.removeValue(value); + } else { + _windowValueAggregator.addValue(value); + } + } + break; + case GROUP: { + int from = Math.max(pStart, frameStart); + int to = Math.min(pEnd, frameEnd); + for (int j = from; j <= to; j++) { + Object value = extractValueFromRow(rows.get(j)); + if (remove) { + _windowValueAggregator.removeValue(value); + } else { + _windowValueAggregator.addValue(value); + } + } + break; + } + case TIES: { + int from = Math.max(pStart, frameStart); + int to = Math.min(pEnd, frameEnd); + for (int j = from; j <= to; j++) { + if (j == i) { + continue; + } + Object value = extractValueFromRow(rows.get(j)); + if (remove) { + _windowValueAggregator.removeValue(value); + } else { + _windowValueAggregator.addValue(value); + } + } + break; + } + default: + throw new IllegalStateException("Unsupported WindowExclusion: " + exclude); + } + } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/SortedMultisetMinMaxWindowValueAggregator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/SortedMultisetMinMaxWindowValueAggregator.java new file mode 100644 index 000000000000..7ecdb762b65b --- /dev/null +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/SortedMultisetMinMaxWindowValueAggregator.java @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.runtime.operator.window.aggregate; + +import java.util.TreeMap; +import javax.annotation.Nullable; + + +/** + * MIN / MAX window aggregator backed by a sorted multiset (a {@link TreeMap} of value to occurrence count). Unlike the + * monotonic-deque aggregators used for sliding-window MIN / MAX, this aggregator supports removal of arbitrary values + * in any order — necessary for window frames with a non-default {@code EXCLUDE} clause. Add / remove / query are all + * {@code O(log K)} where {@code K} is the number of distinct values currently in the window. + * + *

All non-null values added to a single instance must implement {@link Comparable} and share a runtime type that is + * mutually comparable; mixing types (e.g. {@code Integer} and {@code Long}) is undefined. Callers obtain instances + * through {@link WindowValueAggregatorFactory}, which guarantees this because window function values come from a single + * typed input column. {@code removeValue} for a value not currently present is a no-op. + */ +public class SortedMultisetMinMaxWindowValueAggregator implements WindowValueAggregator { + + private final TreeMap _counts = new TreeMap<>(SortedMultisetMinMaxWindowValueAggregator::compare); + private final boolean _isMin; + + public SortedMultisetMinMaxWindowValueAggregator(boolean isMin) { + _isMin = isMin; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static int compare(Object a, Object b) { + return ((Comparable) a).compareTo(b); + } + + @Override + public void addValue(@Nullable Object value) { + if (value != null) { + _counts.merge(value, 1, Integer::sum); + } + } + + @Override + public void removeValue(@Nullable Object value) { + if (value == null) { + return; + } + Integer count = _counts.get(value); + if (count == null) { + return; + } + if (count == 1) { + _counts.remove(value); + } else { + _counts.put(value, count - 1); + } + } + + @Nullable + @Override + public Object getCurrentAggregatedValue() { + if (_counts.isEmpty()) { + return null; + } + return _isMin ? _counts.firstKey() : _counts.lastKey(); + } + + @Override + public void clear() { + _counts.clear(); + } +} diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorFactory.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorFactory.java index 6ab348d5b253..4cc90bb6106c 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorFactory.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorFactory.java @@ -38,6 +38,17 @@ private WindowValueAggregatorFactory() { */ public static WindowValueAggregator getWindowValueAggregator(String functionName, DataSchema.ColumnDataType columnDataType, boolean supportRemoval) { + return getWindowValueAggregator(functionName, columnDataType, supportRemoval, false); + } + + /** + * Returns a window value aggregator for the given function/type. When {@code arbitraryRemoval} is set, MIN / MAX + * aggregators that normally use a monotonic deque (which assumes values are removed in arrival order) are replaced + * with a sorted-multiset implementation so callers may remove any value at any time — required for window frames + * with a non-default {@code EXCLUDE} clause. + */ + public static WindowValueAggregator getWindowValueAggregator(String functionName, + DataSchema.ColumnDataType columnDataType, boolean supportRemoval, boolean arbitraryRemoval) { DataSchema.ColumnDataType storedType = columnDataType.getStoredType(); switch (functionName) { // NOTE: Keep both 'SUM0' and '$SUM0' for backward compatibility where 'SUM0' is SqlKind and '$SUM0' is function @@ -49,9 +60,11 @@ public static WindowValueAggregator getWindowValueAggregator(String func case "AVG": return new AvgWindowValueAggregator(); case "MIN": - return createMinAggregator(storedType, supportRemoval); + return arbitraryRemoval ? new SortedMultisetMinMaxWindowValueAggregator(true) + : createMinAggregator(storedType, supportRemoval); case "MAX": - return createMaxAggregator(storedType, supportRemoval); + return arbitraryRemoval ? new SortedMultisetMinMaxWindowValueAggregator(false) + : createMaxAggregator(storedType, supportRemoval); case "COUNT": return new CountWindowValueAggregator(); case "BOOLAND": diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java index 7baebe8b41f6..8ad0786fa61d 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/FirstValueWindowFunction.java @@ -29,6 +29,7 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.data.table.Key; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.planner.plannode.WindowNode; import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; import org.apache.pinot.query.runtime.operator.window.WindowFrame; @@ -42,6 +43,9 @@ public FirstValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema i @Override public List processRows(List rows) { + if (!_windowFrame.isExcludeNoOthers()) { + return processWithExclude(rows); + } if (_windowFrame.isRowType()) { if (_ignoreNulls) { return processRowsWindowIgnoreNulls(rows); @@ -57,6 +61,39 @@ public List processRows(List rows) { } } + /** + * FIRST_VALUE for a non-default EXCLUDE clause. Computes the first non-excluded index for each row in O(1) using + * the precomputed peer-group boundaries; if {@code IGNORE NULLS} is set, advances past null values respecting the + * exclude rules. + */ + private List processWithExclude(List rows) { + int numRows = rows.size(); + WindowNode.WindowExclusion exclude = _windowFrame.getExclude(); + int[] peerStart = null; + int[] peerEnd = null; + if (needsPeerBoundaries(exclude) || !_windowFrame.isRowType()) { + peerStart = new int[numRows]; + peerEnd = new int[numRows]; + computePeerBoundaries(rows, peerStart, peerEnd); + } + + List result = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + int pStart = peerStart != null ? peerStart[i] : i; + int pEnd = peerEnd != null ? peerEnd[i] : i; + int fs = frameStartForRow(i, pStart, numRows); + int fe = frameEndForRow(i, pEnd, numRows); + int idx = firstNonExcluded(fs, fe, i, pStart, pEnd, exclude); + if (_ignoreNulls) { + while (idx != -1 && extractValueFromRow(rows.get(idx)) == null) { + idx = firstNonExcluded(idx + 1, fe, i, pStart, pEnd, exclude); + } + } + result.add(idx == -1 ? null : extractValueFromRow(rows.get(idx))); + } + return result; + } + private List processRowsWindow(List rows) { int numRows = rows.size(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java index 8da5fa6fb13c..900a04456716 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/window/value/LastValueWindowFunction.java @@ -29,6 +29,7 @@ import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.data.table.Key; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.planner.plannode.WindowNode; import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; import org.apache.pinot.query.runtime.operator.window.WindowFrame; @@ -42,6 +43,9 @@ public LastValueWindowFunction(RexExpression.FunctionCall aggCall, DataSchema in @Override public List processRows(List rows) { + if (!_windowFrame.isExcludeNoOthers()) { + return processWithExclude(rows); + } if (_windowFrame.isRowType()) { if (_ignoreNulls) { return processRowsWindowIgnoreNulls(rows); @@ -57,6 +61,39 @@ public List processRows(List rows) { } } + /** + * LAST_VALUE for a non-default EXCLUDE clause. Computes the last non-excluded index for each row in O(1) using the + * precomputed peer-group boundaries; if {@code IGNORE NULLS} is set, scans backward past null values respecting the + * exclude rules. + */ + private List processWithExclude(List rows) { + int numRows = rows.size(); + WindowNode.WindowExclusion exclude = _windowFrame.getExclude(); + int[] peerStart = null; + int[] peerEnd = null; + if (needsPeerBoundaries(exclude) || !_windowFrame.isRowType()) { + peerStart = new int[numRows]; + peerEnd = new int[numRows]; + computePeerBoundaries(rows, peerStart, peerEnd); + } + + List result = new ArrayList<>(numRows); + for (int i = 0; i < numRows; i++) { + int pStart = peerStart != null ? peerStart[i] : i; + int pEnd = peerEnd != null ? peerEnd[i] : i; + int fs = frameStartForRow(i, pStart, numRows); + int fe = frameEndForRow(i, pEnd, numRows); + int idx = lastNonExcluded(fs, fe, i, pStart, pEnd, exclude); + if (_ignoreNulls) { + while (idx != -1 && extractValueFromRow(rows.get(idx)) == null) { + idx = lastNonExcluded(fs, idx - 1, i, pStart, pEnd, exclude); + } + } + result.add(idx == -1 ? null : extractValueFromRow(rows.get(idx))); + } + return result; + } + private List processRowsWindow(List rows) { int numRows = rows.size(); if (_windowFrame.isUnboundedFollowing() && _windowFrame.getLowerBound() <= 0) { diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java index 0ab3cab16ae6..a6775121b35d 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java @@ -3061,7 +3061,7 @@ private WindowAggregateOperator getOperator(DataSchema inputSchema, DataSchema r MultiStageOperator input) { return new WindowAggregateOperator(OperatorTestUtil.getTracingContext(), input, inputSchema, new WindowNode(-1, resultSchema, nodeHint, List.of(), keys, collations, aggCalls, windowFrameType, lowerBound, - upperBound, List.of())); + upperBound, WindowNode.WindowExclusion.NO_OTHERS, List.of())); } private WindowAggregateOperator getOperator(DataSchema inputSchema, DataSchema resultSchema, List keys, diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorTest.java index 94d46d5d875c..8dcce397c993 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/window/aggregate/WindowValueAggregatorTest.java @@ -753,4 +753,55 @@ public void testComparableMaxPreservesLongType() { assertTrue(result instanceof Long); assertEquals(result, largeVal); } + + // ======================== + // SortedMultisetMinMaxWindowValueAggregator (used for MIN/MAX under EXCLUDE) + // ======================== + + @Test + public void testSortedMultisetMin() { + WindowValueAggregator agg = new SortedMultisetMinMaxWindowValueAggregator(true); + assertNull(agg.getCurrentAggregatedValue()); + agg.addValue(5); + agg.addValue(2); + agg.addValue(2); + agg.addValue(8); + assertEquals(agg.getCurrentAggregatedValue(), 2); + // Removing one occurrence of the duplicate min leaves the other + agg.removeValue(2); + assertEquals(agg.getCurrentAggregatedValue(), 2); + agg.removeValue(2); + assertEquals(agg.getCurrentAggregatedValue(), 5); + // Out-of-order removal still works (we remove the max next) + agg.removeValue(8); + assertEquals(agg.getCurrentAggregatedValue(), 5); + agg.removeValue(5); + assertNull(agg.getCurrentAggregatedValue()); + } + + @Test + public void testSortedMultisetMaxIgnoresNullsAndUnknownRemovals() { + WindowValueAggregator agg = new SortedMultisetMinMaxWindowValueAggregator(false); + agg.addValue(null); + agg.addValue(7); + agg.addValue(3); + assertEquals(agg.getCurrentAggregatedValue(), 7); + // Removing a value never added is a no-op + agg.removeValue(99); + agg.removeValue(null); + assertEquals(agg.getCurrentAggregatedValue(), 7); + agg.clear(); + assertNull(agg.getCurrentAggregatedValue()); + } + + @Test + public void testSortedMultisetWithBigDecimal() { + WindowValueAggregator agg = new SortedMultisetMinMaxWindowValueAggregator(true); + agg.addValue(new BigDecimal("3.14")); + agg.addValue(new BigDecimal("2.71")); + agg.addValue(new BigDecimal("1.41")); + assertEquals(agg.getCurrentAggregatedValue(), new BigDecimal("1.41")); + agg.removeValue(new BigDecimal("1.41")); + assertEquals(agg.getCurrentAggregatedValue(), new BigDecimal("2.71")); + } } diff --git a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json index 76f15a6f104f..e15346ae437f 100644 --- a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json +++ b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json @@ -2378,6 +2378,252 @@ "outputs": [ ["a", 5, 1, 1] ] + }, + { + "description": "EXCLUDE CURRENT ROW on a bounded ROWS frame", + "sql": "SELECT string_col, int_col, double_col, SUM(int_col) OVER(PARTITION BY string_col ORDER BY int_col, double_col ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING EXCLUDE CURRENT ROW) FROM {tbl} ORDER BY string_col, int_col, double_col", + "keepOutputRowOrder": true, + "outputs": [ + ["a", 2, 300.0, 2], + ["a", 2, 400.0, 44], + ["a", 42, 42.0, 44], + ["a", 42, 50.5, 84], + ["a", 42, 75.0, 42], + ["b", 3, 100.0, 100], + ["b", 100, 1.0, 3], + ["c", 2, 400.0, 3], + ["c", 3, 100.0, 103], + ["c", 101, 1.01, 153], + ["c", 150, 1.5, 101], + ["d", 42, 42.0, null], + ["e", 42, 42.0, 42], + ["e", 42, 50.5, 42], + ["g", 3, 100.0, null], + ["h", 150, 1.53, null] + ] + }, + { + "description": "EXCLUDE GROUP on an unbounded RANGE frame", + "sql": "SELECT string_col, int_col, SUM(int_col) OVER(PARTITION BY string_col ORDER BY int_col RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING EXCLUDE GROUP) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 126], + ["a", 2, 126], + ["a", 42, 4], + ["a", 42, 4], + ["a", 42, 4], + ["b", 3, 100], + ["b", 100, 3], + ["c", 2, 254], + ["c", 3, 253], + ["c", 101, 155], + ["c", 150, 106], + ["d", 42, null], + ["e", 42, null], + ["e", 42, null], + ["g", 3, null], + ["h", 150, null] + ] + }, + { + "description": "EXCLUDE TIES on an unbounded RANGE frame", + "sql": "SELECT string_col, int_col, SUM(int_col) OVER(PARTITION BY string_col ORDER BY int_col RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING EXCLUDE TIES) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 128], + ["a", 2, 128], + ["a", 42, 46], + ["a", 42, 46], + ["a", 42, 46], + ["b", 3, 103], + ["b", 100, 103], + ["c", 2, 256], + ["c", 3, 256], + ["c", 101, 256], + ["c", 150, 256], + ["d", 42, 42], + ["e", 42, 42], + ["e", 42, 42], + ["g", 3, 3], + ["h", 150, 150] + ] + }, + { + "description": "EXCLUDE GROUP on RANGE UNBOUNDED PRECEDING to CURRENT ROW", + "sql": "SELECT string_col, int_col, COUNT(*) OVER(PARTITION BY string_col ORDER BY int_col RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW EXCLUDE GROUP) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 0], + ["a", 2, 0], + ["a", 42, 2], + ["a", 42, 2], + ["a", 42, 2], + ["b", 3, 0], + ["b", 100, 1], + ["c", 2, 0], + ["c", 3, 1], + ["c", 101, 2], + ["c", 150, 3], + ["d", 42, 0], + ["e", 42, 0], + ["e", 42, 0], + ["g", 3, 0], + ["h", 150, 0] + ] + }, + { + "description": "EXCLUDE TIES on RANGE CURRENT ROW to UNBOUNDED FOLLOWING", + "sql": "SELECT string_col, int_col, SUM(int_col) OVER(PARTITION BY string_col ORDER BY int_col RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING EXCLUDE TIES) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 128], + ["a", 2, 128], + ["a", 42, 42], + ["a", 42, 42], + ["a", 42, 42], + ["b", 3, 103], + ["b", 100, 100], + ["c", 2, 256], + ["c", 3, 254], + ["c", 101, 251], + ["c", 150, 150], + ["d", 42, 42], + ["e", 42, 42], + ["e", 42, 42], + ["g", 3, 3], + ["h", 150, 150] + ] + }, + { + "description": "EXCLUDE CURRENT ROW on RANGE CURRENT ROW to CURRENT ROW (peer-group only)", + "sql": "SELECT string_col, int_col, COUNT(*) OVER(PARTITION BY string_col ORDER BY int_col RANGE BETWEEN CURRENT ROW AND CURRENT ROW EXCLUDE CURRENT ROW) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 1], + ["a", 2, 1], + ["a", 42, 2], + ["a", 42, 2], + ["a", 42, 2], + ["b", 3, 0], + ["b", 100, 0], + ["c", 2, 0], + ["c", 3, 0], + ["c", 101, 0], + ["c", 150, 0], + ["d", 42, 0], + ["e", 42, 1], + ["e", 42, 1], + ["g", 3, 0], + ["h", 150, 0] + ] + }, + { + "description": "AVG with EXCLUDE CURRENT ROW on an unbounded ROWS frame", + "sql": "SELECT string_col, int_col, double_col, AVG(int_col) OVER(PARTITION BY string_col ORDER BY int_col, double_col ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING EXCLUDE CURRENT ROW) FROM {tbl} ORDER BY string_col, int_col, double_col", + "keepOutputRowOrder": true, + "outputs": [ + ["a", 2, 300.0, 32.0], + ["a", 2, 400.0, 32.0], + ["a", 42, 42.0, 22.0], + ["a", 42, 50.5, 22.0], + ["a", 42, 75.0, 22.0], + ["b", 3, 100.0, 100.0], + ["b", 100, 1.0, 3.0], + ["c", 2, 400.0, 84.66666666666667], + ["c", 3, 100.0, 84.33333333333333], + ["c", 101, 1.01, 51.666666666666664], + ["c", 150, 1.5, 35.333333333333336], + ["d", 42, 42.0, null], + ["e", 42, 42.0, 42.0], + ["e", 42, 50.5, 42.0], + ["g", 3, 100.0, null], + ["h", 150, 1.53, null] + ] + }, + { + "description": "MIN with EXCLUDE GROUP on an unbounded ROWS frame", + "sql": "SELECT string_col, int_col, MIN(int_col) OVER(PARTITION BY string_col ORDER BY int_col ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING EXCLUDE GROUP) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 42], + ["a", 2, 42], + ["a", 42, 2], + ["a", 42, 2], + ["a", 42, 2], + ["b", 3, 100], + ["b", 100, 3], + ["c", 2, 3], + ["c", 3, 2], + ["c", 101, 2], + ["c", 150, 2], + ["d", 42, null], + ["e", 42, null], + ["e", 42, null], + ["g", 3, null], + ["h", 150, null] + ] + }, + { + "description": "FIRST_VALUE with EXCLUDE CURRENT ROW on an unbounded ROWS frame", + "sql": "SELECT string_col, int_col, double_col, FIRST_VALUE(int_col) OVER(PARTITION BY string_col ORDER BY int_col, double_col ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING EXCLUDE CURRENT ROW) FROM {tbl} ORDER BY string_col, int_col, double_col", + "keepOutputRowOrder": true, + "outputs": [ + ["a", 2, 300.0, 2], + ["a", 2, 400.0, 2], + ["a", 42, 42.0, 2], + ["a", 42, 50.5, 2], + ["a", 42, 75.0, 2], + ["b", 3, 100.0, 100], + ["b", 100, 1.0, 3], + ["c", 2, 400.0, 3], + ["c", 3, 100.0, 2], + ["c", 101, 1.01, 2], + ["c", 150, 1.5, 2], + ["d", 42, 42.0, null], + ["e", 42, 42.0, 42], + ["e", 42, 50.5, 42], + ["g", 3, 100.0, null], + ["h", 150, 1.53, null] + ] + }, + { + "description": "LAST_VALUE with EXCLUDE TIES on an unbounded ROWS frame", + "sql": "SELECT string_col, int_col, LAST_VALUE(int_col) OVER(PARTITION BY string_col ORDER BY int_col ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING EXCLUDE TIES) FROM {tbl} ORDER BY string_col, int_col", + "outputs": [ + ["a", 2, 42], + ["a", 2, 42], + ["a", 42, 42], + ["a", 42, 42], + ["a", 42, 42], + ["b", 3, 100], + ["b", 100, 100], + ["c", 2, 150], + ["c", 3, 150], + ["c", 101, 150], + ["c", 150, 150], + ["d", 42, 42], + ["e", 42, 42], + ["e", 42, 42], + ["g", 3, 3], + ["h", 150, 150] + ] + }, + { + "description": "EXCLUDE CURRENT ROW with PARTITION BY only (no ORDER BY)", + "sql": "SELECT string_col, int_col, double_col, SUM(int_col) OVER(PARTITION BY string_col ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING EXCLUDE CURRENT ROW) FROM {tbl} ORDER BY string_col, int_col, double_col", + "keepOutputRowOrder": true, + "outputs": [ + ["a", 2, 300.0, 128], + ["a", 2, 400.0, 128], + ["a", 42, 42.0, 88], + ["a", 42, 50.5, 88], + ["a", 42, 75.0, 88], + ["b", 3, 100.0, 100], + ["b", 100, 1.0, 3], + ["c", 2, 400.0, 254], + ["c", 3, 100.0, 253], + ["c", 101, 1.01, 155], + ["c", 150, 1.5, 106], + ["d", 42, 42.0, null], + ["e", 42, 42.0, 42], + ["e", 42, 50.5, 42], + ["g", 3, 100.0, null], + ["h", 150, 1.53, null] + ] } ] },