diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java index 0b1562eb83d1..ee8c469c3b86 100644 --- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java +++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java @@ -30,6 +30,7 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.query.aggregation.AggregatorFactory; @@ -60,17 +61,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator private static final String STDDEV_NAME = "STDDEV"; private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(VARIANCE_NAME); + buildSqlVarianceAggFunction(VARIANCE_NAME); private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.VAR_POP.name()); + buildSqlVarianceAggFunction(SqlKind.VAR_POP.name()); private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name()); + buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name()); private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(STDDEV_NAME); + buildSqlVarianceAggFunction(STDDEV_NAME); private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name()); + buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name()); private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE = - buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name()); + buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name()); @Nullable @Override @@ -160,14 +161,15 @@ public Aggregation toDruidAggregation( } /** - * Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction} - * but with an operand type that accepts variance aggregator objects in addition to numeric inputs. + * Creates a {@link SqlAggFunction} + * + * It accepts variance aggregator objects in addition to numeric inputs. */ - private static SqlAggFunction buildSqlAvgAggFunction(String name) + private static SqlAggFunction buildSqlVarianceAggFunction(String name) { return OperatorConversions .aggregatorBuilder(name) - .returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION) + .returnTypeInference(ReturnTypes.explicit(SqlTypeName.DOUBLE)) .operandTypeChecker( OperandTypes.or( OperandTypes.NUMERIC, diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java index fe68b2737ef3..e45a93784967 100644 --- a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java +++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java @@ -171,8 +171,8 @@ public void testVarPop() final List expectedResults = ImmutableList.of( new Object[]{ holder1.getVariance(true), - holder2.getVariance(true).doubleValue(), - holder3.getVariance(true).longValue() + holder2.getVariance(true), + holder3.getVariance(true) } ); testQuery( @@ -219,7 +219,7 @@ public void testVarSamp() new Object[] { holder1.getVariance(false), holder2.getVariance(false).doubleValue(), - holder3.getVariance(false).longValue(), + holder3.getVariance(false), } ); testQuery( @@ -266,7 +266,7 @@ public void testStdDevPop() new Object[] { Math.sqrt(holder1.getVariance(true)), Math.sqrt(holder2.getVariance(true)), - (long) Math.sqrt(holder3.getVariance(true)), + Math.sqrt(holder3.getVariance(true)), } ); @@ -321,7 +321,7 @@ public void testStdDevSamp() new Object[]{ Math.sqrt(holder1.getVariance(false)), Math.sqrt(holder2.getVariance(false)), - (long) Math.sqrt(holder3.getVariance(false)), + Math.sqrt(holder3.getVariance(false)), } ); @@ -374,7 +374,7 @@ public void testStdDevWithVirtualColumns() new Object[]{ Math.sqrt(holder1.getVariance(false)), Math.sqrt(holder2.getVariance(false)), - (long) Math.sqrt(holder3.getVariance(false)), + Math.sqrt(holder3.getVariance(false)), } ); @@ -543,7 +543,7 @@ public void testEmptyTimeseriesResults() ), ImmutableList.of( NullHandling.replaceWithDefault() - ? new Object[]{0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L} + ? new Object[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0} : new Object[]{null, null, null, null, null, null, null, null} ) ); @@ -623,7 +623,7 @@ public void testGroupByAggregatorDefaultValues() ), ImmutableList.of( NullHandling.replaceWithDefault() - ? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L} + ? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0} : new Object[]{"a", null, null, null, null, null, null, null, null} ) ); @@ -688,9 +688,9 @@ public void assertResultsEquals(String sql, List expectedResults, List Assert.assertEquals(expectedResult.length, result.length); for (int j = 0; j < expectedResult.length; j++) { if (expectedResult[j] instanceof Float) { - Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10); + Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5); } else if (expectedResult[j] instanceof Double) { - Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10); + Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5); } else { Assert.assertEquals(expectedResult[j], result[j]); }