Skip to content

Commit

Permalink
Allow aliasing of SQL operator and Macros
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavbhole committed Sep 25, 2023
1 parent 75af741 commit 5dcbc0e
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,28 @@
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

public class BuiltInExprMacros
{
public static class ComplexDecodeBase64ExprMacro implements ExprMacroTable.ExprMacro
{
public static final String NAME = "complex_decode_base64";
public static final String ALIAS_NAME = "decode_base64_complex";

@Override
public String name()
{
return NAME;
}

@Override
public Optional<String> alias()
{
return Optional.of(ALIAS_NAME);
}

@Override
public Expr apply(List<Expr> args)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Mechanism by which Druid expressions can define new functions for the Druid expression language. When
Expand All @@ -53,9 +52,19 @@ public class ExprMacroTable

public ExprMacroTable(final List<ExprMacro> macros)
{
this.macroMap = Maps.newHashMapWithExpectedSize(BUILT_IN.size() + macros.size());
macroMap.putAll(BUILT_IN.stream().collect(Collectors.toMap(m -> StringUtils.toLowerCase(m.name()), m -> m)));
macroMap.putAll(macros.stream().collect(Collectors.toMap(m -> StringUtils.toLowerCase(m.name()), m -> m)));
this.macroMap = Maps.newHashMapWithExpectedSize(BUILT_IN.size() + 1 + macros.size());
BUILT_IN.forEach(m -> {
macroMap.put(StringUtils.toLowerCase(m.name()), m);
if (m.alias().isPresent()) {
macroMap.put(StringUtils.toLowerCase(m.alias().get()), m);
}
});
for (ExprMacro macro : macros) {
macroMap.put(StringUtils.toLowerCase(macro.name()), macro);
if (macro.alias().isPresent()) {
macroMap.put(StringUtils.toLowerCase(macro.alias().get()), macro);
}
}
}

public static ExprMacroTable nil()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.Optional;

/**
* Common stuff for "named" functions of "functional" expressions, such as {@link FunctionExpr},
Expand All @@ -38,6 +39,14 @@ public interface NamedFunction
*/
String name();

/**
* Alias of the function
*/
default Optional<String> alias()
{
return Optional.empty();
}

/**
* Helper method for creating a {@link ExpressionValidationException} with the specified reason
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,15 @@ public void testComplexDecode()
),
expected
);
// test with alias
assertExpr(
StringUtils.format(
"decode_base64_complex('%s', '%s')",
TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
StringUtils.encodeBase64String(bytes)
),
expected
);
}

@Test
Expand All @@ -964,6 +973,13 @@ public void testComplexDecodeNull()
),
null
);
assertExpr(
StringUtils.format(
"decode_base64_complex('%s', null)",
TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName()
),
null
);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.apache.calcite.util.Static;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.PostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
Expand Down Expand Up @@ -565,6 +566,38 @@ public T build()
);
}

/**
* Creates a {@link SqlFunction} from this builder with alias.
*/
@SuppressWarnings("unchecked")
public Pair<T, T> buildWithAlias(String alias)
{
Preconditions.checkNotNull(alias, "Function Alias");
Preconditions.checkArgument(!alias.equals(name), "Function alias can not equal to name");
final IntSet nullableOperands = buildNullableOperands();
final SqlOperandTypeInference operandTypeInference = buildOperandTypeInference(nullableOperands);
final SqlOperandTypeChecker sqlOperandTypeChecker = buildOperandTypeChecker(nullableOperands);
Preconditions.checkNotNull(returnTypeInference, "returnTypeInference");
return Pair.of(
(T) new SqlFunction(
name,
kind,
returnTypeInference,
operandTypeInference,
sqlOperandTypeChecker,
functionCategory
),
(T) new SqlFunction(
alias,
kind,
returnTypeInference,
operandTypeInference,
sqlOperandTypeChecker,
functionCategory
)
);
}

protected IntSet buildNullableOperands()
{
// Create "nullableOperands" set including all optional arguments.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;

import javax.annotation.Nullable;
import java.util.Optional;

public interface SqlOperatorConversion
{
Expand All @@ -38,6 +39,16 @@ public interface SqlOperatorConversion
*/
SqlOperator calciteOperator();

/**
* Returns the Alias SQL operator corresponding to this function. Should be a singleton.
*
* @return operator
*/
default Optional<SqlOperator> aliasCalciteOperator()
{
return Optional.empty();
}

/**
* Translate a Calcite {@code RexNode} to a Druid expression.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.BuiltInExprMacros;
import org.apache.druid.segment.column.ColumnType;
Expand All @@ -37,6 +38,7 @@
import org.apache.druid.sql.calcite.table.RowSignatures;

import javax.annotation.Nullable;
import java.util.Optional;

public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConversion
{
Expand All @@ -50,7 +52,7 @@ public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConvers
);
};

private static final SqlFunction SQL_FUNCTION = OperatorConversions
private static final Pair<SqlFunction, SqlFunction> SQL_FUNCTION_PAIR = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME))
.operandTypeChecker(
OperandTypes.sequence(
Expand All @@ -61,13 +63,19 @@ public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConvers
)
.returnTypeInference(ARBITRARY_COMPLEX_RETURN_TYPE_INFERENCE)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
.buildWithAlias(StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.ALIAS_NAME));


@Override
public SqlOperator calciteOperator()
{
return SQL_FUNCTION;
return SQL_FUNCTION_PAIR.lhs;
}

@Override
public Optional<SqlOperator> aliasCalciteOperator()
{
return Optional.of(SQL_FUNCTION_PAIR.rhs);
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,13 @@ public DruidOperatorTable(
|| this.operatorConversions.put(operatorKey, operatorConversion) != null) {
throw new ISE("Cannot have two operators with key [%s]", operatorKey);
}
if (operatorConversion.aliasCalciteOperator().isPresent()) {
final OperatorKey aliasOperatorKey = OperatorKey.of(operatorConversion.aliasCalciteOperator().get());
if (this.aggregators.containsKey(aliasOperatorKey)
|| this.operatorConversions.put(aliasOperatorKey, operatorConversion) != null) {
throw new ISE("Cannot have two operators with alias key [%s]", aliasOperatorKey);
}
}
}

for (SqlOperatorConversion operatorConversion : STANDARD_OPERATOR_CONVERSIONS) {
Expand All @@ -456,6 +463,15 @@ public DruidOperatorTable(
}

this.operatorConversions.putIfAbsent(operatorKey, operatorConversion);

if (operatorConversion.aliasCalciteOperator().isPresent()) {
final OperatorKey aliasOperatorKey = OperatorKey.of(operatorConversion.aliasCalciteOperator().get());
// Don't complain if the alias already exists; we allow standard operators alias to be overridden as well.
if (this.aggregators.containsKey(aliasOperatorKey)) {
continue;
}
this.operatorConversions.put(aliasOperatorKey, operatorConversion);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.druid.java.util.common.HumanReadableBytes;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.JodaUtils;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.PeriodGranularity;
Expand Down Expand Up @@ -14368,36 +14369,40 @@ public void testTimeseriesQueryWithEmptyInlineDatasourceAndGranularity()
public void testComplexDecode()
{
cannotVectorize();
testQuery(
"SELECT COMPLEX_DECODE_BASE64('hyperUnique',PARSE_JSON(TO_JSON_STRING(unique_dim1))) from druid.foo LIMIT 10",
ImmutableList.of(
Druids.newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("v0")
.virtualColumns(
expressionVirtualColumn(
"v0",
"complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))",
ColumnType.ofComplex("hyperUnique")
)
)
.resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.legacy(false)
.limit(10)
.build()
),
ImmutableList.of(
new Object[]{"\"AQAAAEAAAA==\""},
new Object[]{"\"AQAAAQAAAAHNBA==\""},
new Object[]{"\"AQAAAQAAAAOzAg==\""},
new Object[]{"\"AQAAAQAAAAFREA==\""},
new Object[]{"\"AQAAAQAAAACyEA==\""},
new Object[]{"\"AQAAAQAAAAEkAQ==\""}
)
);
for (String complexDecode : Arrays.asList("COMPLEX_DECODE_BASE64", "DECODE_BASE64_COMPLEX")) {
testQuery(
StringUtils.format(
"SELECT %s('hyperUnique',PARSE_JSON(TO_JSON_STRING(unique_dim1))) from druid.foo LIMIT 10",
complexDecode
),
ImmutableList.of(
Druids.newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("v0")
.virtualColumns(
expressionVirtualColumn(
"v0",
"complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))",
ColumnType.ofComplex("hyperUnique")
)
)
.resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.legacy(false)
.limit(10)
.build()
),
ImmutableList.of(
new Object[]{"\"AQAAAEAAAA==\""},
new Object[]{"\"AQAAAQAAAAHNBA==\""},
new Object[]{"\"AQAAAQAAAAOzAg==\""},
new Object[]{"\"AQAAAQAAAAFREA==\""},
new Object[]{"\"AQAAAQAAAACyEA==\""},
new Object[]{"\"AQAAAQAAAAEkAQ==\""}
)
);
}
}

@Test
public void testComplexDecodeAgg()
{
Expand Down

0 comments on commit 5dcbc0e

Please sign in to comment.