Skip to content

Commit

Permalink
Allow casted literal values in SQL functions accepting literals (apac…
Browse files Browse the repository at this point in the history
…he#15282)

Functions that accept literals also allow casted literals. This shouldn't have an impact on the queries that the user writes. It enables the SQL functions to accept explicit cast, which is required with JDBC.
  • Loading branch information
LakshSingla authored Nov 1, 2023
1 parent 49e0cba commit 2ea7177
Show file tree
Hide file tree
Showing 23 changed files with 540 additions and 448 deletions.
1 change: 1 addition & 0 deletions codestyle/druid-forbidden-apis.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ java.util.LinkedList @ Use ArrayList or ArrayDeque instead
java.util.Random#<init>() @ Use ThreadLocalRandom.current() or the constructor with a seed (the latter in tests only!)
java.lang.Math#random() @ Use ThreadLocalRandom.current()
java.util.regex.Pattern#matches(java.lang.String,java.lang.CharSequence) @ Use String.startsWith(), endsWith(), contains(), or compile and cache a Pattern explicitly
org.apache.calcite.sql.type.OperandTypes#LITERAL @ LITERAL type checker throws when literals with CAST are passed. Use org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker instead.
org.apache.commons.io.FileUtils#getTempDirectory() @ Use org.junit.rules.TemporaryFolder for tests instead
org.apache.commons.io.FileUtils#deleteDirectory(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils#deleteDirectory()
org.apache.commons.io.FileUtils#forceMkdir(java.io.File) @ Use org.apache.druid.java.util.common.FileUtils.mkdirp instead
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.tdigestsketch.TDigestSketchAggregatorFactory;
Expand All @@ -37,6 +37,7 @@
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
Expand Down Expand Up @@ -133,8 +134,6 @@ public Aggregation toDruidAggregation(

private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE_WITH_COMPRESSION = "'" + NAME + "(column, compression)'";

TDigestGenerateSketchSqlAggFunction()
{
super(
Expand All @@ -143,16 +142,19 @@ private static class TDigestGenerateSketchSqlAggFunction extends SqlAggFunction
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
null,
OperandTypes.or(
OperandTypes.ANY,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE_WITH_COMPRESSION, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
)
),
// Validation for signatures like 'TDIGEST_GENERATE_SKETCH(column)' and
// 'TDIGEST_GENERATE_SKETCH(column, compression)'
DefaultOperandTypeChecker
.builder()
.operandNames("column", "compression")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.requiredOperandCount(1)
.literalOperands(1)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
false,
Optionality.FORBIDDEN
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
Expand All @@ -40,6 +40,7 @@
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.Aggregations;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.InputAccessor;
Expand Down Expand Up @@ -158,9 +159,6 @@ public Aggregation toDruidAggregation(

private static class TDigestSketchQuantileSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column, quantile)'";
private static final String SIGNATURE2 = "'" + NAME + "(column, quantile, compression)'";

TDigestSketchQuantileSqlAggFunction()
{
super(
Expand All @@ -169,19 +167,18 @@ private static class TDigestSketchQuantileSqlAggFunction extends SqlAggFunction
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.DOUBLE),
null,
OperandTypes.or(
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.LITERAL, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
)
),
// Accounts for both 'TDIGEST_QUANTILE(column, quantile)' and 'TDIGEST_QUANTILE(column, quantile, compression)'
DefaultOperandTypeChecker
.builder()
.operandNames("column", "quantile", "compression")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)
.literalOperands(1, 2)
.requiredOperandCount(2)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
false,
Optionality.FORBIDDEN
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,76 @@ public void testComputingSketchOnNumericValues()
);
}

@Test
public void testCastedQuantileAndCompressionParamForTDigestQuantileAgg()
{
cannotVectorize();
testQuery(
"SELECT\n"
+ "TDIGEST_QUANTILE(m1, CAST(0.0 AS DOUBLE)), "
+ "TDIGEST_QUANTILE(m1, CAST(0.5 AS FLOAT), CAST(200 AS INTEGER)), "
+ "TDIGEST_QUANTILE(m1, CAST(1.0 AS DOUBLE), 300)\n"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(ImmutableList.of(
new TDigestSketchAggregatorFactory("a0:agg", "m1",
TDigestSketchAggregatorFactory.DEFAULT_COMPRESSION
),
new TDigestSketchAggregatorFactory("a1:agg", "m1",
200
),
new TDigestSketchAggregatorFactory("a2:agg", "m1",
300
)
))
.postAggregators(
new TDigestSketchToQuantilePostAggregator("a0", makeFieldAccessPostAgg("a0:agg"), 0.0f),
new TDigestSketchToQuantilePostAggregator("a1", makeFieldAccessPostAgg("a1:agg"), 0.5f),
new TDigestSketchToQuantilePostAggregator("a2", makeFieldAccessPostAgg("a2:agg"), 1.0f)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ResultMatchMode.EQUALS_EPS,
ImmutableList.of(
new Object[]{1.0, 3.5, 6.0}
)
);
}

@Test
public void testComputingSketchOnNumericValuesWithCastedCompressionParameter()
{
cannotVectorize();

testQuery(
"SELECT\n"
+ "TDIGEST_GENERATE_SKETCH(m1, CAST(200 AS INTEGER))"
+ "FROM foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(ImmutableList.of(
new TDigestSketchAggregatorFactory("a0:agg", "m1", 200)
))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ResultMatchMode.EQUALS_EPS,
ImmutableList.of(
new String[]{
"\"AAAAAT/wAAAAAAAAQBgAAAAAAABAaQAAAAAAAAAAAAY/8AAAAAAAAD/wAAAAAAAAP/AAAAAAAABAAAAAAAAAAD/wAAAAAAAAQAgAAAAAAAA/8AAAAAAAAEAQAAAAAAAAP/AAAAAAAABAFAAAAAAAAD/wAAAAAAAAQBgAAAAAAAA=\""
}
)
);
}

@Test
public void testComputingSketchOnCastedString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.BasicOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.expression.PostAggregatorVisitor;
Expand Down Expand Up @@ -143,8 +142,8 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail
final RelDataType operandType = callBinding.getValidator().deriveType(callBinding.getScope(), operand);

// Verify that 'operand' is a literal number.
if (!SqlUtil.isLiteral(operand)) {
return BasicOperandTypeChecker.throwOrReturn(
if (!SqlUtil.isLiteral(operand, true)) {
return OperatorConversions.throwOrReturn(
throwOnFailure,
callBinding,
cb -> cb.getValidator()
Expand All @@ -156,7 +155,7 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail
}

if (!SqlTypeFamily.NUMERIC.contains(operandType)) {
return BasicOperandTypeChecker.throwOrReturn(
return OperatorConversions.throwOrReturn(
throwOnFailure,
callBinding,
SqlCallBinding::newValidationSignatureError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ public void testDoublesSketchPostAggs()
+ " DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt + 123), 0.5) + 1000,\n"
+ " ABS(DS_GET_QUANTILE(DS_QUANTILES_SKETCH(cnt), 0.5)),\n"
+ " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), 0.5, 0.8),\n"
+ " DS_GET_QUANTILES(DS_QUANTILES_SKETCH(cnt), CAST(0.5 AS DOUBLE), CAST(0.8 AS DOUBLE)),\n"
+ " DS_HISTOGRAM(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n"
+ " DS_RANK(DS_QUANTILES_SKETCH(cnt), 3),\n"
+ " DS_CDF(DS_QUANTILES_SKETCH(cnt), 0.2, 0.6),\n"
Expand Down Expand Up @@ -588,41 +589,49 @@ public void testDoublesSketchPostAggs()
),
new double[]{0.5d, 0.8d}
),
new DoublesSketchToHistogramPostAggregator(
new DoublesSketchToQuantilesPostAggregator(
"p13",
new FieldAccessPostAggregator(
"p12",
"a2:agg"
),
new double[]{0.5d, 0.8d}
),
new DoublesSketchToHistogramPostAggregator(
"p15",
new FieldAccessPostAggregator(
"p14",
"a2:agg"
),
new double[]{0.2d, 0.6d},
null
),
new DoublesSketchToRankPostAggregator(
"p15",
"p17",
new FieldAccessPostAggregator(
"p14",
"p16",
"a2:agg"
),
3.0d
),
new DoublesSketchToCDFPostAggregator(
"p17",
"p19",
new FieldAccessPostAggregator(
"p16",
"p18",
"a2:agg"
),
new double[]{0.2d, 0.6d}
),
new DoublesSketchToStringPostAggregator(
"p19",
"p21",
new FieldAccessPostAggregator(
"p18",
"p20",
"a2:agg"
)
),
new ExpressionPostAggregator(
"p20",
"replace(replace(\"p19\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch'),"
"p22",
"replace(replace(\"p21\",'HeapCompactDoublesSketch','HeapUpdateDoublesSketch'),"
+ "'Combined Buffer Capacity : 6',"
+ "'Combined Buffer Capacity : 8')",
null,
Expand All @@ -640,6 +649,7 @@ public void testDoublesSketchPostAggs()
1124.0d,
1.0d,
"[1.0,1.0]",
"[1.0,1.0]",
"[0.0,0.0,6.0]",
1.0d,
"[0.0,0.0,1.0]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.bloom.BloomFilterAggregatorFactory;
Expand All @@ -38,6 +38,7 @@
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DefaultOperandTypeChecker;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.Calcites;
Expand Down Expand Up @@ -168,8 +169,6 @@ public Aggregation toDruidAggregation(

private static class BloomFilterSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE1 = "'" + NAME + "(column, maxNumEntries)'";

BloomFilterSqlAggFunction()
{
super(
Expand All @@ -178,13 +177,18 @@ private static class BloomFilterSqlAggFunction extends SqlAggFunction
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
null,
OperandTypes.and(
OperandTypes.sequence(SIGNATURE1, OperandTypes.ANY, OperandTypes.LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
),
// Allow signatures like 'BLOOM_FILTER(column, maxNumEntries)'
DefaultOperandTypeChecker
.builder()
.operandNames("column", "maxNumEntries")
.operandTypes(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC)
.literalOperands(1)
.requiredOperandCount(2)
.build(),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false
false,
Optionality.FORBIDDEN
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ public void testBloomFilterAgg() throws Exception

testQuery(
"SELECT\n"
+ "BLOOM_FILTER(dim1, 1000)\n"
+ "BLOOM_FILTER(dim1, 1000),\n"
+ "BLOOM_FILTER(dim1, CAST(1000 AS INTEGER))\n"
+ "FROM numfoo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
Expand All @@ -145,7 +146,10 @@ public void testBloomFilterAgg() throws Exception
.build()
),
ImmutableList.of(
new Object[]{queryFramework().queryJsonMapper().writeValueAsString(expected1)}
new Object[]{
queryFramework().queryJsonMapper().writeValueAsString(expected1),
queryFramework().queryJsonMapper().writeValueAsString(expected1)
}
)
);
}
Expand Down
Loading

0 comments on commit 2ea7177

Please sign in to comment.