diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java index 7285ba0b5c71..8bbdda5fcec2 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import javax.annotation.Nullable; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.common.BlockValSet; @@ -204,8 +205,8 @@ public ColumnDataType getFinalResultColumnType() { } @Override - public Long extractFinalResult(Long intermediateResult) { - return intermediateResult; + public Long extractFinalResult(@Nullable Long intermediateResult) { + return intermediateResult != null ? intermediateResult : 0L; } @Override diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java index 360233b0a204..cffce5dce688 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunctionTest.java @@ -19,12 +19,16 @@ package org.apache.pinot.core.query.aggregation.function; +import java.util.List; +import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.queries.FluentQueryTest; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.Schema; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import static org.testng.Assert.assertEquals; + public class CountAggregationFunctionTest extends AbstractAggregationFunctionTest { @@ -234,6 +238,14 @@ public void countGroupByMV() { ); } + @Test + public void testExtractFinalResultReturnsZeroForNull() { + CountAggregationFunction function = + new CountAggregationFunction(List.of(ExpressionContext.forIdentifier("col")), false); + assertEquals(function.extractFinalResult(null), Long.valueOf(0L)); + assertEquals(function.extractFinalResult(5L), Long.valueOf(5L)); + } + @DataProvider(name = "nullHandlingEnabled") public Object[][] nullHandlingEnabled() { return new Object[][]{