Skip to content

Commit

Permalink
[CALCITE-4726] Support aggregate calls with a FILTER clause in Aggreg…
Browse files Browse the repository at this point in the history
…ateExpandWithinDistinctRule (Will Noble)

Close apache#2483
  • Loading branch information
wnob authored and julianhyde committed Sep 3, 2021
1 parent 51c0d92 commit 8c46299
Show file tree
Hide file tree
Showing 8 changed files with 523 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private SimpleCalciteSchema(@Nullable CalciteSchema parent,
return calciteSchema;
}

private @Nullable String caseInsensitiveLookup(Set<String> candidates, String name) {
private static @Nullable String caseInsensitiveLookup(Set<String> candidates, String name) {
// Exact string lookup
if (candidates.contains(name)) {
return name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.IntPair;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -113,8 +114,6 @@ private static boolean hasWithinDistinct(Aggregate aggregate) {
// Wait until AggregateReduceFunctionsRule has dealt with AVG etc.
&& aggregate.getAggCallList().stream()
.noneMatch(CoreRules.AGGREGATE_REDUCE_FUNCTIONS::canReduce)
// Don't know that we can handle FILTER yet
&& aggregate.getAggCallList().stream().noneMatch(c -> c.filterArg >= 0)
// Don't think we can handle GROUPING SETS yet
&& aggregate.getGroupType() == Aggregate.Group.SIMPLE;
}
Expand All @@ -132,7 +131,7 @@ private static boolean hasWithinDistinct(Aggregate aggregate) {
//
// or in algebra,
//
// Aggregate($0, SUM($2), SUM($3) WITHIN DISTINCT ($4))
// Aggregate($0, SUM($2), SUM($2) WITHIN DISTINCT ($4))
// Scan(emp)
//
// We plan to generate the following:
Expand All @@ -154,8 +153,6 @@ private static boolean hasWithinDistinct(Aggregate aggregate) {
// SUM(sal) WITHIN DISTINCT (sal)
//

// TODO: handle "agg(x) filter (b)"

final List<AggregateCall> aggCallList =
aggregate.getAggCallList()
.stream()
Expand All @@ -179,27 +176,31 @@ private static boolean hasWithinDistinct(Aggregate aggregate) {
// sum(x) within distinct (y, z) ... group by y
// can be simplified to
// sum(x) within distinct (z) ... group by y
// Note that this assumes a single grouping set for the original agg.
distinctKeys = distinctKeys.rebuild()
.removeAll(aggregate.getGroupSet()).build();
}
}
argLists.put(distinctKeys, aggCall);
assert aggCall.filterArg < 0;
}

// Compute the set of all grouping sets that will be used in the output
// query. For each WITHIN DISTINCT aggregate call, we will need a grouping
// set that is the union of the aggregate call's unique keys and the input
// query's overall grouping. Redundant grouping sets can be reused for
// multiple aggregate calls.
final Set<ImmutableBitSet> groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
groupSetTreeSet.add(aggregate.getGroupSet());
for (ImmutableBitSet key : argLists.keySet()) {
if (key == notDistinct) {
continue;
}
groupSetTreeSet.add(
ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
(key == notDistinct)
? aggregate.getGroupSet()
: ImmutableBitSet.of(key).union(aggregate.getGroupSet()));
}

final ImmutableList<ImmutableBitSet> groupSets =
ImmutableList.copyOf(groupSetTreeSet);
final boolean hasMultipleGroupSets = groupSets.size() > 1;
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final Set<Integer> fullGroupOrderedSet = new LinkedHashSet<>();
fullGroupOrderedSet.addAll(aggregate.getGroupSet().asSet());
Expand All @@ -216,40 +217,89 @@ private static boolean hasWithinDistinct(Aggregate aggregate) {
//
// or in algebra,
//
// Aggregate([($0), ($0, $2)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4))
// Aggregate([($0), ($0, $4)], SUM($2), MIN($2), MAX($2), GROUPING($0, $4))
// Scan(emp)

final RelBuilder b = call.builder();
b.push(aggregate.getInput());
final List<RelBuilder.AggCall> aggCalls = new ArrayList<>();

// Helper class for building the inner query.
// CHECKSTYLE: IGNORE 1
class Registrar {
final int g = fullGroupSet.cardinality();
final Map<Integer, Integer> args = new HashMap<>();
/** Map of input fields (below the original aggregation) and filter args
* to inner query aggregate calls. */
final Map<IntPair, Integer> args = new HashMap<>();
/** Map of aggregate calls from the original aggregation to inner query
* aggregate calls. */
final Map<Integer, Integer> aggs = new HashMap<>();

List<Integer> fields(List<Integer> fields) {
return Util.transform(fields, this::field);
/** Map of aggregate calls from the original aggregation to inner-query
* {@code COUNT(*)} calls, which are only needed for filters in the outer
* aggregate when the original aggregate call does not ignore null
* inputs. */
final Map<Integer, Integer> counts = new HashMap<>();

List<Integer> fields(List<Integer> fields, int filterArg) {
return Util.transform(fields, f -> this.field(f, filterArg));
}

int field(int field) {
return Objects.requireNonNull(args.get(field));
int field(int field, int filterArg) {
return Objects.requireNonNull(args.get(IntPair.of(field, filterArg)));
}

int register(int field) {
return args.computeIfAbsent(field, j -> {
/** Computes an aggregate call argument's values for a
* {@code WITHIN DISTINCT} aggregate call.
*
* <p>For example, to compute
* {@code SUM(x) WITHIN DISTINCT (y) GROUP BY (z)},
* the inner aggregate must first group {@code x} by {@code (y, z)}
* &mdash; using {@code MIN} to select the (hopefully) unique value of
* {@code x} for each {@code (y, z)} group. Actually summing over the
* grouped {@code x} values must occur in an outer aggregate.
*
* @param field Index of an input field that's used in a
* {@code WITHIN DISTINCT} aggregate call
* @param filterArg Filter arg used in the original aggregate call, or
* {@code -1} if there is no filter. We use the same filter in
* the inner query.
* @return Index of the inner query aggregate call representing the
* grouped field, which can be referenced in the outer query
* aggregate call
*/
int register(int field, int filterArg) {
return args.computeIfAbsent(IntPair.of(field, filterArg), j -> {
final int ordinal = g + aggCalls.size();
RelBuilder.AggCall groupedField =
b.aggregateCall(SqlStdOperatorTable.MIN, b.field(field));
aggCalls.add(
b.aggregateCall(SqlStdOperatorTable.MIN, b.field(j)));
filterArg < 0
? groupedField
: groupedField.filter(b.field(filterArg)));
if (config.throwIfNotUnique()) {
groupedField =
b.aggregateCall(SqlStdOperatorTable.MAX, b.field(field));
aggCalls.add(
b.aggregateCall(SqlStdOperatorTable.MAX, b.field(j)));
filterArg < 0
? groupedField
: groupedField.filter(b.field(filterArg)));
}
return ordinal;
});
}

/** Registers an aggregate call that is <em>not</em> a
* {@code WITHIN DISTINCT} call.
*
* <p>Unlike the case handled by {@link #register(int, int)} above,
* aggregate calls without any distinct keys do not need a second round
* of aggregation in the outer query, so they can be computed "as-is" in
* the inner query.
*
* @param i Index of the aggregate call in the original aggregation
* @param aggregateCall Original aggregate call
* @return Index of the aggregate call in the computed inner query
*/
int registerAgg(int i, RelBuilder.AggCall aggregateCall) {
final int ordinal = g + aggCalls.size();
aggs.put(i, ordinal);
Expand All @@ -260,6 +310,33 @@ int registerAgg(int i, RelBuilder.AggCall aggregateCall) {
int getAgg(int i) {
return Objects.requireNonNull(aggs.get(i));
}

/** Registers an extra {@code COUNT} aggregate call when it's needed to
* filter out null inputs in the outer aggregate.
*
* <p>This should only be called for aggregate calls with filters. It's
* possible that the filter would eliminate all input rows to the
* {@code MIN} call in the inner query, so calls in the outer
* aggregate may need to be aware of this. See usage of
* {@link AggregateExpandWithinDistinctRule#mustBeCounted(AggregateCall)}.
*
* @param filterArg The original aggregate call's filter; must be
* non-negative
* @return Index of the {@code COUNT} call in the computed inner query
*/
int registerCount(int filterArg) {
assert filterArg >= 0;
return counts.computeIfAbsent(filterArg, i -> {
final int ordinal = g + aggCalls.size();
aggCalls.add(b.aggregateCall(SqlStdOperatorTable.COUNT)
.filter(b.field(filterArg)));
return ordinal;
});
}

int getCount(int filterArg) {
return Objects.requireNonNull(counts.get(filterArg));
}
}

final Registrar registrar = new Registrar();
Expand All @@ -269,13 +346,25 @@ int getAgg(int i) {
b.aggregateCall(c.getAggregation(),
b.fields(c.getArgList())));
} else {
c.getArgList().forEach(registrar::register);
for (int inputIdx : c.getArgList()) {
registrar.register(inputIdx, c.filterArg);
}
if (mustBeCounted(c)) {
registrar.registerCount(c.filterArg);
}
}
});
// Add an additional GROUPING() aggregate call so we can select only the
// relevant inner-aggregate rows from the outer aggregate. If there is only
// 1 grouping set (i.e. every aggregate call has the same distinct keys),
// no GROUPING() call is necessary.
final int grouping =
registrar.registerAgg(-1,
b.aggregateCall(SqlStdOperatorTable.GROUPING,
b.fields(fullGroupList)));
hasMultipleGroupSets
? registrar.registerAgg(-1,
b.aggregateCall(
SqlStdOperatorTable.GROUPING,
b.fields(fullGroupList)))
: -1;
b.aggregate(
b.groupKey(fullGroupSet,
(Iterable<ImmutableBitSet>) groupSets), aggCalls);
Expand Down Expand Up @@ -304,32 +393,56 @@ int getAgg(int i) {
aggCalls.clear();
Ord.forEach(aggCallList, (c, i) -> {
final List<RexNode> filters = new ArrayList<>();
final RexNode groupFilter = b.equals(b.field(grouping),
b.literal(
groupValue(fullGroupList,
union(aggregate.getGroupSet(), c.distinctKeys))));
filters.add(groupFilter);
final RelBuilder.AggCall aggCall;
RexNode groupFilter = null;
if (hasMultipleGroupSets) {
groupFilter =
b.equals(
b.field(grouping),
b.literal(
groupValue(fullGroupList, union(aggregate.getGroupSet(), c.distinctKeys))));
filters.add(groupFilter);
}
RelBuilder.AggCall aggCall;
if (c.distinctKeys == null) {
aggCall = b.aggregateCall(SqlStdOperatorTable.MIN,
b.field(registrar.getAgg(i)));
} else {
aggCall = b.aggregateCall(c.getAggregation(),
b.fields(registrar.fields(c.getArgList())));
// The inputs to this aggregate are outputs from MIN() calls from the
// inner agg, and MIN() returns null iff it has no non-null inputs,
// which can only happen if an original aggregate's filter causes all
// non-null input rows to be discarded for a particular group in the
// inner aggregate. In this case, it should be ignored by the outer
// aggregate as well. In case the aggregate call does not naturally
// ignore null inputs, we add a filter based on a COUNT() in the inner
// aggregate.
aggCall =
b.aggregateCall(
c.getAggregation(),
b.fields(registrar.fields(c.getArgList(), c.filterArg)));

if (mustBeCounted(c)) {
filters.add(b.greaterThan(b.field(registrar.getCount(c.filterArg)), b.literal(0)));
}

if (config.throwIfNotUnique()) {
for (int j : c.getArgList()) {
RexNode isUniqueCondition =
b.isNotDistinctFrom(
b.field(registrar.field(j, c.filterArg)),
b.field(registrar.field(j, c.filterArg) + 1));
if (groupFilter != null) {
isUniqueCondition = b.or(b.not(groupFilter), isUniqueCondition);
}
String message = "more than one distinct value in agg UNIQUE_VALUE";
filters.add(
b.call(SqlInternalOperators.THROW_UNLESS,
b.or(b.not(groupFilter),
b.isNotDistinctFrom(b.field(registrar.field(j)),
b.field(registrar.field(j) + 1))),
b.literal(message)));
b.call(SqlInternalOperators.THROW_UNLESS, isUniqueCondition, b.literal(message)));
}
}
}
aggCalls.add(aggCall.filter(b.and(filters)));
if (filters.size() > 0) {
aggCall = aggCall.filter(b.and(filters));
}
aggCalls.add(aggCall);
});

b.aggregate(
Expand All @@ -342,6 +455,22 @@ int getAgg(int i) {
call.transformTo(b.build());
}

private static boolean mustBeCounted(AggregateCall aggCall) {
// Always count filtered inner aggregates to be safe.
//
// It's possible that, for some aggregate calls (namely, those that
// completely ignore null inputs), we could neglect counting the
// grouped-and-filtered rows of the inner aggregate and filtering the empty
// ones out from the outer aggregate, since those empty groups would produce
// null values as the result of MIN and thus be ignored by the outer
// aggregate anyway.
//
// Note that using "aggCall.ignoreNulls()" is not sufficient to determine
// when it's safe to do this, since for COUNT the value of ignoreNulls()
// should generally be true even though COUNT(*) will never ignore anything.
return aggCall.hasFilter();
}

/** Converts a {@code DISTINCT} aggregate call into an equivalent one with
* {@code WITHIN DISTINCT}.
*
Expand All @@ -362,6 +491,7 @@ private static AggregateCall unDistinct(AggregateCall aggregateCall,
.stream()
.filter(i ->
aggregateCall.getAggregation().getKind() != SqlKind.COUNT
|| aggregateCall.hasFilter()
|| isNullable.test(i))
.collect(Collectors.toList());
return aggregateCall.withDistinct(false)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/calcite/rex/RexSimplify.java
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ private RexNode simplifyGenericNode(RexCall e) {
* Try to find a literal with the given value in the input list.
* The type of the literal must be one of the numeric types.
*/
private int findLiteralIndex(List<RexNode> operands, BigDecimal value) {
private static int findLiteralIndex(List<RexNode> operands, BigDecimal value) {
for (int i = 0; i < operands.size(); i++) {
if (operands.get(i).isA(SqlKind.LITERAL)) {
Comparable comparable = ((RexLiteral) operands.get(i)).getValue();
Expand Down
16 changes: 10 additions & 6 deletions core/src/main/java/org/apache/calcite/tools/RelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,11 @@ public RexNode equals(RexNode operand0, RexNode operand1) {
return call(SqlStdOperatorTable.EQUALS, operand0, operand1);
}

/** Creates a {@code >}. */
public RexNode greaterThan(RexNode operand0, RexNode operand1) {
return call(SqlStdOperatorTable.GREATER_THAN, operand0, operand1);
}

/** Creates a {@code <>}. */
public RexNode notEquals(RexNode operand0, RexNode operand1) {
return call(SqlStdOperatorTable.NOT_EQUALS, operand0, operand1);
Expand Down Expand Up @@ -3543,13 +3548,12 @@ private class AggCallImpl implements AggCallPlus {
if (distinct) {
b.append("DISTINCT ");
}
final int iMax = operands.size() - 1;
for (int i = 0; ; i++) {
b.append(operands.get(i));
if (i == iMax) {
break;
if (operands.size() > 0) {
b.append(operands.get(0));
for (int i = 1; i < operands.size(); i++) {
b.append(", ");
b.append(operands.get(i));
}
b.append(", ");
}
b.append(')');
if (filter != null) {
Expand Down
Loading

0 comments on commit 8c46299

Please sign in to comment.