Skip to content

Commit

Permalink
p2
Browse files Browse the repository at this point in the history
  • Loading branch information
kgyrtkirk committed Oct 4, 2023
1 parent fc7e491 commit 5cbfc20
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
Expand Down Expand Up @@ -60,17 +61,17 @@ public abstract class BaseVarianceSqlAggregator implements SqlAggregator
private static final String STDDEV_NAME = "STDDEV";

private static final SqlAggFunction VARIANCE_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(VARIANCE_NAME);
buildSqlVarianceAggFunction(VARIANCE_NAME);
private static final SqlAggFunction VARIANCE_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_POP.name());
buildSqlVarianceAggFunction(SqlKind.VAR_POP.name());
private static final SqlAggFunction VARIANCE_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.VAR_SAMP.name());
buildSqlVarianceAggFunction(SqlKind.VAR_SAMP.name());
private static final SqlAggFunction STDDEV_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(STDDEV_NAME);
buildSqlVarianceAggFunction(STDDEV_NAME);
private static final SqlAggFunction STDDEV_POP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_POP.name());
buildSqlVarianceAggFunction(SqlKind.STDDEV_POP.name());
private static final SqlAggFunction STDDEV_SAMP_SQL_AGG_FUNC_INSTANCE =
buildSqlAvgAggFunction(SqlKind.STDDEV_SAMP.name());
buildSqlVarianceAggFunction(SqlKind.STDDEV_SAMP.name());

@Nullable
@Override
Expand Down Expand Up @@ -160,14 +161,15 @@ public Aggregation toDruidAggregation(
}

/**
* Creates a {@link SqlAggFunction} that is the same as {@link org.apache.calcite.sql.fun.SqlAvgAggFunction}
* but with an operand type that accepts variance aggregator objects in addition to numeric inputs.
* Creates a {@link SqlAggFunction}
*
* It accepts variance aggregator objects in addition to numeric inputs.
*/
private static SqlAggFunction buildSqlAvgAggFunction(String name)
private static SqlAggFunction buildSqlVarianceAggFunction(String name)
{
return OperatorConversions
.aggregatorBuilder(name)
.returnTypeInference(ReturnTypes.AVG_AGG_FUNCTION)
.returnTypeInference(ReturnTypes.explicit(SqlTypeName.DOUBLE))
.operandTypeChecker(
OperandTypes.or(
OperandTypes.NUMERIC,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ public void testVarPop()
final List<Object[]> expectedResults = ImmutableList.of(
new Object[]{
holder1.getVariance(true),
holder2.getVariance(true).doubleValue(),
holder3.getVariance(true).longValue()
holder2.getVariance(true),
holder3.getVariance(true)
}
);
testQuery(
Expand Down Expand Up @@ -219,7 +219,7 @@ public void testVarSamp()
new Object[] {
holder1.getVariance(false),
holder2.getVariance(false).doubleValue(),
holder3.getVariance(false).longValue(),
holder3.getVariance(false),
}
);
testQuery(
Expand Down Expand Up @@ -266,7 +266,7 @@ public void testStdDevPop()
new Object[] {
Math.sqrt(holder1.getVariance(true)),
Math.sqrt(holder2.getVariance(true)),
(long) Math.sqrt(holder3.getVariance(true)),
Math.sqrt(holder3.getVariance(true)),
}
);

Expand Down Expand Up @@ -321,7 +321,7 @@ public void testStdDevSamp()
new Object[]{
Math.sqrt(holder1.getVariance(false)),
Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
Math.sqrt(holder3.getVariance(false)),
}
);

Expand Down Expand Up @@ -374,7 +374,7 @@ public void testStdDevWithVirtualColumns()
new Object[]{
Math.sqrt(holder1.getVariance(false)),
Math.sqrt(holder2.getVariance(false)),
(long) Math.sqrt(holder3.getVariance(false)),
Math.sqrt(holder3.getVariance(false)),
}
);

Expand Down Expand Up @@ -543,7 +543,7 @@ public void testEmptyTimeseriesResults()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
? new Object[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
: new Object[]{null, null, null, null, null, null, null, null}
)
);
Expand Down Expand Up @@ -623,7 +623,7 @@ public void testGroupByAggregatorDefaultValues()
),
ImmutableList.of(
NullHandling.replaceWithDefault()
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0L, 0L, 0L, 0L}
? new Object[]{"a", 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}
: new Object[]{"a", null, null, null, null, null, null, null, null}
)
);
Expand Down Expand Up @@ -688,9 +688,9 @@ public void assertResultsEquals(String sql, List<Object[]> expectedResults, List
Assert.assertEquals(expectedResult.length, result.length);
for (int j = 0; j < expectedResult.length; j++) {
if (expectedResult[j] instanceof Float) {
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-10);
Assert.assertEquals((Float) expectedResult[j], (Float) result[j], 1e-5);
} else if (expectedResult[j] instanceof Double) {
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-10);
Assert.assertEquals((Double) expectedResult[j], (Double) result[j], 1e-5);
} else {
Assert.assertEquals(expectedResult[j], result[j]);
}
Expand Down

0 comments on commit 5cbfc20

Please sign in to comment.