Skip to content

Commit

Permalink
Preemptive restriction for queries with approximate count distinct on…
Browse files Browse the repository at this point in the history
… complex columns of unsupported type (#16682)

This PR aims to check if the complex column being queried aligns with the supported types in the aggregator and aggregator factories, and throws a user-friendly error message if they don't.
  • Loading branch information
Akshat-Jain authored Jul 22, 2024
1 parent 149d7c5 commit 6a2348b
Show file tree
Hide file tree
Showing 15 changed files with 516 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.hll.TgtHllType;
import org.apache.datasketches.hll.Union;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.StringEncoding;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
Expand All @@ -34,7 +35,9 @@
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -107,6 +110,8 @@ protected byte getCacheTypeId()
@Override
public Aggregator factorize(final ColumnSelectorFactory columnSelectorFactory)
{
validateInputs(columnSelectorFactory.getColumnCapabilities(getFieldName()));

final ColumnValueSelector<HllSketchHolder> selector = columnSelectorFactory.makeColumnValueSelector(getFieldName());
return new HllSketchMergeAggregator(selector, getLgK(), TgtHllType.valueOf(getTgtHllType()));
}
Expand All @@ -115,6 +120,8 @@ public Aggregator factorize(final ColumnSelectorFactory columnSelectorFactory)
@Override
public BufferAggregator factorizeBuffered(final ColumnSelectorFactory columnSelectorFactory)
{
validateInputs(columnSelectorFactory.getColumnCapabilities(getFieldName()));

final ColumnValueSelector<HllSketchHolder> selector = columnSelectorFactory.makeColumnValueSelector(getFieldName());
return new HllSketchMergeBufferAggregator(
selector,
Expand All @@ -133,6 +140,7 @@ public boolean canVectorize(ColumnInspector columnInspector)
@Override
public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
{
validateInputs(selectorFactory.getColumnCapabilities(getFieldName()));
return new HllSketchMergeVectorAggregator(
selectorFactory,
getFieldName(),
Expand All @@ -142,6 +150,34 @@ public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFact
);
}

/**
* Validates whether the aggregator supports the input column type.
* Supported column types are complex types of HLLSketch, HLLSketchBuild, HLLSketchMerge, as well as UNKNOWN_COMPLEX.
* @param capabilities
*/
private void validateInputs(@Nullable ColumnCapabilities capabilities)
{
if (capabilities != null) {
final ColumnType type = capabilities.toColumnType();
boolean isSupportedComplexType = ValueType.COMPLEX.equals(type.getType()) &&
(
HllSketchModule.TYPE_NAME.equals(type.getComplexTypeName()) ||
HllSketchModule.BUILD_TYPE_NAME.equals(type.getComplexTypeName()) ||
HllSketchModule.MERGE_TYPE_NAME.equals(type.getComplexTypeName()) ||
type.getComplexTypeName() == null
);
if (!isSupportedComplexType) {
throw DruidException.forPersona(DruidException.Persona.USER)
.ofCategory(DruidException.Category.UNSUPPORTED)
.build(
"Using aggregator [%s] is not supported for complex columns with type [%s].",
getIntermediateType().getComplexTypeName(),
type
);
}
}
}

@Override
public int getMaxIntermediateSize()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,68 @@

import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.StringEncoding;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchBuildAggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchMergeAggregatorFactory;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.table.RowSignatures;

import java.util.Collections;

/**
* Approximate count distinct aggregator using HLL sketches.
* Supported column types: String, Numeric, HLLSketchMerge, HLLSketchBuild.
*/
public class HllSketchApproxCountDistinctSqlAggregator extends HllSketchBaseSqlAggregator implements SqlAggregator
{
public static final String NAME = "APPROX_COUNT_DISTINCT_DS_HLL";

private static final SqlSingleOperandTypeChecker AGGREGATED_COLUMN_TYPE_CHECKER = OperandTypes.or(
OperandTypes.STRING,
OperandTypes.NUMERIC,
RowSignatures.complexTypeChecker(HllSketchMergeAggregatorFactory.TYPE),
RowSignatures.complexTypeChecker(HllSketchBuildAggregatorFactory.TYPE)
);

private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "lgK", "tgtHllType")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.STRING)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1, 2)
.operandTypeChecker(
OperandTypes.or(
// APPROX_COUNT_DISTINCT_DS_HLL(column)
AGGREGATED_COLUMN_TYPE_CHECKER,
// APPROX_COUNT_DISTINCT_DS_HLL(column, lgk)
OperandTypes.and(
OperandTypes.sequence(
StringUtils.format("'%s(column, lgk)'", NAME),
AGGREGATED_COLUMN_TYPE_CHECKER,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
),
// APPROX_COUNT_DISTINCT_DS_HLL(column, lgk, tgtHllType)
OperandTypes.and(
OperandTypes.sequence(
StringUtils.format("'%s(column, lgk, tgtHllType)'", NAME),
AGGREGATED_COLUMN_TYPE_CHECKER,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL,
OperandTypes.STRING
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC, SqlTypeFamily.STRING)
)
)
)
.returnTypeNonNull(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.NUMERIC)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchAggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchBuildAggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchMergeAggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.hll.HllSketchModule;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.segment.column.ColumnType;
Expand All @@ -40,6 +41,7 @@
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
Expand Down Expand Up @@ -115,7 +117,7 @@ public Aggregation toDruidAggregation(
if (columnArg.isDirectColumnAccess()
&& inputAccessor.getInputRowSignature()
.getColumnType(columnArg.getDirectColumn())
.map(type -> type.is(ValueType.COMPLEX))
.map(this::isValidComplexInputType)
.orElse(false)) {
aggregatorFactory = new HllSketchMergeAggregatorFactory(
aggregatorName,
Expand Down Expand Up @@ -154,6 +156,15 @@ public Aggregation toDruidAggregation(
}

if (inputType.is(ValueType.COMPLEX)) {
if (!isValidComplexInputType(inputType)) {
plannerContext.setPlanningError(
"Using APPROX_COUNT_DISTINCT() or enabling approximation with COUNT(DISTINCT) is not supported for"
+ " column type [%s]. You can disable approximation by setting [%s: false] in the query context.",
columnArg.getDruidType(),
PlannerConfig.CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT
);
return null;
}
aggregatorFactory = new HllSketchMergeAggregatorFactory(
aggregatorName,
dimensionSpec.getOutputName(),
Expand Down Expand Up @@ -192,4 +203,11 @@ protected abstract Aggregation toAggregation(
boolean finalizeAggregations,
AggregatorFactory aggregatorFactory
);

private boolean isValidComplexInputType(ColumnType columnType)
{
return HllSketchMergeAggregatorFactory.TYPE.equals(columnType) ||
HllSketchModule.TYPE_NAME.equals(columnType.getComplexTypeName()) ||
HllSketchModule.BUILD_TYPE_NAME.equals(columnType.getComplexTypeName());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.apache.datasketches.theta.SetOperation;
import org.apache.datasketches.theta.Union;
import org.apache.datasketches.thetacommon.ThetaUtil;
import org.apache.druid.error.InvalidInput;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator;
Expand All @@ -41,6 +41,7 @@
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.vector.VectorColumnSelectorFactory;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -80,21 +81,15 @@ public SketchAggregatorFactory(String name, String fieldName, Integer size, byte
@Override
public Aggregator factorize(ColumnSelectorFactory metricFactory)
{
ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
if (capabilities != null && capabilities.isArray()) {
throw InvalidInput.exception("ARRAY types are not supported for theta sketch");
}
validateInputs(metricFactory.getColumnCapabilities(fieldName));
BaseObjectColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName);
return new SketchAggregator(selector, size);
}

@Override
public AggregatorAndSize factorizeWithSize(ColumnSelectorFactory metricFactory)
{
ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
if (capabilities != null && capabilities.isArray()) {
throw InvalidInput.exception("ARRAY types are not supported for theta sketch");
}
validateInputs(metricFactory.getColumnCapabilities(fieldName));
BaseObjectColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName);
final SketchAggregator aggregator = new SketchAggregator(selector, size);
return new AggregatorAndSize(aggregator, aggregator.getInitialSizeBytes());
Expand All @@ -104,20 +99,49 @@ public AggregatorAndSize factorizeWithSize(ColumnSelectorFactory metricFactory)
@Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{
ColumnCapabilities capabilities = metricFactory.getColumnCapabilities(fieldName);
if (capabilities != null && capabilities.isArray()) {
throw InvalidInput.exception("ARRAY types are not supported for theta sketch");
}
validateInputs(metricFactory.getColumnCapabilities(fieldName));
BaseObjectColumnValueSelector selector = metricFactory.makeColumnValueSelector(fieldName);
return new SketchBufferAggregator(selector, size, getMaxIntermediateSizeWithNulls());
}

@Override
public VectorAggregator factorizeVector(VectorColumnSelectorFactory selectorFactory)
{
validateInputs(selectorFactory.getColumnCapabilities(fieldName));
return new SketchVectorAggregator(selectorFactory, fieldName, size, getMaxIntermediateSizeWithNulls());
}

/**
* Validates whether the aggregator supports the input column type.
* Unsupported column types are:
* <ul>
* <li>Arrays</li>
* <li>Complex types of thetaSketch, thetaSketchMerge, thetaSketchBuild.</li>
* </ul>
* @param capabilities
*/
private void validateInputs(@Nullable ColumnCapabilities capabilities)
{
if (capabilities != null) {
boolean isUnsupportedComplexType = capabilities.is(ValueType.COMPLEX) && !(
SketchModule.THETA_SKETCH_TYPE.equals(capabilities.toColumnType()) ||
SketchModule.MERGE_TYPE.equals(capabilities.toColumnType()) ||
SketchModule.BUILD_TYPE.equals(capabilities.toColumnType())
);

if (capabilities.isArray() || isUnsupportedComplexType) {
throw DruidException.forPersona(DruidException.Persona.USER)
.ofCategory(DruidException.Category.UNSUPPORTED)
.build(
"Unsupported input [%s] of type [%s] for aggregator [%s].",
getFieldName(),
capabilities.asTypeString(),
getIntermediateType()
);
}
}
}

@Override
public boolean canVectorize(ColumnInspector columnInspector)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,55 @@

import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.type.CastedLiteralOperandTypeCheckers;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.theta.SketchModule;
import org.apache.druid.query.aggregation.post.FinalizingFieldAccessPostAggregator;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.table.RowSignatures;

import java.util.Collections;

/**
* Approximate count distinct aggregator using theta sketches.
* Supported column types: String, Numeric, Theta Sketch.
*/
public class ThetaSketchApproxCountDistinctSqlAggregator extends ThetaSketchBaseSqlAggregator implements SqlAggregator
{
public static final String NAME = "APPROX_COUNT_DISTINCT_DS_THETA";

private static final SqlSingleOperandTypeChecker AGGREGATED_COLUMN_TYPE_CHECKER = OperandTypes.or(
OperandTypes.STRING,
OperandTypes.NUMERIC,
RowSignatures.complexTypeChecker(SketchModule.THETA_SKETCH_TYPE)
);

private static final SqlAggFunction FUNCTION_INSTANCE =
OperatorConversions.aggregatorBuilder(NAME)
.operandNames("column", "size")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.operandTypeInference(InferTypes.VARCHAR_1024)
.requiredOperandCount(1)
.literalOperands(1)
.operandTypeChecker(
OperandTypes.or(
// APPROX_COUNT_DISTINCT_DS_THETA(expr)
AGGREGATED_COLUMN_TYPE_CHECKER,
// APPROX_COUNT_DISTINCT_DS_THETA(expr, size)
OperandTypes.and(
OperandTypes.sequence(
StringUtils.format("'%s(expr, size)'", NAME),
AGGREGATED_COLUMN_TYPE_CHECKER,
CastedLiteralOperandTypeCheckers.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
)
)
)
.returnTypeNonNull(SqlTypeName.BIGINT)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
Expand Down
Loading

0 comments on commit 6a2348b

Please sign in to comment.