Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.common.BlockValSet;
Expand Down Expand Up @@ -204,8 +205,8 @@ public ColumnDataType getFinalResultColumnType() {
}

@Override
public Long extractFinalResult(Long intermediateResult) {
return intermediateResult;
public Long extractFinalResult(@Nullable Long intermediateResult) {
return intermediateResult != null ? intermediateResult : 0L;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

package org.apache.pinot.core.query.aggregation.function;

import java.util.List;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.queries.FluentQueryTest;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import static org.testng.Assert.assertEquals;


public class CountAggregationFunctionTest extends AbstractAggregationFunctionTest {

Expand Down Expand Up @@ -234,6 +238,14 @@ public void countGroupByMV() {
);
}

@Test
public void testExtractFinalResultReturnsZeroForNull() {
CountAggregationFunction function =
new CountAggregationFunction(List.of(ExpressionContext.forIdentifier("col")), false);
assertEquals(function.extractFinalResult(null), Long.valueOf(0L));
assertEquals(function.extractFinalResult(5L), Long.valueOf(5L));
}

@DataProvider(name = "nullHandlingEnabled")
public Object[][] nullHandlingEnabled() {
return new Object[][]{
Expand Down
Loading