Skip to content

Commit

Permalink
Fix summary row issues in case postaggregations are happening (apache…
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshSingla committed Oct 25, 2023
1 parent 701b9af commit cae81a7
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ public void testQueryWithMoreThanMaxNumericInFilter()

}

@Ignore
@Override
public void testUnSupportedNullsFirst()
{
}

@Ignore
@Override
public void testUnSupportedNullsLast()
{
}

/**
* Same query as {@link CalciteQueryTest#testArrayAggQueryOnComplexDatatypes}. ARRAY_AGG is not supported in MSQ currently.
* Once support is added, this test can be removed and msqCompatible() can be added to the one in CalciteQueryTest.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,11 +788,19 @@ private static boolean summaryRowPreconditions(GroupByQuery query)
private static Iterator<ResultRow> summaryRowIterator(GroupByQuery q)
{
List<AggregatorFactory> aggSpec = q.getAggregatorSpecs();
Object[] values = new Object[aggSpec.size()];
ResultRow resultRow = ResultRow.create(q.getResultRowSizeWithPostAggregators());
for (int i = 0; i < aggSpec.size(); i++) {
values[i] = aggSpec.get(i).factorize(new AllNullColumnSelectorFactory()).get();
resultRow.set(i, aggSpec.get(i).factorize(new AllNullColumnSelectorFactory()).get());
}
return Collections.singleton(ResultRow.of(values)).iterator();
Map<String, Object> map = resultRow.toMap(q);
for (int i = 0; i < q.getPostAggregatorSpecs().size(); i++) {
final PostAggregator postAggregator = q.getPostAggregatorSpecs().get(i);
final Object value = postAggregator.compute(map);

resultRow.set(q.getResultRowPostAggregatorStart() + i, value);
map.put(postAggregator.getName(), value);
}
return Collections.singleton(resultRow).iterator();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12962,6 +12962,9 @@ public void testSummaryrowForEmptyInput()
new FloatSumAggregatorFactory("idxFloat", "indexFloat"),
new DoubleSumAggregatorFactory("idxDouble", "index")
)
.setPostAggregatorSpecs(
ImmutableList.of(
new ExpressionPostAggregator("post", "idx * 2", null, TestExprMacroTable.INSTANCE)))
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();

Expand All @@ -12976,7 +12979,9 @@ public void testSummaryrowForEmptyInput()
"idxFloat",
NullHandling.replaceWithDefault() ? 0.0 : null,
"idxDouble",
NullHandling.replaceWithDefault() ? 0.0 : null
NullHandling.replaceWithDefault() ? 0.0 : null,
"post",
NullHandling.replaceWithDefault() ? 0L : null
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ private SqlValidator createSqlValidator(CalciteCatalogReader catalogReader)
opTab,
catalogReader,
getTypeFactory(),
validatorConfig
validatorConfig,
context.unwrapOrThrow(PlannerContext.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,41 @@
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.run.EngineFeature;

/**
* Druid extended SQL validator. (At present, it doesn't actually
* have any extensions yet, but it will soon.)
*/
class DruidSqlValidator extends BaseDruidSqlValidator
{
private final PlannerContext plannerContext;

protected DruidSqlValidator(
SqlOperatorTable opTab,
CalciteCatalogReader catalogReader,
JavaTypeFactory typeFactory,
Config validatorConfig
Config validatorConfig,
PlannerContext plannerContext
)
{
super(opTab, catalogReader, typeFactory, validatorConfig);
this.plannerContext = plannerContext;
}

@Override
public void validateCall(SqlCall call, SqlValidatorScope scope)
{
if (call.getKind() == SqlKind.OVER) {
if (!plannerContext.featureAvailable(EngineFeature.WINDOW_FUNCTIONS)) {
throw buildCalciteContextException(
StringUtils.format(
"The query contains window functions; To run these window functions, enable [%s] in query context.",
EngineFeature.WINDOW_FUNCTIONS),
call);
}
}
if (call.getKind() == SqlKind.NULLS_FIRST) {
SqlNode op0 = call.getOperandList().get(0);
if (op0.getKind() == SqlKind.DESCENDING) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,12 @@ public SqlConformance conformance()
return DruidConformance.instance();
}
};
} else {
return null;
}
if (aClass.equals(PlannerContext.class)) {
return (C) plannerContext;
}

return null;
}
});

Expand Down
144 changes: 144 additions & 0 deletions sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14275,4 +14275,148 @@ public void testUnSupportedNullsLast()
.run());
assertThat(e, invalidSqlIs("ASCENDING ordering with NULLS LAST is not supported! (line [1], column [41])"));
}

@Test
public void testWindowingErrorWithoutFeatureFlag()
{
DruidException e = assertThrows(DruidException.class, () -> testBuilder()
.queryContext(ImmutableMap.of(PlannerContext.CTX_ENABLE_WINDOW_FNS, false))
.sql("SELECT dim1,ROW_NUMBER() OVER () from druid.foo")
.run());

assertThat(e, invalidSqlIs("The query contains window functions; To run these window functions, enable [WINDOW_FUNCTIONS] in query context. (line [1], column [13])"));
}

@Test
public void testInGroupByLimitOutGroupByOrderBy()
{
skipVectorize();
cannotVectorize();
testQuery(
"with t AS (SELECT m2, COUNT(m1) as trend_score\n"
+ "FROM \"foo\"\n"
+ "GROUP BY 1 \n"
+ "LIMIT 10\n"
+ ")\n"
+ "select m2, (MAX(trend_score)) from t\n"
+ "where m2 > 2\n"
+ "GROUP BY 1 \n"
+ "ORDER BY 2 DESC",
QUERY_CONTEXT_DEFAULT,
ImmutableList.of(
new GroupByQuery.Builder()
.setDataSource(
new TopNQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.dimension(new DefaultDimensionSpec("m2", "d0", ColumnType.DOUBLE))
.threshold(10)
.aggregators(aggregators(
useDefault
? new CountAggregatorFactory("a0")
: new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
notNull("m1")
)
))
.metric(new DimensionTopNMetricSpec(null, StringComparators.NUMERIC))
.context(OUTER_LIMIT_CONTEXT)
.build()
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
new DefaultDimensionSpec("d0", "_d0", ColumnType.DOUBLE)
)
.setDimFilter(
useDefault ?
bound("d0", "2", null, true, false, null, StringComparators.NUMERIC) :
new RangeFilter("d0", ColumnType.LONG, 2L, null, true, false, null)
)
.setAggregatorSpecs(aggregators(
new LongMaxAggregatorFactory("_a0", "a0")
))
.setLimitSpec(
DefaultLimitSpec
.builder()
.orderBy(new OrderByColumnSpec("_a0", Direction.DESCENDING, StringComparators.NUMERIC))
.build()
)
.setContext(OUTER_LIMIT_CONTEXT)
.build()
),
ImmutableList.of(
new Object[]{3.0D, 1L},
new Object[]{4.0D, 1L},
new Object[]{5.0D, 1L},
new Object[]{6.0D, 1L}
)
);
}

@Test
public void testInGroupByOrderByLimitOutGroupByOrderByLimit()
{
skipVectorize();
cannotVectorize();
testQuery(
"with t AS (SELECT m2 as mo, COUNT(m1) as trend_score\n"
+ "FROM \"foo\"\n"
+ "GROUP BY 1\n"
+ "ORDER BY trend_score DESC\n"
+ "LIMIT 10)\n"
+ "select mo, (MAX(trend_score)) from t\n"
+ "where mo > 2\n"
+ "GROUP BY 1 \n"
+ "ORDER BY 2 DESC LIMIT 2\n",
QUERY_CONTEXT_DEFAULT,
ImmutableList.of(
new GroupByQuery.Builder()
.setDataSource(
new TopNQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.dimension(new DefaultDimensionSpec("m2", "d0", ColumnType.DOUBLE))
.threshold(10)
.aggregators(aggregators(
useDefault
? new CountAggregatorFactory("a0")
: new FilteredAggregatorFactory(
new CountAggregatorFactory("a0"),
notNull("m1")
)
))
.metric(new NumericTopNMetricSpec("a0"))
.context(OUTER_LIMIT_CONTEXT)
.build()
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(
new DefaultDimensionSpec("d0", "_d0", ColumnType.DOUBLE)
)
.setDimFilter(
useDefault ?
bound("d0", "2", null, true, false, null, StringComparators.NUMERIC) :
new RangeFilter("d0", ColumnType.LONG, 2L, null, true, false, null)
)
.setAggregatorSpecs(aggregators(
new LongMaxAggregatorFactory("_a0", "a0")
))
.setLimitSpec(
DefaultLimitSpec
.builder()
.orderBy(new OrderByColumnSpec("_a0", Direction.DESCENDING, StringComparators.NUMERIC))
.limit(2)
.build()
)
.setContext(OUTER_LIMIT_CONTEXT)
.build()
),
ImmutableList.of(
new Object[]{3.0D, 1L},
new Object[]{4.0D, 1L}
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
package org.apache.druid.sql.calcite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.sql.calcite.NotYetSupported.Modes;
import org.apache.druid.sql.calcite.NotYetSupported.NotYetSupportedProcessor;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.junit.Rule;
import org.junit.Test;

Expand Down Expand Up @@ -53,6 +55,7 @@ public void testTasksSumOver()
msqIncompatible();

testBuilder()
.queryContext(ImmutableMap.of(PlannerContext.CTX_ENABLE_WINDOW_FNS, true))
.sql("select datasource, sum(duration) over () from sys.tasks group by datasource")
.expectedResults(ImmutableList.of(
new Object[]{"foo", 11L},
Expand Down

0 comments on commit cae81a7

Please sign in to comment.