diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java index 97fef114b086..27e5c19cacf2 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseSingleStageBrokerRequestHandler.java @@ -52,7 +52,6 @@ import org.apache.pinot.broker.querylog.QueryLogger; import org.apache.pinot.broker.queryquota.QueryQuotaManager; import org.apache.pinot.broker.routing.BrokerRoutingManager; -import org.apache.pinot.calcite.jdbc.CalciteSchemaBuilder; import org.apache.pinot.common.config.provider.TableCache; import org.apache.pinot.common.exception.QueryException; import org.apache.pinot.common.http.MultiHttpRequest; @@ -84,7 +83,6 @@ import org.apache.pinot.core.routing.TimeBoundaryInfo; import org.apache.pinot.core.transport.ServerInstance; import org.apache.pinot.core.util.GapfillUtils; -import org.apache.pinot.query.catalog.PinotCatalog; import org.apache.pinot.query.parser.utils.ParserUtils; import org.apache.pinot.spi.auth.AuthorizationResult; import org.apache.pinot.spi.config.table.FieldConfig; @@ -264,8 +262,7 @@ protected BrokerResponse handleRequest(long requestId, String query, @Nullable S // Check if the query is a v2 supported query Map queryOptions = sqlNodeAndOptions.getOptions(); String database = DatabaseUtils.extractDatabaseFromQueryRequest(queryOptions, httpHeaders); - if (ParserUtils.canCompileQueryUsingV2Engine(query, CalciteSchemaBuilder.asRootSchema( - new PinotCatalog(database, _tableCache), database))) { + if (ParserUtils.canCompileWithMultiStageEngine(query, database, _tableCache)) { return new BrokerResponseNative(QueryException.getException(QueryException.SQL_PARSING_ERROR, new Exception( "It seems that the query is only supported by the multi-stage query engine, please retry the query using " + "the multi-stage query engine " @@ -398,8 +395,7 @@ protected BrokerResponse handleRequest(long requestId, String query, @Nullable S if (StringUtils.isNotBlank(failureMessage)) { failureMessage = "Reason: " + failureMessage; } - throw new WebApplicationException("Permission denied." + failureMessage, - Response.Status.FORBIDDEN); + throw new WebApplicationException("Permission denied." + failureMessage, Response.Status.FORBIDDEN); } // Get the tables hit by the request diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java index 3aab671c075b..c8f7c4c2f601 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/MultiStageBrokerRequestHandler.java @@ -39,7 +39,6 @@ import org.apache.pinot.broker.querylog.QueryLogger; import org.apache.pinot.broker.queryquota.QueryQuotaManager; import org.apache.pinot.broker.routing.BrokerRoutingManager; -import org.apache.pinot.calcite.jdbc.CalciteSchemaBuilder; import org.apache.pinot.common.config.provider.TableCache; import org.apache.pinot.common.exception.QueryException; import org.apache.pinot.common.metrics.BrokerMeter; @@ -65,8 +64,6 @@ import org.apache.pinot.query.runtime.MultiStageStatsTreeBuilder; import org.apache.pinot.query.runtime.plan.MultiStageQueryStats; import org.apache.pinot.query.service.dispatch.QueryDispatcher; -import org.apache.pinot.query.type.TypeFactory; -import org.apache.pinot.query.type.TypeSystem; import org.apache.pinot.spi.auth.TableAuthorizationResult; import org.apache.pinot.spi.env.PinotConfiguration; import org.apache.pinot.spi.exception.DatabaseConflictException; @@ -85,6 +82,7 @@ public class MultiStageBrokerRequestHandler extends BaseBrokerRequestHandler { private final WorkerManager _workerManager; private final QueryDispatcher _queryDispatcher; + private final PinotCatalog _catalog; public MultiStageBrokerRequestHandler(PinotConfiguration config, String brokerId, BrokerRoutingManager routingManager, AccessControlFactory accessControlFactory, QueryQuotaManager queryQuotaManager, TableCache tableCache) { @@ -93,6 +91,7 @@ public MultiStageBrokerRequestHandler(PinotConfiguration config, String brokerId int port = Integer.parseInt(config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT)); _workerManager = new WorkerManager(hostname, port, _routingManager); _queryDispatcher = new QueryDispatcher(new MailboxService(hostname, port, config)); + _catalog = new PinotCatalog(tableCache); LOGGER.info("Initialized MultiStageBrokerRequestHandler on host: {}, port: {} with broker id: {}, timeout: {}ms, " + "query log max length: {}, query log max rate: {}", hostname, port, _brokerId, _brokerTimeoutMs, _queryLogger.getMaxQueryLengthToLog(), _queryLogger.getLogRateLimit()); @@ -134,9 +133,7 @@ protected BrokerResponse handleRequest(long requestId, String query, @Nullable S Long timeoutMsFromQueryOption = QueryOptionsUtils.getTimeoutMs(queryOptions); queryTimeoutMs = timeoutMsFromQueryOption != null ? timeoutMsFromQueryOption : _brokerTimeoutMs; String database = DatabaseUtils.extractDatabaseFromQueryRequest(queryOptions, httpHeaders); - QueryEnvironment queryEnvironment = new QueryEnvironment(new TypeFactory(new TypeSystem()), - CalciteSchemaBuilder.asRootSchema(new PinotCatalog(database, _tableCache), database), _workerManager, - _tableCache); + QueryEnvironment queryEnvironment = new QueryEnvironment(database, _tableCache, _workerManager); switch (sqlNodeAndOptions.getSqlNode().getKind()) { case EXPLAIN: queryPlanResult = queryEnvironment.explainQuery(query, sqlNodeAndOptions, requestId); diff --git a/pinot-common/pom.xml b/pinot-common/pom.xml index 57959b531c89..2d93549d405a 100644 --- a/pinot-common/pom.xml +++ b/pinot-common/pom.xml @@ -169,6 +169,10 @@ org.apache.calcite calcite-babel + + org.immutables + value-annotations + diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java index 00df9498ddac..a9d6f639b236 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionRegistry.java @@ -21,15 +21,21 @@ import com.google.common.base.Preconditions; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.util.Arrays; import java.util.HashMap; -import java.util.List; +import java.util.HashSet; import java.util.Map; import java.util.Set; import javax.annotation.Nullable; -import org.apache.calcite.schema.Function; -import org.apache.calcite.schema.impl.ScalarFunctionImpl; -import org.apache.calcite.util.NameMultimap; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.commons.lang3.StringUtils; +import org.apache.pinot.common.function.sql.PinotSqlFunction; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.PinotReflectionUtils; import org.slf4j.Logger; @@ -38,7 +44,28 @@ /** * Registry for scalar functions. - *

TODO: Merge FunctionRegistry and FunctionDefinitionRegistry to provide one single registry for all functions. + * + *

To plug in a class: + *

+ *

To plug in a method: + *

+ *

Multiple methods with different number of arguments can be registered under the same canonical name. Otherwise, + * each canonical name can only be registered once. + *

Class implementing {@link PinotScalarFunction} gives finer control on return type inference and operand type + * check, and allows polymorphism based on the argument types. + *

Method is easier to implement but has less control. If different return type inference or operand type check is + * desired over the default java class inference, they can be directly registered into {@code PinotOperatorTable}. + *

The package name convention is used to reduce the time of class scanning. */ public class FunctionRegistry { private FunctionRegistry() { @@ -46,21 +73,49 @@ private FunctionRegistry() { private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class); - // TODO: consolidate the following 2 - // This FUNCTION_INFO_MAP is used by Pinot server to look up function by # of arguments - private static final Map> FUNCTION_INFO_MAP = new HashMap<>(); - // This FUNCTION_MAP is used by Calcite function catalog to look up function by function signature. - private static final NameMultimap FUNCTION_MAP = new NameMultimap<>(); + // Key is canonical name + public static final Map FUNCTION_MAP; private static final int VAR_ARG_KEY = -1; - /** - * Registers the scalar functions via reflection. - * NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function." - * in its class path. This convention can significantly reduce the time of class scanning. - */ static { long startTimeMs = System.currentTimeMillis(); + + // Register ScalarFunction classes + Map functionMap = new HashMap<>(); + Set> classes = + PinotReflectionUtils.getClassesThroughReflection(".*\\.function\\..*", ScalarFunction.class); + for (Class clazz : classes) { + if (!Modifier.isPublic(clazz.getModifiers())) { + continue; + } + ScalarFunction scalarFunction = clazz.getAnnotation(ScalarFunction.class); + if (scalarFunction.enabled()) { + PinotScalarFunction function; + try { + function = (PinotScalarFunction) clazz.getConstructor().newInstance(); + } catch (Exception e) { + throw new IllegalStateException("Failed to instantiate PinotScalarFunction with class: " + clazz); + } + String[] names = scalarFunction.names(); + if (names.length == 0) { + register(canonicalize(function.getName()), function, functionMap); + } else { + Set canonicalNames = new HashSet<>(); + for (String name : names) { + if (!canonicalNames.add(canonicalize(name))) { + LOGGER.warn("Duplicate names: {} in class: {}", Arrays.toString(names), clazz); + } + } + for (String canonicalName : canonicalNames) { + register(canonicalName, function, functionMap); + } + } + } + } + + // Register ScalarFunction methods + Map> functionInfoMap = new HashMap<>(); Set methods = PinotReflectionUtils.getMethodsThroughReflection(".*\\.function\\..*", ScalarFunction.class); for (Method method : methods) { if (!Modifier.isPublic(method.getModifiers())) { @@ -68,22 +123,37 @@ private FunctionRegistry() { } ScalarFunction scalarFunction = method.getAnnotation(ScalarFunction.class); if (scalarFunction.enabled()) { - // Annotated function names - String[] scalarFunctionNames = scalarFunction.names(); - boolean nullableParameters = scalarFunction.nullableParameters(); - boolean isPlaceholder = scalarFunction.isPlaceholder(); - boolean isVarArg = scalarFunction.isVarArg(); - if (scalarFunctionNames.length > 0) { - for (String name : scalarFunctionNames) { - FunctionRegistry.registerFunction(name, method, nullableParameters, isPlaceholder, isVarArg); - } + FunctionInfo functionInfo = + new FunctionInfo(method, method.getDeclaringClass(), scalarFunction.nullableParameters()); + int numArguments = scalarFunction.isVarArg() ? VAR_ARG_KEY : method.getParameterCount(); + String[] names = scalarFunction.names(); + if (names.length == 0) { + register(canonicalize(method.getName()), functionInfo, numArguments, functionInfoMap); } else { - FunctionRegistry.registerFunction(method, nullableParameters, isPlaceholder, isVarArg); + Set canonicalNames = new HashSet<>(); + for (String name : names) { + if (!canonicalNames.add(canonicalize(name))) { + LOGGER.warn("Duplicate names: {} in method: {}", Arrays.toString(names), method); + } + } + for (String canonicalName : canonicalNames) { + register(canonicalName, functionInfo, numArguments, functionInfoMap); + } } } } - LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_INFO_MAP.size(), - FUNCTION_INFO_MAP.keySet(), System.currentTimeMillis() - startTimeMs); + + // Create PinotScalarFunction for registered methods + for (Map.Entry> entry : functionInfoMap.entrySet()) { + String canonicalName = entry.getKey(); + Preconditions.checkState( + functionMap.put(canonicalName, new ArgumentCountBasedScalarFunction(canonicalName, entry.getValue())) == null, + "Function: %s is already registered", canonicalName); + } + + FUNCTION_MAP = Map.copyOf(functionMap); + LOGGER.info("Initialized FunctionRegistry with {} functions: {} in {}ms", FUNCTION_MAP.size(), + FUNCTION_MAP.keySet(), System.currentTimeMillis() - startTimeMs); } /** @@ -95,108 +165,165 @@ public static void init() { } /** - * Registers a method with the name of the method. + * Registers a {@link PinotScalarFunction} under the given canonical name. */ - public static void registerFunction(Method method, boolean nullableParameters, boolean isPlaceholder, - boolean isVarArg) { - registerFunction(method.getName(), method, nullableParameters, isPlaceholder, isVarArg); + private static void register(String canonicalName, PinotScalarFunction function, + Map functionMap) { + Preconditions.checkState(functionMap.put(canonicalName, function) == null, "Function: %s is already registered", + canonicalName); } /** - * Registers a method with the given function name. + * Registers a {@link FunctionInfo} under the given canonical name. */ - public static void registerFunction(String functionName, Method method, boolean nullableParameters, - boolean isPlaceholder, boolean isVarArg) { - if (!isPlaceholder) { - registerFunctionInfoMap(functionName, method, nullableParameters, isVarArg); - } - registerCalciteNamedFunctionMap(functionName, method, nullableParameters, isVarArg); + private static void register(String canonicalName, FunctionInfo functionInfo, int numArguments, + Map> functionInfoMap) { + Preconditions.checkState( + functionInfoMap.computeIfAbsent(canonicalName, k -> new HashMap<>()).put(numArguments, functionInfo) == null, + "Function: %s with %s arguments is already registered", canonicalName, + numArguments == VAR_ARG_KEY ? "variable" : numArguments); } - private static void registerFunctionInfoMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters); - String canonicalName = canonicalize(functionName); - Map functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>()); - if (isVarArg) { - FunctionInfo existFunctionInfo = functionInfoMap.put(VAR_ARG_KEY, functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with variable number of parameters is already registered", functionName); - } else { - FunctionInfo existFunctionInfo = functionInfoMap.put(method.getParameterCount(), functionInfo); - Preconditions.checkState(existFunctionInfo == null || existFunctionInfo.getMethod() == functionInfo.getMethod(), - "Function: %s with %s parameters is already registered", functionName, method.getParameterCount()); - } - } - - private static void registerCalciteNamedFunctionMap(String functionName, Method method, boolean nullableParameters, - boolean isVarArg) { - if (method.getAnnotation(Deprecated.class) == null) { - FUNCTION_MAP.put(functionName, ScalarFunctionImpl.create(method)); - } - } - - public static Map> getRegisteredCalciteFunctionMap() { - return FUNCTION_MAP.map(); + /** + * Returns {@code true} if the given canonical name is registered, {@code false} otherwise. + * + * TODO: Consider adding a way to look up the usage of a function for better error message when there is no matching + * FunctionInfo. + */ + public static boolean contains(String canonicalName) { + return FUNCTION_MAP.containsKey(canonicalName); } - public static Set getRegisteredCalciteFunctionNames() { - return FUNCTION_MAP.map().keySet(); + /** + * @deprecated For performance concern, use {@link #contains(String)} instead to avoid invoking + * {@link #canonicalize(String)} multiple times. + */ + @Deprecated + public static boolean containsFunction(String name) { + return contains(canonicalize(name)); } /** - * Returns {@code true} if the given function name is registered, {@code false} otherwise. + * Returns the {@link FunctionInfo} associated with the given canonical name and argument types, or {@code null} if + * there is no matching method. This method should be called after the FunctionRegistry is initialized and all methods + * are already registered. */ - public static boolean containsFunction(String functionName) { - return FUNCTION_INFO_MAP.containsKey(canonicalize(functionName)); + @Nullable + public static FunctionInfo lookupFunctionInfo(String canonicalName, ColumnDataType[] argumentTypes) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(argumentTypes) : null; } /** - * Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null} + * Returns the {@link FunctionInfo} associated with the given canonical name and number of arguments, or {@code null} * if there is no matching method. This method should be called after the FunctionRegistry is initialized and all * methods are already registered. + * TODO: Move all usages to {@link #lookupFunctionInfo(String, ColumnDataType[])}. */ @Nullable - public static FunctionInfo getFunctionInfo(String functionName, int numParameters) { - Map functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName)); - if (functionInfoMap != null) { - FunctionInfo functionInfo = functionInfoMap.get(numParameters); - if (functionInfo != null) { - return functionInfo; - } - return functionInfoMap.get(VAR_ARG_KEY); - } - return null; - } - - private static String canonicalize(String functionName) { - return StringUtils.remove(functionName, '_').toLowerCase(); + public static FunctionInfo lookupFunctionInfo(String canonicalName, int numArguments) { + PinotScalarFunction function = FUNCTION_MAP.get(canonicalName); + return function != null ? function.getFunctionInfo(numArguments) : null; } /** - * Placeholders for scalar function, they register and represents the signature for transform and filter predicate - * so that v2 engine can understand and plan them correctly. + * @deprecated For performance concern, use {@link #lookupFunctionInfo(String, int)} instead to avoid invoking + * {@link #canonicalize(String)} multiple times. */ - private static class PlaceholderScalarFunctions { + @Deprecated + @Nullable + public static FunctionInfo getFunctionInfo(String name, int numArguments) { + return lookupFunctionInfo(canonicalize(name), numArguments); + } + + public static String canonicalize(String name) { + return StringUtils.remove(name, '_').toLowerCase(); + } + + public static class ArgumentCountBasedScalarFunction implements PinotScalarFunction { + private final String _name; + private final Map _functionInfoMap; + + private ArgumentCountBasedScalarFunction(String name, Map functionInfoMap) { + _name = name; + _functionInfoMap = functionInfoMap; + } - @ScalarFunction(names = {"textContains", "text_contains"}, isPlaceholder = true) - public static boolean textContains(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + @Override + public String getName() { + return _name; } - @ScalarFunction(names = {"textMatch", "text_match"}, isPlaceholder = true) - public static boolean textMatch(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + @Override + public PinotSqlFunction toPinotSqlFunction() { + return new PinotSqlFunction(_name, getReturnTypeInference(), getOperandTypeChecker()); + } + + private SqlReturnTypeInference getReturnTypeInference() { + return opBinding -> { + int numArguments = opBinding.getOperandCount(); + FunctionInfo functionInfo = getFunctionInfo(numArguments); + Preconditions.checkState(functionInfo != null, "Failed to find function: %s with %s arguments", _name, + numArguments); + RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); + Method method = functionInfo.getMethod(); + RelDataType returnType = FunctionUtils.getRelDataType(opBinding.getTypeFactory(), method.getReturnType()); + + if (!functionInfo.hasNullableParameters()) { + // When any parameter is null, return is null + for (RelDataType type : opBinding.collectOperandTypes()) { + if (type.isNullable()) { + return typeFactory.createTypeWithNullability(returnType, true); + } + } + } + + return method.isAnnotationPresent(Nullable.class) ? typeFactory.createTypeWithNullability(returnType, true) + : returnType; + }; + } + + private SqlOperandTypeChecker getOperandTypeChecker() { + if (_functionInfoMap.containsKey(VAR_ARG_KEY)) { + return OperandTypes.VARIADIC; + } + int numCheckers = _functionInfoMap.size(); + if (numCheckers == 1) { + return getOperandTypeChecker(_functionInfoMap.values().iterator().next().getMethod()); + } + SqlOperandTypeChecker[] operandTypeCheckers = new SqlOperandTypeChecker[numCheckers]; + int index = 0; + for (FunctionInfo functionInfo : _functionInfoMap.values()) { + operandTypeCheckers[index++] = getOperandTypeChecker(functionInfo.getMethod()); + } + return OperandTypes.or(operandTypeCheckers); + } + + private static SqlOperandTypeChecker getOperandTypeChecker(Method method) { + Class[] parameterTypes = method.getParameterTypes(); + int length = parameterTypes.length; + SqlTypeFamily[] typeFamilies = new SqlTypeFamily[length]; + for (int i = 0; i < length; i++) { + typeFamilies[i] = getSqlTypeFamily(parameterTypes[i]); + } + return OperandTypes.family(typeFamilies); } - @ScalarFunction(names = {"jsonMatch", "json_match"}, isPlaceholder = true) - public static boolean jsonMatch(String text, String pattern) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + private static SqlTypeFamily getSqlTypeFamily(Class clazz) { + // NOTE: Pinot allows some non-standard type conversions such as Timestamp <-> long, boolean <-> int etc. Do not + // restrict the type family for now. We only restrict the type family for String so that cast can be added. + // Explicit cast is required to correctly convert boolean and Timestamp to String. Without explicit case, + // BOOLEAN and TIMESTAMP type will be converted with their internal stored format which is INT and LONG + // respectively. E.g. true will be converted to "1", timestamp will be converted to long value string. + // TODO: Revisit this. + return clazz == String.class ? SqlTypeFamily.CHARACTER : SqlTypeFamily.ANY; } - @ScalarFunction(names = {"vectorSimilarity", "vector_similarity"}, isPlaceholder = true) - public static boolean vectorSimilarity(float[] vector1, float[] vector2, int topk) { - throw new UnsupportedOperationException("Placeholder scalar function, should not reach here"); + @Nullable + @Override + public FunctionInfo getFunctionInfo(int numArguments) { + FunctionInfo functionInfo = _functionInfoMap.get(numArguments); + return functionInfo != null ? functionInfo : _functionInfoMap.get(VAR_ARG_KEY); } } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java index 689922c40e98..1b805e9e525f 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java @@ -24,6 +24,9 @@ import java.util.HashMap; import java.util.Map; import javax.annotation.Nullable; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.PinotDataType; import org.apache.pinot.spi.data.FieldSpec.DataType; @@ -34,7 +37,7 @@ private FunctionUtils() { } // Types allowed as the function parameter (in the function signature) for type conversion - private static final Map, PinotDataType> PARAMETER_TYPE_MAP = new HashMap, PinotDataType>() {{ + private static final Map, PinotDataType> PARAMETER_TYPE_MAP = new HashMap<>() {{ put(int.class, PinotDataType.INTEGER); put(Integer.class, PinotDataType.INTEGER); put(long.class, PinotDataType.LONG); @@ -58,7 +61,7 @@ private FunctionUtils() { }}; // Types allowed as the function argument (actual value passed into the function) for type conversion - private static final Map, PinotDataType> ARGUMENT_TYPE_MAP = new HashMap, PinotDataType>() {{ + private static final Map, PinotDataType> ARGUMENT_TYPE_MAP = new HashMap<>() {{ put(Byte.class, PinotDataType.BYTE); put(Boolean.class, PinotDataType.BOOLEAN); put(Character.class, PinotDataType.CHARACTER); @@ -84,7 +87,7 @@ private FunctionUtils() { put(Object[].class, PinotDataType.OBJECT_ARRAY); }}; - private static final Map, DataType> DATA_TYPE_MAP = new HashMap, DataType>() {{ + private static final Map, DataType> DATA_TYPE_MAP = new HashMap<>() {{ put(int.class, DataType.INT); put(Integer.class, DataType.INT); put(long.class, DataType.LONG); @@ -106,7 +109,7 @@ private FunctionUtils() { put(String[].class, DataType.STRING); }}; - private static final Map, ColumnDataType> COLUMN_DATA_TYPE_MAP = new HashMap, ColumnDataType>() {{ + private static final Map, ColumnDataType> COLUMN_DATA_TYPE_MAP = new HashMap<>() {{ put(int.class, ColumnDataType.INT); put(Integer.class, ColumnDataType.INT); put(long.class, ColumnDataType.LONG); @@ -163,4 +166,53 @@ public static DataType getDataType(Class clazz) { public static ColumnDataType getColumnDataType(Class clazz) { return COLUMN_DATA_TYPE_MAP.get(clazz); } + + /** + * Returns the corresponding RelDataType for the given class, or OTHER if there is no one matching. + */ + public static RelDataType getRelDataType(RelDataTypeFactory typeFactory, Class clazz) { + ColumnDataType columnDataType = getColumnDataType(clazz); + if (columnDataType == null) { + return typeFactory.createSqlType(SqlTypeName.OTHER); + } + switch (columnDataType) { + case INT: + return typeFactory.createSqlType(SqlTypeName.INTEGER); + case LONG: + return typeFactory.createSqlType(SqlTypeName.BIGINT); + case FLOAT: + return typeFactory.createSqlType(SqlTypeName.FLOAT); + case DOUBLE: + return typeFactory.createSqlType(SqlTypeName.DOUBLE); + case BIG_DECIMAL: + return typeFactory.createSqlType(SqlTypeName.DECIMAL); + case BOOLEAN: + return typeFactory.createSqlType(SqlTypeName.BOOLEAN); + case TIMESTAMP: + return typeFactory.createSqlType(SqlTypeName.TIMESTAMP); + case STRING: + case JSON: + return typeFactory.createSqlType(SqlTypeName.VARCHAR); + case BYTES: + return typeFactory.createSqlType(SqlTypeName.VARBINARY); + case INT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.INTEGER), -1); + case LONG_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BIGINT), -1); + case FLOAT_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.FLOAT), -1); + case DOUBLE_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.DOUBLE), -1); + case BOOLEAN_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.BOOLEAN), -1); + case TIMESTAMP_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.TIMESTAMP), -1); + case STRING_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARCHAR), -1); + case BYTES_ARRAY: + return typeFactory.createArrayType(typeFactory.createSqlType(SqlTypeName.VARBINARY), -1); + default: + return typeFactory.createSqlType(SqlTypeName.OTHER); + } + } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java new file mode 100644 index 000000000000..7a87935bd998 --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/PinotScalarFunction.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function; + +import javax.annotation.Nullable; +import org.apache.pinot.common.function.sql.PinotSqlFunction; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.spi.annotations.ScalarFunction; + + +/** + * Provides finer control to the scalar functions annotated with {@link ScalarFunction} + *

See more details in {@link FunctionRegistry}. + */ +public interface PinotScalarFunction { + + /** + * Returns the name of the function. + */ + String getName(); + + /** + * Returns the corresponding {@link PinotSqlFunction} to be registered into the OperatorTable, or {@code null} if it + * doesn't need to be registered (e.g. standard SqlFunction). + */ + @Nullable + PinotSqlFunction toPinotSqlFunction(); + + /** + * Returns the {@link FunctionInfo} for the given argument types, or {@code null} if there is no matching. + */ + @Nullable + default FunctionInfo getFunctionInfo(ColumnDataType[] argumentTypes) { + return getFunctionInfo(argumentTypes.length); + } + + /** + * Returns the {@link FunctionInfo} for the given number of arguments, or {@code null} if there is no matching. + */ + @Nullable + FunctionInfo getFunctionInfo(int numArguments); +} diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java index cb0683fd8bfc..27299ac290e1 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/TransformFunctionType.java @@ -18,15 +18,12 @@ */ package org.apache.pinot.common.function; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; +import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; @@ -70,7 +67,6 @@ public enum TransformFunctionType { GREATEST("greatest"), // predicate functions - // there's no need to register these functions b/c Calcite parser doesn't allow explicit function parsing EQUALS("equals"), NOT_EQUALS("not_equals"), GREATER_THAN("greater_than"), @@ -81,7 +77,6 @@ public enum TransformFunctionType { NOT_IN("not_in"), // null handling functions, they never return null - // there's no need for alternative name b/c Calcite parser doesn't allow non-parentese representation IS_TRUE("is_true"), IS_NOT_TRUE("is_not_true"), IS_FALSE("is_false"), @@ -103,202 +98,118 @@ public enum TransformFunctionType { // date type conversion functions CAST("cast"), - // object type - ARRAY_TO_MV("arrayToMV", - ReturnTypes.cascade(opBinding -> positionalComponentReturnType(opBinding, 0), SqlTypeTransforms.FORCE_NULLABLE), - OperandTypes.family(SqlTypeFamily.ARRAY), "array_to_mv") { - - @Override - public boolean isDeterministic() { - // ARRAY_TO_MV is not deterministic. In fact, it has no implementation - // We need to explicitly set it as not deterministic in order to do not let Calcite to optimize expressions like - // `ARRAY_TO_MV(RandomAirports) = 'MFR' and ARRAY_TO_MV(RandomAirports) = 'GTR'` as `false`. - // If the function were deterministic, its value would never be MFR and GTR at the same time, so Calcite is - // smart enough to know there is no value that satisfies the condition. - // In fact what ARRAY_TO_MV does is just to trick Calcite typesystem, but then what the leaf stage executor - // receives is `RandomAirports = 'MFR' and RandomAirports = 'GTR'`, which in the V1 semantics means: - // true if and only if RandomAirports contains a value equal to 'MFR' and RandomAirports contains a value equal - // to 'GTR' - return false; - } - }, - - // string functions + // JSON extract functions JSON_EXTRACT_SCALAR("jsonExtractScalar", - ReturnTypes.cascade(opBinding -> positionalReturnTypeInferenceFromStringLiteral(opBinding, 2, - SqlTypeName.VARCHAR), SqlTypeTransforms.FORCE_NULLABLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER), ordinal -> ordinal > 2), "json_extract_scalar"), + opBinding -> positionalReturnTypeInferenceFromStringLiteral(opBinding, 2, SqlTypeName.VARCHAR), + OperandTypes.family( + List.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER), + i -> i == 3)), JSON_EXTRACT_INDEX("jsonExtractIndex", - ReturnTypes.cascade(opBinding -> positionalReturnTypeInferenceFromStringLiteral(opBinding, 2, - SqlTypeName.VARCHAR), SqlTypeTransforms.FORCE_NULLABLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER), ordinal -> ordinal > 2), "json_extract_index"), - + opBinding -> positionalReturnTypeInferenceFromStringLiteral(opBinding, 2, SqlTypeName.VARCHAR), + OperandTypes.family( + List.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER), i -> i > 2)), JSON_EXTRACT_KEY("jsonExtractKey", ReturnTypes.TO_ARRAY, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER)), "json_extract_key"), - - // date time functions - TIME_CONVERT("timeConvert", - ReturnTypes.BIGINT_FORCE_NULLABLE, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER)), - "time_convert"), - - DATE_TIME_CONVERT("dateTimeConvert", - ReturnTypes.cascade( - opBinding -> dateTimeConverterReturnTypeInference(opBinding), - SqlTypeTransforms.FORCE_NULLABLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER)), "date_time_convert"), - - DATE_TIME_CONVERT_WINDOW_HOP("dateTimeConvertWindowHop", ReturnTypes.TO_ARRAY, OperandTypes.family( - ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER)), "date_time_convert_window_hop"), - - DATE_TRUNC("dateTrunc", - ReturnTypes.BIGINT_FORCE_NULLABLE, + OperandTypes.family(List.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))), + + // Date time functions + TIME_CONVERT("timeConvert", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))), + DATE_TIME_CONVERT("dateTimeConvert", TransformFunctionType::dateTimeConverterReturnTypeInference, OperandTypes.family( + List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER))), + DATE_TIME_CONVERT_WINDOW_HOP("dateTimeConvertWindowHop", + ReturnTypes.cascade(TransformFunctionType::dateTimeConverterReturnTypeInference, SqlTypeTransforms.TO_ARRAY), OperandTypes.family( - ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, - SqlTypeFamily.CHARACTER), - ordinal -> ordinal > 1), "date_trunc"), - - FROM_DATE_TIME("fromDateTime", ReturnTypes.TIMESTAMP_NULLABLE, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER), - ordinal -> ordinal > 1), "from_date_time"), - - TO_DATE_TIME("toDateTime", ReturnTypes.VARCHAR_2000_NULLABLE, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER), - ordinal -> ordinal > 1), "to_date_time"), - - TIMESTAMP_ADD("timestampAdd", ReturnTypes.TIMESTAMP_NULLABLE, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.NUMERIC, SqlTypeFamily.ANY)), - "timestamp_add", "dateAdd", "date_add"), - - TIMESTAMP_DIFF("timestampDiff", ReturnTypes.BIGINT_NULLABLE, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY, SqlTypeFamily.ANY)), - "timestamp_diff", "dateDiff", "date_diff"), - + List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER))), + DATE_TRUNC("dateTrunc", ReturnTypes.BIGINT, OperandTypes.family( + List.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, + SqlTypeFamily.CHARACTER), i -> i > 1)), YEAR("year"), - YEAR_OF_WEEK("yearOfWeek", "year_of_week", "yow"), + YEAR_OF_WEEK("yearOfWeek", "yow"), QUARTER("quarter"), - MONTH_OF_YEAR("monthOfYear", "month_of_year", "month"), - WEEK_OF_YEAR("weekOfYear", "week_of_year", "week"), - DAY_OF_YEAR("dayOfYear", "day_of_year", "doy"), - DAY_OF_MONTH("dayOfMonth", "day_of_month", "day"), - DAY_OF_WEEK("dayOfWeek", "day_of_week", "dow"), + MONTH_OF_YEAR("monthOfYear", "month"), + WEEK_OF_YEAR("weekOfYear", "week"), + DAY_OF_YEAR("dayOfYear", "doy"), + DAY_OF_MONTH("dayOfMonth", "day"), + DAY_OF_WEEK("dayOfWeek", "dow"), HOUR("hour"), MINUTE("minute"), SECOND("second"), MILLISECOND("millisecond"), - EXTRACT("extract"), - // string functions - SPLIT("split", ReturnTypes.TO_ARRAY, OperandTypes.family( - ImmutableList.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER), - ordinal -> ordinal > 1), "split", "string_to_array"), - - // array functions + // Array functions // The only column accepted by "cardinality" function is multi-value array, thus putting "cardinality" as alias. // TODO: once we support other types of multiset, we should make CARDINALITY its own function - ARRAY_LENGTH("arrayLength", ReturnTypes.INTEGER, OperandTypes.family(SqlTypeFamily.ARRAY), "array_length", - "cardinality"), - ARRAY_AVERAGE("arrayAverage", ReturnTypes.DOUBLE, OperandTypes.family(SqlTypeFamily.ARRAY), "array_average"), - ARRAY_MIN("arrayMin", ReturnTypes.cascade(opBinding -> positionalComponentReturnType(opBinding, 0), - SqlTypeTransforms.FORCE_NULLABLE), OperandTypes.family(SqlTypeFamily.ARRAY), "array_min"), - ARRAY_MAX("arrayMax", ReturnTypes.cascade(opBinding -> positionalComponentReturnType(opBinding, 0), - SqlTypeTransforms.FORCE_NULLABLE), OperandTypes.family(SqlTypeFamily.ARRAY), "array_max"), - ARRAY_SUM("arraySum", ReturnTypes.DOUBLE, OperandTypes.family(SqlTypeFamily.ARRAY), "array_sum"), - ARRAY_SUM_INT("arraySumInt", ReturnTypes.INTEGER, OperandTypes.family(SqlTypeFamily.ARRAY), "array_sum_int"), - ARRAY_SUM_LONG("arraySumLong", ReturnTypes.BIGINT, OperandTypes.family(SqlTypeFamily.ARRAY), "array_sum_long"), - - VALUE_IN("valueIn", ReturnTypes.ARG0_FORCE_NULLABLE, OperandTypes.variadic(SqlOperandCountRanges.from(2)), - "value_in"), - MAP_VALUE("mapValue", ReturnTypes.cascade(opBinding -> - opBinding.getOperandType(2).getComponentType(), SqlTypeTransforms.FORCE_NULLABLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.ANY)), - "map_value"), - - // special functions - IN_ID_SET("inIdSet", "in_id_set"), + ARRAY_LENGTH("arrayLength", ReturnTypes.INTEGER, OperandTypes.ARRAY, "cardinality"), + ARRAY_AVERAGE("arrayAverage", ReturnTypes.DOUBLE, OperandTypes.ARRAY), + ARRAY_MIN("arrayMin", TransformFunctionType::componentType, OperandTypes.ARRAY), + ARRAY_MAX("arrayMax", TransformFunctionType::componentType, OperandTypes.ARRAY), + ARRAY_SUM("arraySum", ReturnTypes.DOUBLE, OperandTypes.ARRAY), + ARRAY_SUM_INT("arraySumInt", ReturnTypes.INTEGER, OperandTypes.ARRAY), + ARRAY_SUM_LONG("arraySumLong", ReturnTypes.BIGINT, OperandTypes.ARRAY), + + // Special functions + VALUE_IN("valueIn", ReturnTypes.ARG0, OperandTypes.variadic(SqlOperandCountRanges.from(2))), + MAP_VALUE("mapValue", + ReturnTypes.cascade(opBinding -> positionalComponentType(opBinding, 2), SqlTypeTransforms.FORCE_NULLABLE), + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY))), + IN_ID_SET("inIdSet", ReturnTypes.BOOLEAN, OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER)), LOOKUP("lookUp"), GROOVY("groovy"), + SCALAR("scalar"), // CLP functions - CLP_DECODE("clpDecode", ReturnTypes.VARCHAR_2000_NULLABLE, OperandTypes.family( - ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), - ordinal -> ordinal > 2), "clp_decode"), - CLP_ENCODED_VARS_MATCH("clpEncodedVarsMatch", ReturnTypes.BOOLEAN_NOT_NULL, OperandTypes.family( - ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), - ordinal -> ordinal > 2), "clp_encoded_vars_match"), + CLP_DECODE("clpDecode", ReturnTypes.VARCHAR_NULLABLE, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER), + i -> i == 3)), + CLP_ENCODED_VARS_MATCH("clpEncodedVarsMatch", ReturnTypes.BOOLEAN_NOT_NULL, + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER)), // Regexp functions - REGEXP_EXTRACT("regexpExtract", "regexp_extract"), - REGEXP_REPLACE("regexpReplace", - ReturnTypes.VARCHAR_2000_NULLABLE, - OperandTypes.family( - ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, - SqlTypeFamily.INTEGER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER), - ordinal -> ordinal > 2), - "regexp_replace"), - - // Special type for annotation based scalar functions - SCALAR("scalar"), + REGEXP_EXTRACT("regexpExtract", ReturnTypes.VARCHAR_NULLABLE, OperandTypes.family( + List.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.INTEGER, SqlTypeFamily.CHARACTER), + i -> i > 1)), // Geo constructors - ST_GEOG_FROM_TEXT("ST_GeogFromText", ReturnTypes.explicit(SqlTypeName.VARBINARY), OperandTypes.STRING), - ST_GEOM_FROM_TEXT("ST_GeomFromText", ReturnTypes.explicit(SqlTypeName.VARBINARY), OperandTypes.STRING), - ST_GEOG_FROM_WKB("ST_GeogFromWKB", ReturnTypes.explicit(SqlTypeName.VARBINARY), OperandTypes.BINARY), - ST_GEOM_FROM_WKB("ST_GeomFromWKB", ReturnTypes.explicit(SqlTypeName.VARBINARY), OperandTypes.BINARY), - ST_POINT("ST_Point", ReturnTypes.explicit(SqlTypeName.VARBINARY), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.ANY), - ordinal -> ordinal > 1 && ordinal < 4), "stPoint"), - ST_POLYGON("ST_Polygon", ReturnTypes.explicit(SqlTypeName.VARBINARY), OperandTypes.STRING, "stPolygon"), + ST_GEOG_FROM_TEXT("ST_GeogFromText", ReturnTypes.VARBINARY, OperandTypes.CHARACTER), + ST_GEOM_FROM_TEXT("ST_GeomFromText", ReturnTypes.VARBINARY, OperandTypes.CHARACTER), + ST_GEOG_FROM_WKB("ST_GeogFromWKB", ReturnTypes.VARBINARY, OperandTypes.BINARY), + ST_GEOM_FROM_WKB("ST_GeomFromWKB", ReturnTypes.VARBINARY, OperandTypes.BINARY), + ST_POINT("ST_Point", ReturnTypes.VARBINARY, + OperandTypes.family(List.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.ANY), i -> i == 2)), + ST_POLYGON("ST_Polygon", ReturnTypes.VARBINARY, OperandTypes.CHARACTER), // Geo measurements - ST_AREA("ST_Area", ReturnTypes.DOUBLE_NULLABLE, OperandTypes.BINARY, "stArea"), - ST_DISTANCE("ST_Distance", ReturnTypes.DOUBLE_NULLABLE, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.BINARY, SqlTypeFamily.BINARY)), "stDistance"), - ST_GEOMETRY_TYPE("ST_GeometryType", ReturnTypes.VARCHAR_2000_NULLABLE, OperandTypes.BINARY, "stGeometryType"), + ST_AREA("ST_Area", ReturnTypes.DOUBLE, OperandTypes.BINARY), + ST_DISTANCE("ST_Distance", ReturnTypes.DOUBLE, OperandTypes.BINARY_BINARY), + ST_GEOMETRY_TYPE("ST_GeometryType", ReturnTypes.VARCHAR, OperandTypes.BINARY), // Geo outputs - ST_AS_BINARY("ST_AsBinary", ReturnTypes.explicit(SqlTypeName.VARBINARY), OperandTypes.BINARY, "stAsBinary"), - ST_AS_TEXT("ST_AsText", ReturnTypes.VARCHAR_2000_NULLABLE, OperandTypes.BINARY, "stAsText"), + ST_AS_BINARY("ST_AsBinary", ReturnTypes.VARBINARY, OperandTypes.BINARY), + ST_AS_TEXT("ST_AsText", ReturnTypes.VARCHAR, OperandTypes.BINARY), // Geo relationship - ST_CONTAINS("ST_Contains", ReturnTypes.INTEGER, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.BINARY, SqlTypeFamily.BINARY)), "stContains"), - ST_EQUALS("ST_Equals", ReturnTypes.INTEGER, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.BINARY, SqlTypeFamily.BINARY)), "stEquals"), - ST_WITHIN("ST_Within", ReturnTypes.INTEGER, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.BINARY, SqlTypeFamily.BINARY)), "stWithin"), + // TODO: Revisit whether we should return BOOLEAN instead + ST_CONTAINS("ST_Contains", ReturnTypes.INTEGER, OperandTypes.BINARY_BINARY), + ST_EQUALS("ST_Equals", ReturnTypes.INTEGER, OperandTypes.BINARY_BINARY), + ST_WITHIN("ST_Within", ReturnTypes.INTEGER, OperandTypes.BINARY_BINARY), // Geo indexing - GEO_TO_H3("geoToH3", ReturnTypes.explicit(SqlTypeName.BIGINT), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC), - ordinal -> ordinal > 1 && ordinal < 4), "geo_to_h3"), + GEO_TO_H3("geoToH3", ReturnTypes.BIGINT, + OperandTypes.or(OperandTypes.family(SqlTypeFamily.BINARY, SqlTypeFamily.INTEGER), + OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER))), // Vector functions // TODO: Once VECTOR type is defined, we should update here. - COSINE_DISTANCE("cosineDistance", ReturnTypes.explicit(SqlTypeName.DOUBLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC), - ordinal -> ordinal > 1 && ordinal < 4), "cosine_distance"), - INNER_PRODUCT("innerProduct", ReturnTypes.explicit(SqlTypeName.DOUBLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "inner_product"), - L1_DISTANCE("l1Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l1_distance"), - L2_DISTANCE("l2Distance", ReturnTypes.explicit(SqlTypeName.DOUBLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY)), "l2_distance"), - VECTOR_DIMS("vectorDims", ReturnTypes.explicit(SqlTypeName.INTEGER), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_dims"), - VECTOR_NORM("vectorNorm", ReturnTypes.explicit(SqlTypeName.DOUBLE), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY)), "vector_norm"), - - VECTOR_SIMILARITY("vectorSimilarity", ReturnTypes.BOOLEAN_NOT_NULL, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC), - ordinal -> ordinal > 1 && ordinal < 4), "vector_similarity"), - - ARRAY_VALUE_CONSTRUCTOR("arrayValueConstructor", "array_value_constructor"), + COSINE_DISTANCE("cosineDistance", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC), id -> id == 2)), + INNER_PRODUCT("innerProduct", ReturnTypes.DOUBLE, OperandTypes.ARRAY_ARRAY), + L1_DISTANCE("l1Distance", ReturnTypes.DOUBLE, OperandTypes.ARRAY_ARRAY), + L2_DISTANCE("l2Distance", ReturnTypes.DOUBLE, OperandTypes.ARRAY_ARRAY), + VECTOR_DIMS("vectorDims", ReturnTypes.INTEGER, OperandTypes.ARRAY), + VECTOR_NORM("vectorNorm", ReturnTypes.DOUBLE, OperandTypes.ARRAY), // Trigonometry SIN("sin"), @@ -316,49 +227,41 @@ public boolean isDeterministic() { RADIANS("radians"); private final String _name; - private final List _alternativeNames; - private final SqlKind _sqlKind; + private final List _names; private final SqlReturnTypeInference _returnTypeInference; private final SqlOperandTypeChecker _operandTypeChecker; - private final SqlFunctionCategory _sqlFunctionCategory; TransformFunctionType(String name, String... alternativeNames) { - this(name, null, null, null, null, alternativeNames); - } - - TransformFunctionType(String name, SqlReturnTypeInference returnTypeInference, - SqlOperandTypeChecker operandTypeChecker, String... alternativeNames) { - this(name, SqlKind.OTHER_FUNCTION, returnTypeInference, operandTypeChecker, - SqlFunctionCategory.USER_DEFINED_FUNCTION, alternativeNames); + this(name, null, null, alternativeNames); } - /** - * Constructor to use for transform functions which are supported in both v1 and multistage engines - */ - TransformFunctionType(String name, SqlKind sqlKind, SqlReturnTypeInference returnTypeInference, - SqlOperandTypeChecker operandTypeChecker, SqlFunctionCategory sqlFunctionCategory, String... alternativeNames) { + TransformFunctionType(String name, @Nullable SqlReturnTypeInference returnTypeInference, + @Nullable SqlOperandTypeChecker operandTypeChecker, String... alternativeNames) { _name = name; - List all = new ArrayList<>(alternativeNames.length + 2); - all.add(name); - all.add(name()); - all.addAll(Arrays.asList(alternativeNames)); - _alternativeNames = Collections.unmodifiableList(all); - _sqlKind = sqlKind; + int numAlternativeNames = alternativeNames.length; + if (numAlternativeNames == 0) { + _names = List.of(name); + } else { + List names = new ArrayList<>(numAlternativeNames + 1); + names.add(name); + names.addAll(Arrays.asList(alternativeNames)); + _names = List.copyOf(names); + } _returnTypeInference = returnTypeInference; _operandTypeChecker = operandTypeChecker; - _sqlFunctionCategory = sqlFunctionCategory; } public String getName() { return _name; } - public List getAlternativeNames() { - return _alternativeNames; + public List getNames() { + return _names; } - public SqlKind getSqlKind() { - return _sqlKind; + @Deprecated + public List getAlternativeNames() { + return _names; } public SqlReturnTypeInference getReturnTypeInference() { @@ -369,14 +272,6 @@ public SqlOperandTypeChecker getOperandTypeChecker() { return _operandTypeChecker; } - public SqlFunctionCategory getSqlFunctionCategory() { - return _sqlFunctionCategory; - } - - public boolean isDeterministic() { - return true; - } - /** Returns the optional explicit returning type specification. */ private static RelDataType positionalReturnTypeInferenceFromStringLiteral(SqlOperatorBinding opBinding, int pos) { return positionalReturnTypeInferenceFromStringLiteral(opBinding, pos, SqlTypeName.ANY); @@ -384,25 +279,24 @@ private static RelDataType positionalReturnTypeInferenceFromStringLiteral(SqlOpe private static RelDataType positionalReturnTypeInferenceFromStringLiteral(SqlOperatorBinding opBinding, int pos, SqlTypeName defaultSqlType) { - if (opBinding.getOperandCount() > pos - && opBinding.isOperandLiteral(pos, false)) { + if (opBinding.getOperandCount() > pos && opBinding.isOperandLiteral(pos, false)) { String operandType = opBinding.getOperandLiteralValue(pos, String.class).toUpperCase(); return inferTypeFromStringLiteral(operandType, opBinding.getTypeFactory()); } return opBinding.getTypeFactory().createSqlType(defaultSqlType); } - private static RelDataType positionalComponentReturnType(SqlOperatorBinding opBinding, int pos) { - if (opBinding.getOperandCount() > pos) { - return opBinding.getOperandType(pos).getComponentType(); - } - throw new IllegalArgumentException("Invalid number of arguments for function " + opBinding.getOperator().getName()); + private static RelDataType componentType(SqlOperatorBinding opBinding) { + return opBinding.getOperandType(0).getComponentType(); + } + + private static RelDataType positionalComponentType(SqlOperatorBinding opBinding, int pos) { + return opBinding.getOperandType(pos).getComponentType(); } private static RelDataType dateTimeConverterReturnTypeInference(SqlOperatorBinding opBinding) { int outputFormatPos = 2; - if (opBinding.getOperandCount() > outputFormatPos - && opBinding.isOperandLiteral(outputFormatPos, false)) { + if (opBinding.getOperandCount() > outputFormatPos && opBinding.isOperandLiteral(outputFormatPos, false)) { String outputFormatStr = opBinding.getOperandLiteralValue(outputFormatPos, String.class); DateTimeFormatSpec dateTimeFormatSpec = new DateTimeFormatSpec(outputFormatStr); if ((dateTimeFormatSpec.getTimeFormat() == DateTimeFieldSpec.TimeFormat.EPOCH) || ( diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java index b8d0cc240451..94489c92b1ef 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArithmeticFunctions.java @@ -20,14 +20,12 @@ import java.math.BigDecimal; import java.math.RoundingMode; -import org.apache.calcite.linq4j.function.Strict; import org.apache.pinot.spi.annotations.ScalarFunction; /** * Arithmetic scalar functions. */ -@Strict public class ArithmeticFunctions { private ArithmeticFunctions() { } @@ -137,14 +135,14 @@ public static double power(double a, double exponent) { // Big Decimal Implementation has been used here to avoid overflows // when multiplying by Math.pow(10, scale) for rounding - @ScalarFunction(names = {"roundDecimal", "round_decimal"}) + @ScalarFunction public static double roundDecimal(double a, int scale) { return BigDecimal.valueOf(a).setScale(scale, RoundingMode.HALF_UP).doubleValue(); } // TODO: The function should ideally be named 'round' // but it is not possible because of existing DateTimeFunction with same name. - @ScalarFunction(names = {"roundDecimal", "round_decimal"}) + @ScalarFunction public static double roundDecimal(double a) { return Math.round(a); } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java index 13f77b2c5a3d..160db193789e 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ArrayFunctions.java @@ -25,7 +25,6 @@ import it.unimi.dsi.fastutil.objects.ObjectSet; import java.math.BigDecimal; import java.util.Arrays; -import org.apache.calcite.linq4j.function.SemiStrict; import org.apache.commons.lang3.ArrayUtils; import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.utils.CommonConstants.NullValuePlaceHolder; @@ -34,7 +33,6 @@ /** * Inbuilt array scalar functions. See {@link ArrayUtils} for details. */ -@SemiStrict public class ArrayFunctions { private ArrayFunctions() { } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java index ca7a94eb0c07..3a4eef70e856 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ComparisonFunctions.java @@ -18,10 +18,8 @@ */ package org.apache.pinot.common.function.scalar; -import org.apache.calcite.linq4j.function.Strict; import org.apache.pinot.spi.annotations.ScalarFunction; -@Strict public class ComparisonFunctions { private static final double DOUBLE_COMPARISON_TOLERANCE = 1e-7d; @@ -29,27 +27,27 @@ public class ComparisonFunctions { private ComparisonFunctions() { } - @ScalarFunction(names = {"greater_than", "greaterThan"}) + @ScalarFunction public static boolean greaterThan(double a, double b) { return a > b; } - @ScalarFunction(names = {"greater_than_or_equal", "greaterThanOrEqual"}) + @ScalarFunction public static boolean greaterThanOrEqual(double a, double b) { return a >= b; } - @ScalarFunction(names = {"less_than", "lessThan"}) + @ScalarFunction public static boolean lessThan(double a, double b) { return a < b; } - @ScalarFunction(names = {"less_than_or_equal", "lessThanOrEqual"}) + @ScalarFunction public static boolean lessThanOrEqual(double a, double b) { return a <= b; } - @ScalarFunction(names = {"not_equals", "notEquals"}) + @ScalarFunction public static boolean notEquals(double a, double b) { return Math.abs(a - b) >= DOUBLE_COMPARISON_TOLERANCE; } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeConvert.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeConvert.java index 859bc2171f6c..c99b437c2b3d 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeConvert.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeConvert.java @@ -35,7 +35,7 @@ public class DateTimeConvert { private DateTimeFormatSpec _outputFormatSpec; private DateTimeGranularitySpec _granularitySpec; - @ScalarFunction(names = {"dateTimeConvert", "date_time_convert"}) + @ScalarFunction public Object dateTimeConvert(String timeValueStr, String inputFormatStr, String outputFormatStr, String outputGranularityStr) { if (_inputFormatSpec == null) { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java index ac734dd4daf4..ebb31ab2cb23 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DateTimeFunctions.java @@ -21,7 +21,6 @@ import java.sql.Timestamp; import java.time.Duration; import java.util.concurrent.TimeUnit; -import org.apache.calcite.linq4j.function.Strict; import org.apache.pinot.common.function.DateTimePatternHandler; import org.apache.pinot.common.function.DateTimeUtils; import org.apache.pinot.common.function.TimeZoneKey; @@ -70,7 +69,6 @@ * }] * */ -@Strict public class DateTimeFunctions { private DateTimeFunctions() { } @@ -78,12 +76,12 @@ private DateTimeFunctions() { /** * Convert epoch millis to epoch seconds */ - @ScalarFunction(names = {"toEpochSeconds", "to_epoch_seconds"}) + @ScalarFunction public static long toEpochSeconds(long millis) { return TimeUnit.MILLISECONDS.toSeconds(millis); } - @ScalarFunction(names = {"toEpochSecondsMV", "to_epoch_seconds_mv"}) + @ScalarFunction public static long[] toEpochSecondsMV(long[] millis) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -95,12 +93,12 @@ public static long[] toEpochSecondsMV(long[] millis) { /** * Convert epoch millis to epoch minutes */ - @ScalarFunction(names = {"toEpochMinutes", "to_epoch_minutes"}) + @ScalarFunction public static long toEpochMinutes(long millis) { return TimeUnit.MILLISECONDS.toMinutes(millis); } - @ScalarFunction(names = {"toEpochMinutesMV", "to_epoch_minutes_mv"}) + @ScalarFunction public static long[] toEpochMinutesMV(long[] millis) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -112,12 +110,12 @@ public static long[] toEpochMinutesMV(long[] millis) { /** * Convert epoch millis to epoch hours */ - @ScalarFunction(names = {"toEpochHours", "to_epoch_hours"}) + @ScalarFunction public static long toEpochHours(long millis) { return TimeUnit.MILLISECONDS.toHours(millis); } - @ScalarFunction(names = {"toEpochHoursMV", "to_epoch_hours_mv"}) + @ScalarFunction public static long[] toEpochHoursMV(long[] millis) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -129,12 +127,12 @@ public static long[] toEpochHoursMV(long[] millis) { /** * Convert epoch millis to epoch days */ - @ScalarFunction(names = {"toEpochDays", "to_epoch_days"}) + @ScalarFunction public static long toEpochDays(long millis) { return TimeUnit.MILLISECONDS.toDays(millis); } - @ScalarFunction(names = {"toEpochDaysMV", "to_epoch_days_mv"}) + @ScalarFunction public static long[] toEpochDaysMV(long[] millis) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -146,12 +144,12 @@ public static long[] toEpochDaysMV(long[] millis) { /** * Convert epoch millis to epoch seconds, round to nearest rounding bucket */ - @ScalarFunction(names = {"toEpochSecondsRounded", "to_epoch_seconds_rounded"}) + @ScalarFunction public static long toEpochSecondsRounded(long millis, long roundToNearest) { return (TimeUnit.MILLISECONDS.toSeconds(millis) / roundToNearest) * roundToNearest; } - @ScalarFunction(names = {"toEpochSecondsRoundedMV", "to_epoch_seconds_rounded_mv"}) + @ScalarFunction public static long[] toEpochSecondsRoundedMV(long[] millis, long roundToNearest) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -163,12 +161,12 @@ public static long[] toEpochSecondsRoundedMV(long[] millis, long roundToNearest) /** * Convert epoch millis to epoch minutes, round to nearest rounding bucket */ - @ScalarFunction(names = {"toEpochMinutesRounded", "to_epoch_minutes_rounded"}) + @ScalarFunction public static long toEpochMinutesRounded(long millis, long roundToNearest) { return (TimeUnit.MILLISECONDS.toMinutes(millis) / roundToNearest) * roundToNearest; } - @ScalarFunction(names = {"toEpochMinutesRoundedMV", "to_epoch_minutes_rounded_mv"}) + @ScalarFunction public static long[] toEpochMinutesRoundedMV(long[] millis, long roundToNearest) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -180,12 +178,12 @@ public static long[] toEpochMinutesRoundedMV(long[] millis, long roundToNearest) /** * Convert epoch millis to epoch hours, round to nearest rounding bucket */ - @ScalarFunction(names = {"toEpochHoursRounded", "to_epoch_hours_rounded"}) + @ScalarFunction public static long toEpochHoursRounded(long millis, long roundToNearest) { return (TimeUnit.MILLISECONDS.toHours(millis) / roundToNearest) * roundToNearest; } - @ScalarFunction(names = {"toEpochHoursRoundedMV", "to_epoch_hours_rounded_mv"}) + @ScalarFunction public static long[] toEpochHoursRoundedMV(long[] millis, long roundToNearest) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -197,12 +195,12 @@ public static long[] toEpochHoursRoundedMV(long[] millis, long roundToNearest) { /** * Convert epoch millis to epoch days, round to nearest rounding bucket */ - @ScalarFunction(names = {"toEpochDaysRounded", "to_epoch_days_rounded"}) + @ScalarFunction public static long toEpochDaysRounded(long millis, long roundToNearest) { return (TimeUnit.MILLISECONDS.toDays(millis) / roundToNearest) * roundToNearest; } - @ScalarFunction(names = {"toEpochDaysRoundedMV", "to_epoch_days_rounded_mv"}) + @ScalarFunction public static long[] toEpochDaysRoundedMV(long[] millis, long roundToNearest) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -214,12 +212,12 @@ public static long[] toEpochDaysRoundedMV(long[] millis, long roundToNearest) { /** * Convert epoch millis to epoch seconds, divided by given bucket, to get nSecondsSinceEpoch */ - @ScalarFunction(names = {"toEpochSecondsBucket", "to_epoch_seconds_bucket"}) + @ScalarFunction public static long toEpochSecondsBucket(long millis, long bucket) { return TimeUnit.MILLISECONDS.toSeconds(millis) / bucket; } - @ScalarFunction(names = {"toEpochSecondsBucketMV", "to_epoch_seconds_bucket_mv"}) + @ScalarFunction public static long[] toEpochSecondsBucketMV(long[] millis, long bucket) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -231,12 +229,12 @@ public static long[] toEpochSecondsBucketMV(long[] millis, long bucket) { /** * Convert epoch millis to epoch minutes, divided by given bucket, to get nMinutesSinceEpoch */ - @ScalarFunction(names = {"toEpochMinutesBucket", "to_epoch_minutes_bucket"}) + @ScalarFunction public static long toEpochMinutesBucket(long millis, long bucket) { return TimeUnit.MILLISECONDS.toMinutes(millis) / bucket; } - @ScalarFunction(names = {"toEpochMinutesBucketMV", "to_epoch_minutes_bucket_mv"}) + @ScalarFunction public static long[] toEpochMinutesBucketMV(long[] millis, long bucket) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -248,12 +246,12 @@ public static long[] toEpochMinutesBucketMV(long[] millis, long bucket) { /** * Convert epoch millis to epoch hours, divided by given bucket, to get nHoursSinceEpoch */ - @ScalarFunction(names = {"toEpochHoursBucket", "to_epoch_hours_bucket"}) + @ScalarFunction public static long toEpochHoursBucket(long millis, long bucket) { return TimeUnit.MILLISECONDS.toHours(millis) / bucket; } - @ScalarFunction(names = {"toEpochHoursBucketMV", "to_epoch_hours_bucket_mv"}) + @ScalarFunction public static long[] toEpochHoursBucketMV(long[] millis, long bucket) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -265,12 +263,12 @@ public static long[] toEpochHoursBucketMV(long[] millis, long bucket) { /** * Convert epoch millis to epoch days, divided by given bucket, to get nDaysSinceEpoch */ - @ScalarFunction(names = {"toEpochDaysBucket", "to_epoch_days_bucket"}) + @ScalarFunction public static long toEpochDaysBucket(long millis, long bucket) { return TimeUnit.MILLISECONDS.toDays(millis) / bucket; } - @ScalarFunction(names = {"toEpochDaysBucketMV", "to_epoch_days_bucket_mv"}) + @ScalarFunction public static long[] toEpochDaysBucketMV(long[] millis, long bucket) { long[] results = new long[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -282,12 +280,12 @@ public static long[] toEpochDaysBucketMV(long[] millis, long bucket) { /** * Converts epoch seconds to epoch millis */ - @ScalarFunction(names = {"fromEpochSeconds", "from_epoch_seconds"}) + @ScalarFunction public static long fromEpochSeconds(long seconds) { return TimeUnit.SECONDS.toMillis(seconds); } - @ScalarFunction(names = {"fromEpochSecondsMV", "from_epoch_seconds_mv"}) + @ScalarFunction public static long[] fromEpochSecondsMV(long[] seconds) { long[] results = new long[seconds.length]; for (int i = 0; i < seconds.length; i++) { @@ -299,12 +297,12 @@ public static long[] fromEpochSecondsMV(long[] seconds) { /** * Converts epoch minutes to epoch millis */ - @ScalarFunction(names = {"fromEpochMinutes", "from_epoch_minutes"}) + @ScalarFunction public static long fromEpochMinutes(long minutes) { return TimeUnit.MINUTES.toMillis(minutes); } - @ScalarFunction(names = {"fromEpochMinutesMV", "from_epoch_minutes_mv"}) + @ScalarFunction public static long[] fromEpochMinutesMV(long[] minutes) { long[] results = new long[minutes.length]; for (int i = 0; i < minutes.length; i++) { @@ -316,12 +314,12 @@ public static long[] fromEpochMinutesMV(long[] minutes) { /** * Converts epoch hours to epoch millis */ - @ScalarFunction(names = {"fromEpochHours", "from_epoch_hours"}) + @ScalarFunction public static long fromEpochHours(long hours) { return TimeUnit.HOURS.toMillis(hours); } - @ScalarFunction(names = {"fromEpochHoursMV", "from_epoch_hours_mv"}) + @ScalarFunction public static long[] fromEpochHoursMV(long[] hours) { long[] results = new long[hours.length]; for (int i = 0; i < hours.length; i++) { @@ -333,12 +331,12 @@ public static long[] fromEpochHoursMV(long[] hours) { /** * Converts epoch days to epoch millis */ - @ScalarFunction(names = {"fromEpochDays", "from_epoch_days"}) + @ScalarFunction public static long fromEpochDays(long days) { return TimeUnit.DAYS.toMillis(days); } - @ScalarFunction(names = {"fromEpochDaysMV", "from_epoch_days_mv"}) + @ScalarFunction public static long[] fromEpochDaysMV(long[] days) { long[] results = new long[days.length]; for (int i = 0; i < days.length; i++) { @@ -350,12 +348,12 @@ public static long[] fromEpochDaysMV(long[] days) { /** * Converts nSecondsSinceEpoch (seconds that have been divided by a bucket), to epoch millis */ - @ScalarFunction(names = {"fromEpochSecondsBucket", "from_epoch_seconds_bucket"}) + @ScalarFunction public static long fromEpochSecondsBucket(long seconds, long bucket) { return TimeUnit.SECONDS.toMillis(seconds * bucket); } - @ScalarFunction(names = {"fromEpochSecondsBucketMV", "from_epoch_seconds_bucket_mv"}) + @ScalarFunction public static long[] fromEpochSecondsBucketMV(long[] seconds, long bucket) { long[] results = new long[seconds.length]; for (int i = 0; i < seconds.length; i++) { @@ -367,12 +365,12 @@ public static long[] fromEpochSecondsBucketMV(long[] seconds, long bucket) { /** * Converts nMinutesSinceEpoch (minutes that have been divided by a bucket), to epoch millis */ - @ScalarFunction(names = {"fromEpochMinutesBucket", "from_epoch_minutes_bucket"}) + @ScalarFunction public static long fromEpochMinutesBucket(long minutes, long bucket) { return TimeUnit.MINUTES.toMillis(minutes * bucket); } - @ScalarFunction(names = {"fromEpochMinutesBucketMV", "from_epoch_minutes_bucket_mv"}) + @ScalarFunction public static long[] fromEpochMinutesBucketMV(long[] minutes, long bucket) { long[] results = new long[minutes.length]; for (int i = 0; i < minutes.length; i++) { @@ -384,12 +382,12 @@ public static long[] fromEpochMinutesBucketMV(long[] minutes, long bucket) { /** * Converts nHoursSinceEpoch (hours that have been divided by a bucket), to epoch millis */ - @ScalarFunction(names = {"fromEpochHoursBucket", "from_epoch_hours_bucket"}) + @ScalarFunction public static long fromEpochHoursBucket(long hours, long bucket) { return TimeUnit.HOURS.toMillis(hours * bucket); } - @ScalarFunction(names = {"fromEpochHoursBucketMV", "from_epoch_hours_bucket_mv"}) + @ScalarFunction public static long[] fromEpochHoursBucketMV(long[] hours, long bucket) { long[] results = new long[hours.length]; for (int i = 0; i < hours.length; i++) { @@ -401,12 +399,12 @@ public static long[] fromEpochHoursBucketMV(long[] hours, long bucket) { /** * Converts nDaysSinceEpoch (days that have been divided by a bucket), to epoch millis */ - @ScalarFunction(names = {"fromEpochDaysBucket", "from_epoch_days_bucket"}) + @ScalarFunction public static long fromEpochDaysBucket(long days, long bucket) { return TimeUnit.DAYS.toMillis(days * bucket); } - @ScalarFunction(names = {"fromEpochDaysBucketMV", "from_epoch_days_bucket_mv"}) + @ScalarFunction public static long[] fromEpochDaysBucketMV(long[] days, long bucket) { long[] results = new long[days.length]; for (int i = 0; i < days.length; i++) { @@ -418,12 +416,12 @@ public static long[] fromEpochDaysBucketMV(long[] days, long bucket) { /** * Converts epoch millis to Timestamp */ - @ScalarFunction(names = {"toTimestamp", "to_timestamp"}) + @ScalarFunction public static Timestamp toTimestamp(long millis) { return new Timestamp(millis); } - @ScalarFunction(names = {"toTimestampMV", "to_timestamp_mv"}) + @ScalarFunction public static Timestamp[] toTimestampMV(long[] millis) { Timestamp[] results = new Timestamp[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -435,12 +433,12 @@ public static Timestamp[] toTimestampMV(long[] millis) { /** * Converts Timestamp to epoch millis */ - @ScalarFunction(names = {"fromTimestamp", "from_timestamp"}) + @ScalarFunction public static long fromTimestamp(Timestamp timestamp) { return timestamp.getTime(); } - @ScalarFunction(names = {"fromTimestampMV", "from_timestamp_mv"}) + @ScalarFunction public static long[] fromTimestampMV(Timestamp[] timestamp) { long[] results = new long[timestamp.length]; for (int i = 0; i < timestamp.length; i++) { @@ -452,12 +450,12 @@ public static long[] fromTimestampMV(Timestamp[] timestamp) { /** * Converts epoch millis to DateTime string represented by pattern */ - @ScalarFunction(names = {"toDateTime", "to_date_time"}) + @ScalarFunction public static String toDateTime(long millis, String pattern) { return DateTimePatternHandler.parseEpochMillisToDateTimeString(millis, pattern); } - @ScalarFunction(names = {"toDateTimeMV", "to_date_time_mv"}) + @ScalarFunction public static String[] toDateTimeMV(long[] millis, String pattern) { String[] results = new String[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -469,12 +467,12 @@ public static String[] toDateTimeMV(long[] millis, String pattern) { /** * Converts epoch millis to DateTime string represented by pattern and the time zone id. */ - @ScalarFunction(names = {"toDateTime", "to_date_time"}) + @ScalarFunction public static String toDateTime(long millis, String pattern, String timezoneId) { return DateTimePatternHandler.parseEpochMillisToDateTimeString(millis, pattern, timezoneId); } - @ScalarFunction(names = {"toDateTimeMV", "to_date_time_mv"}) + @ScalarFunction public static String[] toDateTimeMV(long[] millis, String pattern, String timezoneId) { String[] results = new String[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -486,12 +484,12 @@ public static String[] toDateTimeMV(long[] millis, String pattern, String timezo /** * Converts DateTime string represented by pattern to epoch millis */ - @ScalarFunction(names = {"fromDateTime", "from_date_time"}) + @ScalarFunction public static long fromDateTime(String dateTimeString, String pattern) { return DateTimePatternHandler.parseDateTimeStringToEpochMillis(dateTimeString, pattern); } - @ScalarFunction(names = {"fromDateTimeMV", "from_date_time_mv"}) + @ScalarFunction public static long[] fromDateTimeMV(String[] dateTimeString, String pattern) { long[] results = new long[dateTimeString.length]; for (int i = 0; i < dateTimeString.length; i++) { @@ -503,17 +501,17 @@ public static long[] fromDateTimeMV(String[] dateTimeString, String pattern) { /** * Converts DateTime string represented by pattern to epoch millis */ - @ScalarFunction(names = {"fromDateTime", "from_date_time"}) + @ScalarFunction public static long fromDateTime(String dateTimeString, String pattern, String timeZoneId) { return DateTimePatternHandler.parseDateTimeStringToEpochMillis(dateTimeString, pattern, timeZoneId); } - @ScalarFunction(names = {"fromDateTime", "from_date_time"}) + @ScalarFunction public static long fromDateTime(String dateTimeString, String pattern, String timeZoneId, long defaultVal) { return DateTimePatternHandler.parseDateTimeStringToEpochMillis(dateTimeString, pattern, timeZoneId, defaultVal); } - @ScalarFunction(names = {"fromDateTimeMV", "from_date_time_mv"}) + @ScalarFunction public static long[] fromDateTimeMV(String[] dateTimeString, String pattern, String timeZoneId) { long[] results = new long[dateTimeString.length]; for (int i = 0; i < dateTimeString.length; i++) { @@ -531,7 +529,7 @@ public static long round(long timeValue, long roundToNearest) { return (timeValue / roundToNearest) * roundToNearest; } - @ScalarFunction(names = {"roundMV", "round_mv"}) + @ScalarFunction public static long[] roundMV(long[] timeValue, long roundToNearest) { long[] results = new long[timeValue.length]; for (int i = 0; i < timeValue.length; i++) { @@ -567,7 +565,7 @@ public static long ago(String periodString) { return System.currentTimeMillis() - period.toMillis(); } - @ScalarFunction(names = {"agoMV", "ago_mv"}) + @ScalarFunction public static long[] agoMV(String[] periodString) { long[] results = new long[periodString.length]; for (int i = 0; i < periodString.length; i++) { @@ -584,12 +582,12 @@ public static long[] agoMV(String[] periodString) { /** * Returns the hour of the time zone offset. */ - @ScalarFunction(names = {"timezoneHour", "timezone_hour"}) + @ScalarFunction public static int timezoneHour(String timezoneId) { return timezoneHour(timezoneId, 0); } - @ScalarFunction(names = {"timezoneHourMV", "timezone_hour_mv"}) + @ScalarFunction public static int[] timezoneHourMV(String[] timezoneId) { int[] results = new int[timezoneId.length]; for (int i = 0; i < timezoneId.length; i++) { @@ -602,12 +600,12 @@ public static int[] timezoneHourMV(String[] timezoneId) { * Returns the hour of the time zone offset, for the UTC timestamp at {@code millis}. This will * properly handle daylight savings time. */ - @ScalarFunction(names = {"timezoneHour", "timezone_hour"}) + @ScalarFunction public static int timezoneHour(String timezoneId, long millis) { return (int) TimeUnit.MILLISECONDS.toHours(DateTimeZone.forID(timezoneId).getOffset(millis)); } - @ScalarFunction(names = {"timezoneHourMV", "timezone_hour_mv"}) + @ScalarFunction public static int[] timezoneHourMV(String timezoneId, long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -619,12 +617,12 @@ public static int[] timezoneHourMV(String timezoneId, long[] millis) { /** * Returns the minute of the time zone offset. */ - @ScalarFunction(names = {"timezoneMinute", "timezone_minute"}) + @ScalarFunction public static int timezoneMinute(String timezoneId) { return timezoneMinute(timezoneId, 0); } - @ScalarFunction(names = {"timezoneMinuteMV", "timezone_minute_mv"}) + @ScalarFunction public static int[] timezoneMinuteMV(String[] timezoneId) { int[] results = new int[timezoneId.length]; for (int i = 0; i < timezoneId.length; i++) { @@ -637,12 +635,12 @@ public static int[] timezoneMinuteMV(String[] timezoneId) { * Returns the minute of the time zone offset, for the UTC timestamp at {@code millis}. This will * properly handle daylight savings time */ - @ScalarFunction(names = {"timezoneMinute", "timezone_minute"}) + @ScalarFunction public static int timezoneMinute(String timezoneId, long millis) { return (int) TimeUnit.MILLISECONDS.toMinutes(DateTimeZone.forID(timezoneId).getOffset(millis)) % 60; } - @ScalarFunction(names = {"timezoneMinuteMV", "timezone_minute_mv"}) + @ScalarFunction public static int[] timezoneMinuteMV(String timezoneId, long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -659,7 +657,7 @@ public static int year(long millis) { return new DateTime(millis, DateTimeZone.UTC).getYear(); } - @ScalarFunction(names = {"yearMV", "year_mv"}) + @ScalarFunction public static int[] yearMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -676,7 +674,7 @@ public static int year(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getYear(); } - @ScalarFunction(names = {"yearMV", "year_mv"}) + @ScalarFunction public static int[] yearMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -688,12 +686,12 @@ public static int[] yearMV(long[] millis, String timezoneId) { /** * Returns the year of the ISO week from the given epoch millis in UTC timezone. */ - @ScalarFunction(names = {"yearOfWeek", "year_of_week", "yow"}) + @ScalarFunction(names = {"yearOfWeek", "yow"}) public static int yearOfWeek(long millis) { return new DateTime(millis, DateTimeZone.UTC).getWeekyear(); } - @ScalarFunction(names = {"yearOfWeekMV", "year_of_week_mv", "yowmv"}) + @ScalarFunction(names = {"yearOfWeekMV", "yowMV"}) public static int[] yearOfWeekMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -705,12 +703,12 @@ public static int[] yearOfWeekMV(long[] millis) { /** * Returns the year of the ISO week from the given epoch millis and timezone id. */ - @ScalarFunction(names = {"yearOfWeek", "year_of_week", "yow"}) + @ScalarFunction(names = {"yearOfWeek", "yow"}) public static int yearOfWeek(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getWeekyear(); } - @ScalarFunction(names = {"yearOfWeekMV", "year_of_week_mv", "yowmv"}) + @ScalarFunction(names = {"yearOfWeekMV", "yowMV"}) public static int[] yearOfWeekMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -724,10 +722,10 @@ public static int[] yearOfWeekMV(long[] millis, String timezoneId) { */ @ScalarFunction public static int quarter(long millis) { - return (monthOfYear(millis) - 1) / 3 + 1; + return (month(millis) - 1) / 3 + 1; } - @ScalarFunction(names = {"quarterMV", "quarter_mv"}) + @ScalarFunction public static int[] quarterMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -741,10 +739,10 @@ public static int[] quarterMV(long[] millis) { */ @ScalarFunction public static int quarter(long millis, String timezoneId) { - return (monthOfYear(millis, timezoneId) - 1) / 3 + 1; + return (month(millis, timezoneId) - 1) / 3 + 1; } - @ScalarFunction(names = {"quarterMV", "quarter_mv"}) + @ScalarFunction public static int[] quarterMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -756,16 +754,16 @@ public static int[] quarterMV(long[] millis, String timezoneId) { /** * Returns the month of the year from the given epoch millis in UTC timezone. The value ranges from 1 to 12. */ - @ScalarFunction(names = {"month", "month_of_year", "monthOfYear"}) - public static int monthOfYear(long millis) { + @ScalarFunction(names = {"month", "monthOfYear"}) + public static int month(long millis) { return new DateTime(millis, DateTimeZone.UTC).getMonthOfYear(); } - @ScalarFunction(names = {"monthMV", "month_of_year_mv", "monthOfYearMV"}) - public static int[] monthOfYearMV(long[] millis) { + @ScalarFunction(names = {"monthMV", "monthOfYearMV"}) + public static int[] monthMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { - results[i] = monthOfYear(millis[i]); + results[i] = month(millis[i]); } return results; } @@ -773,16 +771,16 @@ public static int[] monthOfYearMV(long[] millis) { /** * Returns the month of the year from the given epoch millis and timezone id. The value ranges from 1 to 12. */ - @ScalarFunction(names = {"month", "month_of_year", "monthOfYear"}) - public static int monthOfYear(long millis, String timezoneId) { + @ScalarFunction(names = {"month", "monthOfYear"}) + public static int month(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getMonthOfYear(); } - @ScalarFunction(names = {"monthMV", "month_of_year_mv", "monthOfYearMV"}) - public static int[] monthOfYearMV(long[] millis, String timezoneId) { + @ScalarFunction(names = {"monthMV", "monthOfYearMV"}) + public static int[] monthMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { - results[i] = monthOfYear(millis[i], timezoneId); + results[i] = month(millis[i], timezoneId); } return results; } @@ -790,16 +788,16 @@ public static int[] monthOfYearMV(long[] millis, String timezoneId) { /** * Returns the ISO week of the year from the given epoch millis in UTC timezone.The value ranges from 1 to 53. */ - @ScalarFunction(names = {"weekOfYear", "week_of_year", "week"}) - public static int weekOfYear(long millis) { + @ScalarFunction(names = {"week", "weekOfYear"}) + public static int week(long millis) { return new DateTime(millis, DateTimeZone.UTC).getWeekOfWeekyear(); } - @ScalarFunction(names = {"weekOfYearMV", "week_of_year_mv", "weekMV"}) - public static int[] weekOfYearMV(long[] millis) { + @ScalarFunction(names = {"weekMV", "weekOfYearMV"}) + public static int[] weekMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { - results[i] = weekOfYear(millis[i]); + results[i] = week(millis[i]); } return results; } @@ -807,16 +805,16 @@ public static int[] weekOfYearMV(long[] millis) { /** * Returns the ISO week of the year from the given epoch millis and timezone id. The value ranges from 1 to 53. */ - @ScalarFunction(names = {"weekOfYear", "week_of_year", "week"}) - public static int weekOfYear(long millis, String timezoneId) { + @ScalarFunction(names = {"week", "weekOfYear"}) + public static int week(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getWeekOfWeekyear(); } - @ScalarFunction(names = {"weekOfYearMV", "week_of_year_mv", "weekMV"}) - public static int[] weekOfYearMV(long[] millis, String timezoneId) { + @ScalarFunction(names = {"weekMV", "weekOfYearMV"}) + public static int[] weekMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { - results[i] = weekOfYear(millis[i], timezoneId); + results[i] = week(millis[i], timezoneId); } return results; } @@ -824,12 +822,12 @@ public static int[] weekOfYearMV(long[] millis, String timezoneId) { /** * Returns the day of the year from the given epoch millis in UTC timezone. The value ranges from 1 to 366. */ - @ScalarFunction(names = {"dayOfYear", "day_of_year", "doy"}) + @ScalarFunction(names = {"dayOfYear", "doy"}) public static int dayOfYear(long millis) { return new DateTime(millis, DateTimeZone.UTC).getDayOfYear(); } - @ScalarFunction(names = {"dayOfYearMV", "day_of_year_mv", "doyMV"}) + @ScalarFunction(names = {"dayOfYearMV", "doyMV"}) public static int[] dayOfYear(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -841,12 +839,12 @@ public static int[] dayOfYear(long[] millis) { /** * Returns the day of the year from the given epoch millis and timezone id. The value ranges from 1 to 366. */ - @ScalarFunction(names = {"dayOfYear", "day_of_year", "doy"}) + @ScalarFunction(names = {"dayOfYear", "doy"}) public static int dayOfYear(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getDayOfYear(); } - @ScalarFunction(names = {"dayOfYearMV", "day_of_year_mv", "doyMV"}) + @ScalarFunction(names = {"dayOfYearMV", "doyMV"}) public static int[] dayOfYear(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -858,12 +856,12 @@ public static int[] dayOfYear(long[] millis, String timezoneId) { /** * Returns the day of the month from the given epoch millis in UTC timezone. The value ranges from 1 to 31. */ - @ScalarFunction(names = {"day", "dayOfMonth", "day_of_month"}) + @ScalarFunction(names = {"dayOfMonth", "day"}) public static int dayOfMonth(long millis) { return new DateTime(millis, DateTimeZone.UTC).getDayOfMonth(); } - @ScalarFunction(names = {"dayMV", "dayOfMonthMV", "day_of_month_mv"}) + @ScalarFunction(names = {"dayOfMonthMV", "dayMV"}) public static int[] dayOfMonthMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -875,12 +873,12 @@ public static int[] dayOfMonthMV(long[] millis) { /** * Returns the day of the month from the given epoch millis and timezone id. The value ranges from 1 to 31. */ - @ScalarFunction(names = {"day", "dayOfMonth", "day_of_month"}) + @ScalarFunction(names = {"dayOfMonth", "day"}) public static int dayOfMonth(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getDayOfMonth(); } - @ScalarFunction(names = {"dayMV", "dayOfMonthMV", "day_of_month_mv"}) + @ScalarFunction(names = {"dayOfMonthMV", "dayMV"}) public static int[] dayOfMonthMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -893,12 +891,12 @@ public static int[] dayOfMonthMV(long[] millis, String timezoneId) { * Returns the day of the week from the given epoch millis in UTC timezone. The value ranges from 1 (Monday) to 7 * (Sunday). */ - @ScalarFunction(names = {"dayOfWeek", "day_of_week", "dow"}) + @ScalarFunction(names = {"dayOfWeek", "dow"}) public static int dayOfWeek(long millis) { return new DateTime(millis, DateTimeZone.UTC).getDayOfWeek(); } - @ScalarFunction(names = {"dayOfWeekMV", "day_of_week_mv", "dowMV"}) + @ScalarFunction(names = {"dayOfWeekMV", "dowMV"}) public static int[] dayOfWeekMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -911,12 +909,12 @@ public static int[] dayOfWeekMV(long[] millis) { * Returns the day of the week from the given epoch millis and timezone id. The value ranges from 1 (Monday) to 7 * (Sunday). */ - @ScalarFunction(names = {"dayOfWeek", "day_of_week", "dow"}) + @ScalarFunction(names = {"dayOfWeek", "dow"}) public static int dayOfWeek(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getDayOfWeek(); } - @ScalarFunction(names = {"dayOfWeekMV", "day_of_week_mv", "dowMV"}) + @ScalarFunction(names = {"dayOfWeekMV", "dowMV"}) public static int[] dayOfWeekMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -933,7 +931,7 @@ public static int hour(long millis) { return new DateTime(millis, DateTimeZone.UTC).getHourOfDay(); } - @ScalarFunction(names = {"hourMV", "hour_mv"}) + @ScalarFunction public static int[] hourMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -950,7 +948,7 @@ public static int hour(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getHourOfDay(); } - @ScalarFunction(names = {"hourMV", "hour_mv"}) + @ScalarFunction public static int[] hourMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -967,7 +965,7 @@ public static int minute(long millis) { return new DateTime(millis, DateTimeZone.UTC).getMinuteOfHour(); } - @ScalarFunction(names = {"minuteMV", "minute_mv"}) + @ScalarFunction public static int[] minuteMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -984,7 +982,7 @@ public static int minute(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getMinuteOfHour(); } - @ScalarFunction(names = {"minuteMV", "minute_mv"}) + @ScalarFunction public static int[] minuteMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -1001,7 +999,7 @@ public static int second(long millis) { return new DateTime(millis, DateTimeZone.UTC).getSecondOfMinute(); } - @ScalarFunction(names = {"secondMV", "second_mv"}) + @ScalarFunction public static int[] secondMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -1018,7 +1016,7 @@ public static int second(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getSecondOfMinute(); } - @ScalarFunction(names = {"secondMV", "second_mv"}) + @ScalarFunction public static int[] secondMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -1035,7 +1033,7 @@ public static int millisecond(long millis) { return new DateTime(millis, DateTimeZone.UTC).getMillisOfSecond(); } - @ScalarFunction(names = {"millisecondMV", "millisecond_mv"}) + @ScalarFunction public static int[] millisecondMV(long[] millis) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -1052,7 +1050,7 @@ public static int millisecond(long millis, String timezoneId) { return new DateTime(millis, DateTimeZone.forID(timezoneId)).getMillisOfSecond(); } - @ScalarFunction(names = {"millisecondMV", "millisecond_mv"}) + @ScalarFunction public static int[] millisecondMV(long[] millis, String timezoneId) { int[] results = new int[millis.length]; for (int i = 0; i < millis.length; i++) { @@ -1068,13 +1066,13 @@ public static int[] millisecondMV(long[] millis, String timezoneId) { * @param timeValue value to truncate * @return truncated timeValue in TimeUnit.MILLISECONDS */ - @ScalarFunction(names = {"dateTrunc", "date_trunc"}) + @ScalarFunction public static long dateTrunc(String unit, long timeValue) { return dateTrunc(unit, timeValue, TimeUnit.MILLISECONDS.name(), ISOChronology.getInstanceUTC(), TimeUnit.MILLISECONDS.name()); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) + @ScalarFunction public static long[] dateTruncMV(String unit, long[] timeValue) { long[] results = new long[timeValue.length]; for (int i = 0; i < timeValue.length; i++) { @@ -1091,12 +1089,12 @@ public static long[] dateTruncMV(String unit, long[] timeValue) { * @param inputTimeUnit TimeUnit of value, expressed in Java's joda TimeUnit * @return truncated timeValue in same TimeUnit as the input */ - @ScalarFunction(names = {"dateTrunc", "date_trunc"}) + @ScalarFunction public static long dateTrunc(String unit, long timeValue, String inputTimeUnit) { return dateTrunc(unit, timeValue, inputTimeUnit, ISOChronology.getInstanceUTC(), inputTimeUnit); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) + @ScalarFunction public static long[] dateTruncMV(String unit, long[] timeValue, String inputTimeUnit) { long[] results = new long[timeValue.length]; for (int i = 0; i < timeValue.length; i++) { @@ -1114,13 +1112,13 @@ public static long[] dateTruncMV(String unit, long[] timeValue, String inputTime * @param timeZone timezone of the input * @return truncated timeValue in same TimeUnit as the input */ - @ScalarFunction(names = {"dateTrunc", "date_trunc"}) + @ScalarFunction public static long dateTrunc(String unit, long timeValue, String inputTimeUnit, String timeZone) { return dateTrunc(unit, timeValue, inputTimeUnit, DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(timeZone)), inputTimeUnit); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) + @ScalarFunction public static long[] dateTruncMV(String unit, long[] timeValue, String inputTimeUnit, String timeZone) { long[] results = new long[timeValue.length]; for (int i = 0; i < timeValue.length; i++) { @@ -1140,14 +1138,14 @@ public static long[] dateTruncMV(String unit, long[] timeValue, String inputTime * @return truncated timeValue * */ - @ScalarFunction(names = {"dateTrunc", "date_trunc"}) + @ScalarFunction public static long dateTrunc(String unit, long timeValue, String inputTimeUnit, String timeZone, String outputTimeUnit) { return dateTrunc(unit, timeValue, inputTimeUnit, DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(timeZone)), outputTimeUnit); } - @ScalarFunction(names = {"dateTruncMV", "date_trunc_mv"}) + @ScalarFunction public static long[] dateTruncMV(String unit, long[] timeValue, String inputTimeUnit, String timeZone, String outputTimeUnit) { long[] results = new long[timeValue.length]; @@ -1173,7 +1171,7 @@ private static long dateTrunc(String unit, long timeValue, String inputTimeUnit, * @param originTimestamp The origin timestamp from which binning starts. * @return A java.sql.Timestamp aligned to the nearest bin. */ - @ScalarFunction(names = {"dateBin", "date_bin"}) + @ScalarFunction public static Timestamp dateBin(String binWidthStr, Timestamp sourceTimestamp, Timestamp originTimestamp) { long originMillis = originTimestamp.getTime(); long sourceMillis = sourceTimestamp.getTime(); @@ -1210,14 +1208,14 @@ public static long dateBin(String binWidthStr, long sourceMillisEpoch, long orig * @param timestamp * @return */ - @ScalarFunction(names = {"timestampAdd", "timestamp_add", "dateAdd", "date_add"}) + @ScalarFunction(names = {"timestampAdd", "dateAdd"}) public static long timestampAdd(String unit, long interval, long timestamp) { ISOChronology chronology = ISOChronology.getInstanceUTC(); long millis = DateTimeUtils.getTimestampField(chronology, unit).add(timestamp, interval); return millis; } - @ScalarFunction(names = {"timestampAddMV", "timestamp_add_mv", "dateAddMV", "date_add_mv"}) + @ScalarFunction(names = {"timestampAddMV", "dateAddMV"}) public static long[] timestampAddMV(String unit, long interval, long[] timestamp) { long[] results = new long[timestamp.length]; for (int i = 0; i < timestamp.length; i++) { @@ -1234,13 +1232,13 @@ public static long[] timestampAddMV(String unit, long interval, long[] timestamp * @param timestamp2 * @return */ - @ScalarFunction(names = {"timestampDiff", "timestamp_diff", "dateDiff", "date_diff"}) + @ScalarFunction(names = {"timestampDiff", "dateDiff"}) public static long timestampDiff(String unit, long timestamp1, long timestamp2) { ISOChronology chronology = ISOChronology.getInstanceUTC(); return DateTimeUtils.getTimestampField(chronology, unit).getDifferenceAsLong(timestamp2, timestamp1); } - @ScalarFunction(names = {"timestampDiffMV", "timestamp_diff_mv", "dateDiffMV", "date_diff_mv"}) + @ScalarFunction(names = {"timestampDiffMV", "dateDiffMV"}) public static long[] timestampDiffMV(String unit, long[] timestamp1, long timestamp2) { long[] results = new long[timestamp1.length]; for (int i = 0; i < timestamp1.length; i++) { @@ -1249,9 +1247,7 @@ public static long[] timestampDiffMV(String unit, long[] timestamp1, long timest return results; } - @ScalarFunction(names = { - "timestampDiffMVReverse", "timestamp_diff_mv_reverse", "dateDiffMVReverse", "date_diff_mv_reverse" - }) + @ScalarFunction(names = {"timestampDiffMVReverse", "dateDiffMVReverse"}) public static long[] timestampDiffMVReverse(String unit, long timestamp1, long[] timestamp2) { long[] results = new long[timestamp2.length]; for (int i = 0; i < timestamp2.length; i++) { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/GeohashFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/GeohashFunctions.java index 1c784254ce31..429a60a236b4 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/GeohashFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/GeohashFunctions.java @@ -121,7 +121,7 @@ private static void refineInterval(double[] interval, int cd, int mask) { * @param precision * @return the geohash value as a string */ - @ScalarFunction(names = {"encodeGeoHash", "encode_geohash"}) + @ScalarFunction public static String encodeGeoHash(double latitude, double longitude, int precision) { return longHashToStringGeohash(encode(latitude, longitude, precision)); } @@ -131,7 +131,7 @@ public static String encodeGeoHash(double latitude, double longitude, int precis * @param geohash * @return the latitude and longitude as a double array */ - @ScalarFunction(names = {"decodeGeoHash", "decode_geohash"}) + @ScalarFunction public static double[] decodeGeoHash(String geohash) { return decode(geohash); } @@ -141,7 +141,7 @@ public static double[] decodeGeoHash(String geohash) { * @param geohash * @return the latitude as a double */ - @ScalarFunction(names = {"decodeGeoHashLatitude", "decode_geohash_latitude", "decode_geohash_lat"}) + @ScalarFunction(names = {"decodeGeoHashLatitude", "decodeGeoHashLat"}) public static double decodeGeoHashLatitude(String geohash) { double[] latLon = decode(geohash); return latLon[0]; @@ -152,7 +152,7 @@ public static double decodeGeoHashLatitude(String geohash) { * @param geohash * @return the longitude as a double */ - @ScalarFunction(names = {"decodeGeoHashLongitude", "decode_geohash_longitude", "decode_geohash_lon"}) + @ScalarFunction(names = {"decodeGeoHashLongitude", "decodeGeoHashLon"}) public static double decodeGeoHashLongitude(String geohash) { double[] latLon = decode(geohash); return latLon[1]; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/JsonFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/JsonFunctions.java index 659f9e717dd5..5effbe3e5448 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/JsonFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/JsonFunctions.java @@ -76,7 +76,7 @@ public static String toJsonMapStr(@Nullable Map map) /** * Convert object to Json String */ - @ScalarFunction(names = {"jsonFormat", "json_format"}) + @ScalarFunction public static String jsonFormat(Object object) throws JsonProcessingException { return JsonUtils.objectToString(object); @@ -85,7 +85,7 @@ public static String jsonFormat(Object object) /** * Extract object based on Json path */ - @ScalarFunction(names = {"jsonPath", "json_path"}) + @ScalarFunction public static Object jsonPath(Object object, String jsonPath) { if (object instanceof String) { return PARSE_CONTEXT.parse((String) object).read(jsonPath, NO_PREDICATES); @@ -96,7 +96,7 @@ public static Object jsonPath(Object object, String jsonPath) { /** * Extract object array based on Json path */ - @ScalarFunction(names = {"jsonPathArray", "json_path_array"}) + @ScalarFunction public static Object[] jsonPathArray(Object object, String jsonPath) { if (object instanceof String) { return convertObjectToArray(PARSE_CONTEXT.parse((String) object).read(jsonPath, NO_PREDICATES)); @@ -104,7 +104,7 @@ public static Object[] jsonPathArray(Object object, String jsonPath) { return convertObjectToArray(PARSE_CONTEXT.parse(object).read(jsonPath, NO_PREDICATES)); } - @ScalarFunction(nullableParameters = true, names = {"jsonPathArrayDefaultEmpty", "json_path_array_default_empty"}) + @ScalarFunction(nullableParameters = true) public static Object[] jsonPathArrayDefaultEmpty(@Nullable Object object, String jsonPath) { try { Object[] result = object == null ? null : jsonPathArray(object, jsonPath); @@ -129,7 +129,7 @@ private static Object[] convertObjectToArray(Object arrayObject) { * Extract from Json with path to String */ @Nullable - @ScalarFunction(names = {"jsonPathString", "json_path_string"}) + @ScalarFunction public static String jsonPathString(Object object, String jsonPath) throws JsonProcessingException { Object jsonValue = jsonPath(object, jsonPath); @@ -142,7 +142,7 @@ public static String jsonPathString(Object object, String jsonPath) /** * Extract from Json with path to String */ - @ScalarFunction(nullableParameters = true, names = {"jsonPathString", "json_path_string"}) + @ScalarFunction(nullableParameters = true) public static String jsonPathString(@Nullable Object object, String jsonPath, String defaultValue) { try { Object jsonValue = jsonPath(object, jsonPath); @@ -158,7 +158,7 @@ public static String jsonPathString(@Nullable Object object, String jsonPath, St /** * Extract from Json with path to Long */ - @ScalarFunction(names = {"jsonPathLong", "json_path_long"}) + @ScalarFunction public static long jsonPathLong(Object object, String jsonPath) { return jsonPathLong(object, jsonPath, Long.MIN_VALUE); } @@ -166,7 +166,7 @@ public static long jsonPathLong(Object object, String jsonPath) { /** * Extract from Json with path to Long */ - @ScalarFunction(nullableParameters = true, names = {"jsonPathLong", "json_path_long"}) + @ScalarFunction(nullableParameters = true) public static long jsonPathLong(@Nullable Object object, String jsonPath, long defaultValue) { try { Object jsonValue = jsonPath(object, jsonPath); @@ -185,7 +185,7 @@ public static long jsonPathLong(@Nullable Object object, String jsonPath, long d /** * Extract from Json with path to Double */ - @ScalarFunction(names = {"jsonPathDouble", "json_path_double"}) + @ScalarFunction public static double jsonPathDouble(Object object, String jsonPath) { return jsonPathDouble(object, jsonPath, Double.NaN); } @@ -193,7 +193,7 @@ public static double jsonPathDouble(Object object, String jsonPath) { /** * Extract from Json with path to Double */ - @ScalarFunction(nullableParameters = true, names = {"jsonPathDouble", "json_path_double"}) + @ScalarFunction(nullableParameters = true) public static double jsonPathDouble(@Nullable Object object, String jsonPath, double defaultValue) { try { Object jsonValue = jsonPath(object, jsonPath); diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/LogicalFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/LogicalFunctions.java index 648c8439f509..af6bbd95490a 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/LogicalFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/LogicalFunctions.java @@ -18,13 +18,11 @@ */ package org.apache.pinot.common.function.scalar; -import org.apache.calcite.linq4j.function.Strict; import org.apache.pinot.spi.annotations.ScalarFunction; /** * Logical transformation on boolean values. Currently, only not is supported. */ -@Strict public class LogicalFunctions { private LogicalFunctions() { } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java index 17223cb03938..7b8fa82d11fc 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/ObjectFunctions.java @@ -26,17 +26,17 @@ public class ObjectFunctions { private ObjectFunctions() { } - @ScalarFunction(nullableParameters = true, names = {"isNull", "is_null"}) + @ScalarFunction(nullableParameters = true) public static boolean isNull(@Nullable Object obj) { return obj == null; } - @ScalarFunction(nullableParameters = true, names = {"isNotNull", "is_not_null"}) + @ScalarFunction(nullableParameters = true) public static boolean isNotNull(@Nullable Object obj) { return !isNull(obj); } - @ScalarFunction(nullableParameters = true, names = {"isDistinctFrom", "is_distinct_from"}) + @ScalarFunction(nullableParameters = true) public static boolean isDistinctFrom(@Nullable Object obj1, @Nullable Object obj2) { if (obj1 == null && obj2 == null) { return false; @@ -47,7 +47,7 @@ public static boolean isDistinctFrom(@Nullable Object obj1, @Nullable Object obj return !obj1.equals(obj2); } - @ScalarFunction(nullableParameters = true, names = {"isNotDistinctFrom", "is_not_distinct_from"}) + @ScalarFunction(nullableParameters = true) public static boolean isNotDistinctFrom(@Nullable Object obj1, @Nullable Object obj2) { return !isDistinctFrom(obj1, obj2); } @@ -94,34 +94,34 @@ private static Object coalesceVar(Object... objects) { } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Object oe) { return caseWhenVar(c1, o1, oe); } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Object oe) { return caseWhenVar(c1, o1, c2, o2, oe); } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Object oe) { return caseWhenVar(c1, o1, c2, o2, c3, o3, oe); } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Object oe) { return caseWhenVar(c1, o1, c2, o2, c3, o3, c4, o4, oe); } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Object oe) { @@ -129,7 +129,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Object oe) { @@ -137,7 +137,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -146,7 +146,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -155,7 +155,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -164,7 +164,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -174,7 +174,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -184,7 +184,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -196,7 +196,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -208,7 +208,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, @@ -220,7 +220,7 @@ public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullab } @Nullable - @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen", "case_when"}) + @ScalarFunction(nullableParameters = true, names = {"case", "caseWhen"}) public static Object caseWhen(@Nullable Boolean c1, @Nullable Object o1, @Nullable Boolean c2, @Nullable Object o2, @Nullable Boolean c3, @Nullable Object o3, @Nullable Boolean c4, @Nullable Object o4, @Nullable Boolean c5, @Nullable Object o5, @Nullable Boolean c6, @Nullable Object o6, @Nullable Boolean c7, @Nullable Object o7, diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java index 21c086ffb71e..2c1f8617cb3f 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/StringFunctions.java @@ -137,40 +137,27 @@ public static String substring(String input, int beginIndex, int length) { } /** - * Join two input string with seperator in between - * @param input1 - * @param input2 - * @param seperator - * @return The two input strings joined by the seperator + * Joins two input strings with separator in between. */ - @ScalarFunction(names = "concat_ws") - public static String concatws(String seperator, String input1, String input2) { - return concat(input1, input2, seperator); + @ScalarFunction + public static String concatWS(String separator, String input1, String input2) { + return input1 + separator + input2; } /** - * Join two input string with seperator in between - * @param input1 - * @param input2 - * @param seperator - * @return The two input strings joined by the seperator + * Joins two input strings with separator in between. */ @ScalarFunction - public static String concat(String input1, String input2, String seperator) { - String result = input1; - result = result + seperator + input2; - return result; + public static String concat(String input1, String input2, String separator) { + return input1 + separator + input2; } /** - * Join two input string with no seperator in between - * @param input1 - * @param input2 - * @return The two input strings joined + * Joins two input strings with no separator in between. */ @ScalarFunction public static String concat(String input1, String input2) { - return concat(input1, input2, ""); + return input1 + input2; } /** @@ -265,7 +252,7 @@ public static String rightSubStr(String input, int length) { * @param regexp * @return the matched result. */ - @ScalarFunction(names = {"regexp_extract", "regexpExtract"}) + @ScalarFunction public static String regexpExtract(String value, String regexp) { return regexpExtract(value, regexp, 0, ""); } @@ -277,7 +264,7 @@ public static String regexpExtract(String value, String regexp) { * @param group * @return the matched result. */ - @ScalarFunction(names = {"regexp_extract", "regexpExtract"}) + @ScalarFunction public static String regexpExtract(String value, String regexp, int group) { return regexpExtract(value, regexp, group, ""); } @@ -290,7 +277,7 @@ public static String regexpExtract(String value, String regexp, int group) { * @param defaultValue the default value if no match found * @return the matched result */ - @ScalarFunction(names = {"regexp_extract", "regexpExtract"}) + @ScalarFunction public static String regexpExtract(String value, String regexp, int group, String defaultValue) { Pattern p = Pattern.compile(regexp); Matcher matcher = p.matcher(value); @@ -367,7 +354,7 @@ public static int strrpos(String input, String find, int instance) { * @param prefix substring to check if it is the prefix * @return true if string starts with prefix, false o.w. */ - @ScalarFunction(names = {"startsWith", "starts_with"}) + @ScalarFunction public static boolean startsWith(String input, String prefix) { return StringUtils.startsWith(input, prefix); } @@ -378,7 +365,7 @@ public static boolean startsWith(String input, String prefix) { * @param suffix substring to check if it is the prefix * @return true if string ends with prefix, false o.w. */ - @ScalarFunction(names = {"endsWith", "ends_with"}) + @ScalarFunction public static boolean endsWith(String input, String suffix) { return StringUtils.endsWith(input, suffix); } @@ -566,7 +553,7 @@ public static String normalize(String input, String form) { * @param delimiter * @return splits string on specified delimiter and returns an array. */ - @ScalarFunction(names = {"split", "string_to_array"}) + @ScalarFunction(names = {"split", "stringToArray"}) public static String[] split(String input, String delimiter) { return StringUtils.splitByWholeSeparator(input, delimiter); } @@ -577,7 +564,7 @@ public static String[] split(String input, String delimiter) { * @param limit * @return splits string on specified delimiter limiting the number of results till the specified limit */ - @ScalarFunction(names = {"split", "string_to_array"}) + @ScalarFunction(names = {"split", "stringToArray"}) public static String[] split(String input, String delimiter, int limit) { return StringUtils.splitByWholeSeparator(input, delimiter, limit); } @@ -603,7 +590,7 @@ public static String[] prefixes(String input, int maxlength) { * @param prefix the prefix to be prepended to prefix strings generated. e.g. '^' for regex matching * @return generate an array of prefix matchers of the string that are shorter than the specified length. */ - @ScalarFunction(nullableParameters = true, names = {"prefixesWithPrefix", "prefixes_with_prefix"}) + @ScalarFunction(nullableParameters = true) public static String[] prefixesWithPrefix(String input, int maxlength, @Nullable String prefix) { if (prefix == null) { return prefixes(input, maxlength); @@ -637,7 +624,7 @@ public static String[] suffixes(String input, int maxlength) { * @param suffix the suffix string to be appended for suffix strings generated. e.g. '$' for regex matching. * @return generate an array of suffix matchers of the string that are shorter than the specified length. */ - @ScalarFunction(nullableParameters = true, names = {"suffixesWithSuffix", "suffixes_with_suffix"}) + @ScalarFunction(nullableParameters = true) public static String[] suffixesWithSuffix(String input, int maxlength, @Nullable String suffix) { if (suffix == null) { return suffixes(input, maxlength); @@ -694,7 +681,7 @@ public static String[] uniqueNgrams(String input, int minGram, int maxGram) { * @param index we allow negative value for index which indicates the index from the end. * @return splits string on specified delimiter and returns String at specified index from the split. */ - @ScalarFunction(names = {"splitPart", "split_part"}) + @ScalarFunction public static String splitPart(String input, String delimiter, int index) { String[] splitString = StringUtils.splitByWholeSeparator(input, delimiter); if (index >= 0 && index < splitString.length) { @@ -853,7 +840,7 @@ public static byte[] fromBase64(String input) { * i -> Case insensitive * @return replaced input string */ - @ScalarFunction(names = {"regexpReplace", "regexp_replace"}) + @ScalarFunction public static String regexpReplace(String inputStr, String matchStr, String replaceStr, int matchStartPos, int occurence, String flag) { Integer patternFlag; @@ -907,7 +894,7 @@ public static String regexpReplace(String inputStr, String matchStr, String repl * @param replaceStr Regexp or string to replace if matchStr is found * @return replaced input string */ - @ScalarFunction(names = {"regexpReplace", "regexp_replace"}) + @ScalarFunction public static String regexpReplace(String inputStr, String matchStr, String replaceStr) { return regexpReplace(inputStr, matchStr, replaceStr, 0, -1, ""); } @@ -922,7 +909,7 @@ public static String regexpReplace(String inputStr, String matchStr, String repl * @param matchStartPos Index of inputStr from where matching should start. Default is 0. * @return replaced input string */ - @ScalarFunction(names = {"regexpReplace", "regexp_replace"}) + @ScalarFunction public static String regexpReplace(String inputStr, String matchStr, String replaceStr, int matchStartPos) { return regexpReplace(inputStr, matchStr, replaceStr, matchStartPos, -1, ""); } @@ -938,13 +925,13 @@ public static String regexpReplace(String inputStr, String matchStr, String repl * at 0. Default is -1 * @return replaced input string */ - @ScalarFunction(names = {"regexpReplace", "regexp_replace"}) + @ScalarFunction public static String regexpReplace(String inputStr, String matchStr, String replaceStr, int matchStartPos, int occurence) { return regexpReplace(inputStr, matchStr, replaceStr, matchStartPos, occurence, ""); } - @ScalarFunction(names = {"regexpLike", "regexp_like"}) + @ScalarFunction public static boolean regexpLike(String inputStr, String regexPatternStr) { Pattern pattern = Pattern.compile(regexPatternStr, Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE); return pattern.matcher(inputStr).find(); @@ -965,7 +952,7 @@ public static boolean like(String inputStr, String likePatternStr) { * @return true in case of valid json parsing else false * */ - @ScalarFunction(names = {"isJson", "is_json"}) + @ScalarFunction public static boolean isJson(String inputStr) { try { JsonUtils.stringToJsonNode(inputStr); diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/TrigonometricFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/TrigonometricFunctions.java index c3b327965f8e..c6214baadc1d 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/TrigonometricFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/TrigonometricFunctions.java @@ -18,10 +18,8 @@ */ package org.apache.pinot.common.function.scalar; -import org.apache.calcite.linq4j.function.Strict; import org.apache.pinot.spi.annotations.ScalarFunction; -@Strict public class TrigonometricFunctions { private TrigonometricFunctions() { } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/VectorFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/VectorFunctions.java index 2de34d15c917..ba743cc0c4ec 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/VectorFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/VectorFunctions.java @@ -39,7 +39,7 @@ private VectorFunctions() { * @param vector2 vector2 * @return cosine distance */ - @ScalarFunction(names = {"cosinedistance", "cosine_distance"}) + @ScalarFunction public static double cosineDistance(float[] vector1, float[] vector2) { return cosineDistance(vector1, vector2, Double.NaN); } @@ -51,7 +51,7 @@ public static double cosineDistance(float[] vector1, float[] vector2) { * @param defaultValue default value when either vector has a norm of 0 * @return cosine distance */ - @ScalarFunction(names = {"cosinedistance", "cosine_distance"}) + @ScalarFunction public static double cosineDistance(float[] vector1, float[] vector2, double defaultValue) { validateVectors(vector1, vector2); double dotProduct = 0.0; @@ -74,7 +74,7 @@ public static double cosineDistance(float[] vector1, float[] vector2, double def * @param vector2 vector2 * @return inner product */ - @ScalarFunction(names = {"innerproduct", "inner_product"}) + @ScalarFunction public static double innerProduct(float[] vector1, float[] vector2) { validateVectors(vector1, vector2); double dotProduct = 0.0; @@ -90,7 +90,7 @@ public static double innerProduct(float[] vector1, float[] vector2) { * @param vector2 vector2 * @return L2 distance */ - @ScalarFunction(names = {"l2distance", "l2_distance"}) + @ScalarFunction public static double l2Distance(float[] vector1, float[] vector2) { validateVectors(vector1, vector2); double distance = 0.0; @@ -106,7 +106,7 @@ public static double l2Distance(float[] vector1, float[] vector2) { * @param vector2 vector2 * @return L1 distance */ - @ScalarFunction(names = {"l1distance", "l1_distance"}) + @ScalarFunction public static double l1Distance(float[] vector1, float[] vector2) { validateVectors(vector1, vector2); double distance = 0.0; @@ -122,7 +122,7 @@ public static double l1Distance(float[] vector1, float[] vector2) { * @param vector2 vector2 * @return Euclidean distance */ - @ScalarFunction(names = {"euclideandistance", "euclidean_distance"}) + @ScalarFunction public static double euclideanDistance(float[] vector1, float[] vector2) { validateVectors(vector1, vector2); double distance = 0; @@ -138,7 +138,7 @@ public static double euclideanDistance(float[] vector1, float[] vector2) { * @param vector2 vector2 * @return dot product */ - @ScalarFunction(names = {"dotproduct", "dot_product"}) + @ScalarFunction public static double dotProduct(float[] vector1, float[] vector2) { validateVectors(vector1, vector2); double dotProduct = 0.0; @@ -153,7 +153,7 @@ public static double dotProduct(float[] vector1, float[] vector2) { * @param vector input vector * @return number of dimensions */ - @ScalarFunction(names = {"vectordims", "vector_dims"}) + @ScalarFunction public static int vectorDims(float[] vector) { validateVector(vector); return vector.length; @@ -164,7 +164,7 @@ public static int vectorDims(float[] vector) { * @param vector input vector * @return norm */ - @ScalarFunction(names = {"vectornorm", "vector_norm"}) + @ScalarFunction public static double vectorNorm(float[] vector) { validateVector(vector); double norm = 0.0; diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java new file mode 100644 index 000000000000..3c497e424ade --- /dev/null +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlAggFunction.java @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function.sql; + +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.util.Optionality; + + +/** + * Pinot custom SqlAggFunction to be registered into SqlOperatorTable. + */ +public class PinotSqlAggFunction extends SqlAggFunction { + + public PinotSqlAggFunction(String name, SqlKind kind, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker, SqlFunctionCategory functionCategory) { + super(name.toUpperCase(), null, kind, returnTypeInference, null, operandTypeChecker, functionCategory, false, false, + Optionality.FORBIDDEN); + } + + public PinotSqlAggFunction(String name, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker) { + this(name, SqlKind.OTHER_FUNCTION, returnTypeInference, operandTypeChecker, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/PinotSqlTransformFunction.java b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlFunction.java similarity index 57% rename from pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/PinotSqlTransformFunction.java rename to pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlFunction.java index ffd5b49667eb..628d739bef7b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/PinotSqlTransformFunction.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/sql/PinotSqlFunction.java @@ -16,32 +16,35 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.pinot.calcite.sql; +package org.apache.pinot.common.function.sql; import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlOperandTypeChecker; -import org.apache.calcite.sql.type.SqlOperandTypeInference; import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.checkerframework.checker.nullness.qual.Nullable; /** - * Pinot SqlAggFunction class to register the Pinot aggregation functions with the Calcite operator table. + * Pinot custom SqlFunction to be registered into SqlOperatorTable. */ -public class PinotSqlTransformFunction extends SqlFunction { - private final boolean _isDeterministic; +public class PinotSqlFunction extends SqlFunction { + private final boolean _deterministic; - public PinotSqlTransformFunction(String name, SqlKind kind, @Nullable SqlReturnTypeInference returnTypeInference, - @Nullable SqlOperandTypeInference operandTypeInference, @Nullable SqlOperandTypeChecker operandTypeChecker, - SqlFunctionCategory category, boolean isDeterministic) { - super(name, kind, returnTypeInference, operandTypeInference, operandTypeChecker, category); - _isDeterministic = isDeterministic; + public PinotSqlFunction(String name, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker, boolean deterministic) { + super(name.toUpperCase(), SqlKind.OTHER_FUNCTION, returnTypeInference, null, operandTypeChecker, + SqlFunctionCategory.USER_DEFINED_FUNCTION); + _deterministic = deterministic; + } + + public PinotSqlFunction(String name, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker) { + this(name, returnTypeInference, operandTypeChecker, true); } @Override public boolean isDeterministic() { - return _isDeterministic; + return _deterministic; } } diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java index f1a3b6aaf0f1..6a47fa827ef7 100644 --- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java +++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/CompileTimeFunctionsInvoker.java @@ -68,9 +68,9 @@ protected static Expression invokeCompileTimeFunctionExpression(@Nullable Expres } operands.set(i, operand); } - String functionName = function.getOperator(); if (compilable) { - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName, numOperands); + String canonicalName = FunctionRegistry.canonicalize(function.getOperator()); + FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, numOperands); if (functionInfo != null) { Object[] arguments = new Object[numOperands]; for (int i = 0; i < numOperands; i++) { diff --git a/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java b/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java deleted file mode 100644 index 600016fd8f98..000000000000 --- a/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionDefinitionRegistryTest.java +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.common.function; - -import com.google.common.collect.ImmutableList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.regex.Pattern; -import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.apache.pinot.spi.annotations.ScalarFunction; -import org.apache.pinot.sql.FilterKind; -import org.testng.annotations.Test; - -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertNotNull; -import static org.testng.Assert.assertNull; -import static org.testng.Assert.assertTrue; - - -public class FunctionDefinitionRegistryTest { - private static final List IGNORED_FUNCTION_NAME_PATTERNS = ImmutableList.of( - // Geo functions are defined in pinot-core - Pattern.compile("st_.*") - ); - private static final List IGNORED_FUNCTION_NAMES = ImmutableList.of( - // Geo functions are defined in pinot-core - "geotoh3", - // ArrayToMV, ArrayValueConstructor and GenerateArray are placeholder functions without implementation - "arraytomv", "arrayvalueconstructor", "generatearray", - // Scalar function - "scalar", - // Functions without scalar function counterpart as of now - "arraylength", "arrayaverage", "arraymin", "arraymax", "arraysum", "arraysumint", "arraysumlong", - "clpdecode", "clpencodedvarsmatch", "groovy", "inidset", - "jsonextractscalar", "jsonextractindex", "jsonextractkey", - "lookup", "mapvalue", "timeconvert", "valuein", "datetimeconvertwindowhop", - // functions not needed for register b/c they are in std sql table or they will not be composed directly. - "in", "not_in", "and", "or", "range", "extract", "is_true", "is_not_true", "is_false", "is_not_false" - ); - - @Test - public void testIsAggFunc() { - assertTrue(AggregationFunctionType.isAggregationFunction("count")); - assertTrue(AggregationFunctionType.isAggregationFunction("percentileRawEstMV")); - assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILERAWESTMV")); - assertTrue(AggregationFunctionType.isAggregationFunction("percentilerawestmv")); - assertTrue(AggregationFunctionType.isAggregationFunction("percentile_raw_est_mv")); - assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILE_RAW_EST_MV")); - assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILEEST90")); - assertTrue(AggregationFunctionType.isAggregationFunction("percentileest90")); - assertFalse(AggregationFunctionType.isAggregationFunction("toEpochSeconds")); - } - - @Test - public void testCalciteFunctionMapAllRegistered() { - Set registeredCalciteFunctionNameIgnoreCase = new HashSet<>(); - for (String funcNames : FunctionRegistry.getRegisteredCalciteFunctionNames()) { - registeredCalciteFunctionNameIgnoreCase.add(funcNames.toLowerCase()); - } - for (TransformFunctionType enumType : TransformFunctionType.values()) { - if (!isIgnored(enumType.getName().toLowerCase())) { - for (String funcName : enumType.getAlternativeNames()) { - assertTrue(registeredCalciteFunctionNameIgnoreCase.contains(funcName.toLowerCase()), - "Unable to find transform function signature for: " + funcName); - } - } - } - for (FilterKind enumType : FilterKind.values()) { - if (!isIgnored(enumType.name().toLowerCase())) { - assertTrue(registeredCalciteFunctionNameIgnoreCase.contains(enumType.name().toLowerCase()), - "Unable to find filter function signature for: " + enumType.name()); - } - } - } - - private boolean isIgnored(String funcName) { - if (IGNORED_FUNCTION_NAMES.contains(funcName)) { - return true; - } - for (Pattern signature : IGNORED_FUNCTION_NAME_PATTERNS) { - if (signature.matcher(funcName).find()) { - return true; - } - } - return false; - } - - @ScalarFunction(names = {"testFunc1", "testFunc2"}) - public static String testScalarFunction(long randomArg1, String randomArg2) { - return null; - } - - @Test - public void testScalarFunctionNames() { - assertNotNull(FunctionRegistry.getFunctionInfo("testFunc1", 2)); - assertNotNull(FunctionRegistry.getFunctionInfo("testFunc2", 2)); - assertNull(FunctionRegistry.getFunctionInfo("testScalarFunction", 2)); - assertNull(FunctionRegistry.getFunctionInfo("testFunc1", 1)); - assertNull(FunctionRegistry.getFunctionInfo("testFunc2", 1)); - } -} diff --git a/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotQueryResource.java b/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotQueryResource.java index d86dc7faed75..55ea95212169 100644 --- a/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotQueryResource.java +++ b/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotQueryResource.java @@ -53,7 +53,6 @@ import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.helix.model.InstanceConfig; -import org.apache.pinot.calcite.jdbc.CalciteSchemaBuilder; import org.apache.pinot.common.Utils; import org.apache.pinot.common.exception.QueryException; import org.apache.pinot.common.response.ProcessingException; @@ -70,10 +69,7 @@ import org.apache.pinot.core.auth.ManualAuthorization; import org.apache.pinot.core.query.executor.sql.SqlQueryExecutor; import org.apache.pinot.query.QueryEnvironment; -import org.apache.pinot.query.catalog.PinotCatalog; import org.apache.pinot.query.parser.utils.ParserUtils; -import org.apache.pinot.query.type.TypeFactory; -import org.apache.pinot.query.type.TypeSystem; import org.apache.pinot.spi.config.table.TableConfig; import org.apache.pinot.spi.exception.DatabaseConflictException; import org.apache.pinot.spi.utils.CommonConstants; @@ -114,7 +110,7 @@ public String handlePostSql(String requestJsonStr, @Context HttpHeaders httpHead JsonNode requestJson = JsonUtils.stringToJsonNode(requestJsonStr); if (!requestJson.has("sql")) { return constructQueryExceptionResponse(QueryException.getException(QueryException.JSON_PARSING_ERROR, - "JSON Payload is missing the query string field 'sql'")); + "JSON Payload is missing the query string field 'sql'")); } String sqlQuery = requestJson.get("sql").asText(); String traceEnabled = "false"; @@ -200,7 +196,8 @@ private String executeSqlQuery(@Context HttpHeaders httpHeaders, String sqlQuery } private String getMultiStageQueryResponse(String query, String queryOptions, HttpHeaders httpHeaders, - String endpointUrl, String traceEnabled) throws ProcessingException { + String endpointUrl, String traceEnabled) + throws ProcessingException { // Validate data access // we don't have a cross table access control rule so only ADMIN can make request to multi-stage engine. @@ -214,9 +211,8 @@ private String getMultiStageQueryResponse(String query, String queryOptions, Htt queryOptionsMap.putAll(RequestUtils.getOptionsFromString(queryOptions)); } String database = DatabaseUtils.extractDatabaseFromQueryRequest(queryOptionsMap, httpHeaders); - QueryEnvironment queryEnvironment = new QueryEnvironment(new TypeFactory(new TypeSystem()), - CalciteSchemaBuilder.asRootSchema(new PinotCatalog(database, _pinotHelixResourceManager.getTableCache()), - database), null, null); + QueryEnvironment queryEnvironment = + new QueryEnvironment(database, _pinotHelixResourceManager.getTableCache(), null); List tableNames; try { tableNames = queryEnvironment.getTableNamesForQuery(query); @@ -235,17 +231,18 @@ private String getMultiStageQueryResponse(String query, String queryOptions, Htt // find the unions of all the broker tenant tags of the queried tables. Set brokerTenantsUnion = getBrokerTenantsUnion(tableConfigList); if (brokerTenantsUnion.isEmpty()) { - return QueryException.getException(QueryException.BROKER_REQUEST_SEND_ERROR, new Exception( - String.format("Unable to dispatch multistage query for tables: [%s]", tableNames))).toString(); + return QueryException.getException(QueryException.BROKER_REQUEST_SEND_ERROR, + new Exception(String.format("Unable to dispatch multistage query for tables: [%s]", tableNames))) + .toString(); } instanceIds = findCommonBrokerInstances(brokerTenantsUnion); if (instanceIds.isEmpty()) { // No common broker found for table tenants - LOGGER.error("Unable to find a common broker instance for table tenants. Tables: {}, Tenants: {}", - tableNames, brokerTenantsUnion); - throw QueryException.getException(QueryException.BROKER_RESOURCE_MISSING_ERROR, - new Exception("Unable to find a common broker instance for table tenants. Tables: " - + tableNames + ", Tenants: " + brokerTenantsUnion)); + LOGGER.error("Unable to find a common broker instance for table tenants. Tables: {}, Tenants: {}", tableNames, + brokerTenantsUnion); + throw QueryException.getException(QueryException.BROKER_RESOURCE_MISSING_ERROR, new Exception( + "Unable to find a common broker instance for table tenants. Tables: " + tableNames + ", Tenants: " + + brokerTenantsUnion)); } } else { // TODO fail these queries going forward. Added this logic to take care of tautologies like BETWEEN 0 and -1. @@ -257,7 +254,8 @@ private String getMultiStageQueryResponse(String query, String queryOptions, Htt } private String getQueryResponse(String query, @Nullable SqlNode sqlNode, String traceEnabled, String queryOptions, - HttpHeaders httpHeaders) throws ProcessingException { + HttpHeaders httpHeaders) + throws ProcessingException { // Get resource table name. String tableName; Map queryOptionsMap = RequestUtils.parseQuery(query).getOptions(); @@ -279,8 +277,7 @@ private String getQueryResponse(String query, @Nullable SqlNode sqlNode, String LOGGER.error("Caught exception while compiling query: {}", query, e); // Check if the query is a v2 supported query - if (ParserUtils.canCompileQueryUsingV2Engine(query, CalciteSchemaBuilder.asRootSchema( - new PinotCatalog(database, _pinotHelixResourceManager.getTableCache()), database))) { + if (ParserUtils.canCompileWithMultiStageEngine(query, database, _pinotHelixResourceManager.getTableCache())) { return QueryException.getException(QueryException.SQL_PARSING_ERROR, new Exception( "It seems that the query is only supported by the multi-stage query engine, please retry the query using " + "the multi-stage query engine " @@ -322,7 +319,8 @@ private List getListTableConfigs(List tableNames) { return allTableConfigList; } - private String selectRandomInstanceId(List instanceIds) throws ProcessingException { + private String selectRandomInstanceId(List instanceIds) + throws ProcessingException { if (instanceIds.isEmpty()) { throw QueryException.getException(QueryException.BROKER_RESOURCE_MISSING_ERROR, "No broker found for query"); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/function/scalar/SketchFunctions.java b/pinot-core/src/main/java/org/apache/pinot/core/function/scalar/SketchFunctions.java index 90e313edb268..a3b4b65f021e 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/function/scalar/SketchFunctions.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/function/scalar/SketchFunctions.java @@ -208,52 +208,52 @@ public static byte[] toIntegerSumTupleSketch(@Nullable Object key, Integer value return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.serialize(is.compact()); } - @ScalarFunction(names = {"getThetaSketchEstimate", "get_theta_sketch_estimate"}) + @ScalarFunction public static long getThetaSketchEstimate(Object sketchObject) { return Math.round(asThetaSketch(sketchObject).getEstimate()); } - @ScalarFunction(names = {"thetaSketchUnion", "theta_sketch_union"}) + @ScalarFunction public static Sketch thetaSketchUnion(Object o1, Object o2) { return thetaSketchUnionVar(o1, o2); } - @ScalarFunction(names = {"thetaSketchUnion", "theta_sketch_union"}) + @ScalarFunction public static Sketch thetaSketchUnion(Object o1, Object o2, Object o3) { return thetaSketchUnionVar(o1, o2, o3); } - @ScalarFunction(names = {"thetaSketchUnion", "theta_sketch_union"}) + @ScalarFunction public static Sketch thetaSketchUnion(Object o1, Object o2, Object o3, Object o4) { return thetaSketchUnionVar(o1, o2, o3, o4); } - @ScalarFunction(names = {"thetaSketchUnion", "theta_sketch_union"}) + @ScalarFunction public static Sketch thetaSketchUnion(Object o1, Object o2, Object o3, Object o4, Object o5) { return thetaSketchUnionVar(o1, o2, o3, o4, o5); } - @ScalarFunction(names = {"thetaSketchIntersect", "theta_sketch_intersect"}) + @ScalarFunction public static Sketch thetaSketchIntersect(Object o1, Object o2) { return thetaSketchIntersectVar(o1, o2); } - @ScalarFunction(names = {"thetaSketchIntersect", "theta_sketch_intersect"}) + @ScalarFunction public static Sketch thetaSketchIntersect(Object o1, Object o2, Object o3) { return thetaSketchIntersectVar(o1, o2, o3); } - @ScalarFunction(names = {"thetaSketchIntersect", "theta_sketch_intersect"}) + @ScalarFunction public static Sketch thetaSketchIntersect(Object o1, Object o2, Object o3, Object o4) { return thetaSketchIntersectVar(o1, o2, o3, o4); } - @ScalarFunction(names = {"thetaSketchIntersect", "theta_sketch_intersect"}) + @ScalarFunction public static Sketch thetaSketchIntersect(Object o1, Object o2, Object o3, Object o4, Object o5) { return thetaSketchIntersectVar(o1, o2, o3, o4, o5); } - @ScalarFunction(names = {"thetaSketchDiff", "theta_sketch_diff"}) + @ScalarFunction public static Sketch thetaSketchDiff(Object sketchObjectA, Object sketchObjectB) { AnotB diff = SET_OPERATION_BUILDER.buildANotB(); diff.setA(asThetaSketch(sketchObjectA)); @@ -261,7 +261,7 @@ public static Sketch thetaSketchDiff(Object sketchObjectA, Object sketchObjectB) return diff.getResult(false, null, false); } - @ScalarFunction(names = {"thetaSketchToString", "theta_sketch_to_string"}) + @ScalarFunction public static String thetaSketchToString(Object sketchObject) { return asThetaSketch(sketchObject).toString(); } @@ -296,32 +296,32 @@ private static Sketch asThetaSketch(Object sketchObj) { } } - @ScalarFunction(names = {"intSumTupleSketchUnion", "int_sum_tuple_sketch_union"}) + @ScalarFunction public static byte[] intSumTupleSketchUnion(Object o1, Object o2) { return intSumTupleSketchUnion((int) Math.pow(2, CommonConstants.Helix.DEFAULT_TUPLE_SKETCH_LGK), o1, o2); } - @ScalarFunction(names = {"intSumTupleSketchUnion", "int_sum_tuple_sketch_union"}) + @ScalarFunction public static byte[] intSumTupleSketchUnion(int nomEntries, Object o1, Object o2) { return intTupleSketchUnionVar(IntegerSummary.Mode.Sum, nomEntries, o1, o2); } - @ScalarFunction(names = {"intMinTupleSketchUnion", "int_min_tuple_sketch_union"}) + @ScalarFunction public static byte[] intMinTupleSketchUnion(Object o1, Object o2) { return intMinTupleSketchUnion((int) Math.pow(2, CommonConstants.Helix.DEFAULT_TUPLE_SKETCH_LGK), o1, o2); } - @ScalarFunction(names = {"intMinTupleSketchUnion", "int_min_tuple_sketch_union"}) + @ScalarFunction public static byte[] intMinTupleSketchUnion(int nomEntries, Object o1, Object o2) { return intTupleSketchUnionVar(IntegerSummary.Mode.Min, nomEntries, o1, o2); } - @ScalarFunction(names = {"intMaxTupleSketchUnion", "int_max_tuple_sketch_union"}) + @ScalarFunction public static byte[] intMaxTupleSketchUnion(Object o1, Object o2) { return intMaxTupleSketchUnion((int) Math.pow(2, CommonConstants.Helix.DEFAULT_TUPLE_SKETCH_LGK), o1, o2); } - @ScalarFunction(names = {"intMaxTupleSketchUnion", "int_max_tuple_sketch_union"}) + @ScalarFunction public static byte[] intMaxTupleSketchUnion(int nomEntries, Object o1, Object o2) { return intTupleSketchUnionVar(IntegerSummary.Mode.Max, nomEntries, o1, o2); } @@ -335,17 +335,17 @@ private static byte[] intTupleSketchUnionVar(IntegerSummary.Mode mode, int nomEn return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.serialize(union.getResult().compact()); } - @ScalarFunction(names = {"intSumTupleSketchIntersect", "int_sum_tuple_sketch_intersect"}) + @ScalarFunction public static byte[] intSumTupleSketchIntersect(Object o1, Object o2) { return intTupleSketchIntersectVar(IntegerSummary.Mode.Sum, o1, o2); } - @ScalarFunction(names = {"intMinTupleSketchIntersect", "int_min_tuple_sketch_intersect"}) + @ScalarFunction public static byte[] intMinTupleSketchIntersect(Object o1, Object o2) { return intTupleSketchIntersectVar(IntegerSummary.Mode.Min, o1, o2); } - @ScalarFunction(names = {"intMaxTupleSketchIntersect", "int_max_tuple_sketch_intersect"}) + @ScalarFunction public static byte[] intMaxTupleSketchIntersect(Object o1, Object o2) { return intTupleSketchIntersectVar(IntegerSummary.Mode.Max, o1, o2); } @@ -359,7 +359,7 @@ private static byte[] intTupleSketchIntersectVar(IntegerSummary.Mode mode, Objec return ObjectSerDeUtils.DATA_SKETCH_INT_TUPLE_SER_DE.serialize(intersection.getResult().compact()); } - @ScalarFunction(names = {"intTupleSketchDiff", "int_tuple_sketch_diff"}) + @ScalarFunction public static byte[] intSumTupleSketchDiff(Object o1, Object o2) { org.apache.datasketches.tuple.AnotB diff = new org.apache.datasketches.tuple.AnotB<>(); diff.setA(asIntegerSketch(o1)); @@ -381,7 +381,7 @@ private static org.apache.datasketches.tuple.Sketch asIntegerSke } } - @ScalarFunction(names = {"getIntTupleSketchEstimate", "get_int_tuple_sketch_estimate"}) + @ScalarFunction public static long getIntTupleSketchEstimate(Object o1) { return Math.round(asIntegerSketch(o1).getEstimate()); } @@ -397,32 +397,32 @@ public static byte[] toCpcSketch(@Nullable Object input) { return toCpcSketch(input, CommonConstants.Helix.DEFAULT_CPC_SKETCH_LGK); } - @ScalarFunction(names = {"getCpcSketchEstimate", "get_cpc_sketch_estimate"}) + @ScalarFunction public static long getCpcSketchEstimate(Object o1) { return Math.round(asCpcSketch(o1).getEstimate()); } - @ScalarFunction(names = {"cpcSketchUnion", "cpc_sketch_union"}) + @ScalarFunction public static byte[] cpcSketchUnion(Object o1, Object o2) { return cpcSketchUnionVar(o1, o2); } - @ScalarFunction(names = {"cpcSketchUnion", "cpc_sketch_union"}) + @ScalarFunction public static byte[] cpcSketchUnion(Object o1, Object o2, Object o3) { return cpcSketchUnionVar(o1, o2, o3); } - @ScalarFunction(names = {"cpcSketchUnion", "cpc_sketch_union"}) + @ScalarFunction public static byte[] cpcSketchUnion(Object o1, Object o2, Object o3, Object o4) { return cpcSketchUnionVar(o1, o2, o3, o4); } - @ScalarFunction(names = {"cpcSketchUnion", "cpc_sketch_union"}) + @ScalarFunction public static byte[] cpcSketchUnion(Object o1, Object o2, Object o3, Object o4, Object o5) { return cpcSketchUnionVar(o1, o2, o3, o4, o5); } - @ScalarFunction(names = {"cpcSketchToString", "cpc_sketch_to_string"}) + @ScalarFunction public static String cpcSketchToString(Object sketchObject) { return asCpcSketch(sketchObject).toString(); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/geospatial/transform/function/ScalarFunctions.java b/pinot-core/src/main/java/org/apache/pinot/core/geospatial/transform/function/ScalarFunctions.java index 2d259c03a7bc..fc6d256649e8 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/geospatial/transform/function/ScalarFunctions.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/geospatial/transform/function/ScalarFunctions.java @@ -43,7 +43,7 @@ private ScalarFunctions() { * @param y y * @return the created point */ - @ScalarFunction(names = {"stPoint", "ST_point"}) + @ScalarFunction public static byte[] stPoint(double x, double y) { return GeometrySerializer.serialize(GeometryUtils.GEOMETRY_FACTORY.createPoint(new Coordinate(x, y))); } @@ -56,7 +56,7 @@ public static byte[] stPoint(double x, double y) { * @param isGeography if it's geography * @return the created point */ - @ScalarFunction(names = {"stPoint", "ST_point"}) + @ScalarFunction public static byte[] stPoint(double x, double y, Object isGeography) { Point point = GeometryUtils.GEOMETRY_FACTORY.createPoint(new Coordinate(x, y)); if (BooleanUtils.toBoolean(isGeography)) { @@ -68,7 +68,7 @@ public static byte[] stPoint(double x, double y, Object isGeography) { /** * Reads a geometry object from the WKT format. */ - @ScalarFunction(names = {"stGeomFromText", "ST_geom_from_text"}) + @ScalarFunction public static byte[] stGeomFromText(String wkt) throws ParseException { return GeometrySerializer.serialize(GeometryUtils.GEOMETRY_WKT_READER.read(wkt)); @@ -77,7 +77,7 @@ public static byte[] stGeomFromText(String wkt) /** * Reads a geography object from the WKT format. */ - @ScalarFunction(names = {"stGeogFromText", "ST_geog_from_text"}) + @ScalarFunction public static byte[] stGeogFromText(String wkt) throws ParseException { return GeometrySerializer.serialize(GeometryUtils.GEOGRAPHY_WKT_READER.read(wkt)); @@ -86,7 +86,7 @@ public static byte[] stGeogFromText(String wkt) /** * Reads a geometry object from the WKB format. */ - @ScalarFunction(names = {"stGeomFromWKB", "ST_geom_from_wkb"}) + @ScalarFunction public static byte[] stGeomFromWKB(byte[] wkb) throws ParseException { return GeometrySerializer.serialize(GeometryUtils.GEOMETRY_WKB_READER.read(wkb)); @@ -95,7 +95,7 @@ public static byte[] stGeomFromWKB(byte[] wkb) /** * Reads a geography object from the WKB format. */ - @ScalarFunction(names = {"stGeogFromWKB", "ST_geog_from_wkb"}) + @ScalarFunction public static byte[] stGeogFromWKB(byte[] wkb) throws ParseException { return GeometrySerializer.serialize(GeometryUtils.GEOGRAPHY_WKB_READER.read(wkb)); @@ -107,7 +107,7 @@ public static byte[] stGeogFromWKB(byte[] wkb) * @param bytes the serialized geometry object * @return the geometry in WKT */ - @ScalarFunction(names = {"stAsText", "ST_as_text"}) + @ScalarFunction public static String stAsText(byte[] bytes) { return GeometryUtils.WKT_WRITER.write(GeometrySerializer.deserialize(bytes)); } @@ -118,7 +118,7 @@ public static String stAsText(byte[] bytes) { * @param bytes the serialized geometry object * @return the geometry in WKB */ - @ScalarFunction(names = {"stAsBinary", "ST_as_binary"}) + @ScalarFunction public static byte[] stAsBinary(byte[] bytes) { return GeometryUtils.WKB_WRITER.write(GeometrySerializer.deserialize(bytes)); } @@ -129,7 +129,7 @@ public static byte[] stAsBinary(byte[] bytes) { * @param bytes the serialized geometry object * @return the geographical object */ - @ScalarFunction(names = {"toSphericalGeography", "to_spherical_geography"}) + @ScalarFunction public static byte[] toSphericalGeography(byte[] bytes) { Geometry geometry = GeometrySerializer.deserialize(bytes); GeometryUtils.setGeography(geometry); @@ -142,7 +142,7 @@ public static byte[] toSphericalGeography(byte[] bytes) { * @param bytes the serialized geographical object * @return the geometry object */ - @ScalarFunction(names = {"toGeometry", "to_geometry"}) + @ScalarFunction public static byte[] toGeometry(byte[] bytes) { Geometry geometry = GeometrySerializer.deserialize(bytes); GeometryUtils.setGeometry(geometry); @@ -156,7 +156,7 @@ public static byte[] toGeometry(byte[] bytes) { * @param resolution H3 index resolution * @return the H3 index address */ - @ScalarFunction(names = {"geoToH3", "geo_to_h3"}) + @ScalarFunction public static long geoToH3(double longitude, double latitude, int resolution) { return H3Utils.H3_CORE.geoToH3(latitude, longitude, resolution); } @@ -167,7 +167,7 @@ public static long geoToH3(double longitude, double latitude, int resolution) { * @param resolution H3 index resolution * @return the H3 index address */ - @ScalarFunction(names = {"geoToH3", "geo_to_h3"}) + @ScalarFunction public static long geoToH3(byte[] geoBytes, int resolution) { Geometry geometry = GeometrySerializer.deserialize(geoBytes); double latitude = geometry.getCoordinate().y; @@ -175,7 +175,7 @@ public static long geoToH3(byte[] geoBytes, int resolution) { return H3Utils.H3_CORE.geoToH3(latitude, longitude, resolution); } - @ScalarFunction(names = {"stDistance", "ST_distance"}) + @ScalarFunction public static double stDistance(byte[] firstPoint, byte[] secondPoint) { Geometry firstGeometry = GeometrySerializer.deserialize(firstPoint); Geometry secondGeometry = GeometrySerializer.deserialize(secondPoint); @@ -189,7 +189,7 @@ public static double stDistance(byte[] firstPoint, byte[] secondPoint) { } } - @ScalarFunction(names = {"stContains", "st_contains"}) + @ScalarFunction public static int stContains(byte[] first, byte[] second) { Geometry firstGeometry = GeometrySerializer.deserialize(first); Geometry secondGeometry = GeometrySerializer.deserialize(second); @@ -200,7 +200,7 @@ public static int stContains(byte[] first, byte[] second) { return firstGeometry.contains(secondGeometry) ? 1 : 0; } - @ScalarFunction(names = {"stEquals", "st_equals"}) + @ScalarFunction public static int stEquals(byte[] first, byte[] second) { Geometry firstGeometry = GeometrySerializer.deserialize(first); Geometry secondGeometry = GeometrySerializer.deserialize(second); @@ -208,13 +208,13 @@ public static int stEquals(byte[] first, byte[] second) { return firstGeometry.equals(secondGeometry) ? 1 : 0; } - @ScalarFunction(names = {"stGeometryType", "st_geometry_type"}) + @ScalarFunction public static String stGeometryType(byte[] bytes) { Geometry geometry = GeometrySerializer.deserialize(bytes); return geometry.getGeometryType(); } - @ScalarFunction(names = {"stWithin", "st_within"}) + @ScalarFunction public static int stWithin(byte[] first, byte[] second) { Geometry firstGeometry = GeometrySerializer.deserialize(first); Geometry secondGeometry = GeometrySerializer.deserialize(second); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java index 7643959d113b..3067cdc5f59d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java @@ -158,7 +158,6 @@ private static Map> createRegistry() typeToImplementation.put(TransformFunctionType.ARRAY_MAX, ArrayMaxTransformFunction.class); typeToImplementation.put(TransformFunctionType.ARRAY_MIN, ArrayMinTransformFunction.class); typeToImplementation.put(TransformFunctionType.ARRAY_SUM, ArraySumTransformFunction.class); - typeToImplementation.put(TransformFunctionType.ARRAY_VALUE_CONSTRUCTOR, ArrayLiteralTransformFunction.class); typeToImplementation.put(TransformFunctionType.GROOVY, GroovyTransformFunction.class); typeToImplementation.put(TransformFunctionType.CASE, CaseTransformFunction.class); @@ -241,11 +240,11 @@ private static Map> createRegistry() typeToImplementation.put(TransformFunctionType.VECTOR_DIMS, VectorDimsTransformFunction.class); typeToImplementation.put(TransformFunctionType.VECTOR_NORM, VectorNormTransformFunction.class); - Map> registry - = new HashMap<>(HashUtil.getHashMapCapacity(typeToImplementation.size())); + Map> registry = + new HashMap<>(HashUtil.getHashMapCapacity(typeToImplementation.size())); for (Map.Entry> entry : typeToImplementation.entrySet()) { - for (String alias : entry.getKey().getAlternativeNames()) { - registry.put(canonicalize(alias), entry.getValue()); + for (String name : entry.getKey().getNames()) { + registry.put(canonicalize(name), entry.getValue()); } } return registry; @@ -292,11 +291,10 @@ public static TransformFunction get(ExpressionContext expression, Map arguments = function.getArguments(); int numArguments = arguments.size(); - // Check if the function is ArrayLiteraltransform function + // Check if the function is ArrayValueConstructor transform function if (functionName.equalsIgnoreCase(ArrayLiteralTransformFunction.FUNCTION_NAME)) { return queryContext.getOrComputeSharedValue(ArrayLiteralTransformFunction.class, - expression.getFunction().getArguments(), - ArrayLiteralTransformFunction::new); + expression.getFunction().getArguments(), ArrayLiteralTransformFunction::new); } // Check if the function is GenerateArray transform function @@ -317,13 +315,14 @@ public static TransformFunction get(ExpressionContext expression, Map dataSourceMap) { Map columnContextMap = new HashMap<>(HashUtil.getHashMapCapacity(dataSourceMap.size())); dataSourceMap.forEach((k, v) -> columnContextMap.put(k, ColumnContext.fromDataSource(v))); - QueryContext dummy = QueryContextConverterUtils.getQueryContext( - CalciteSqlParser.compileToPinotQuery("SELECT * from testTable;")); + QueryContext dummy = + QueryContextConverterUtils.getQueryContext(CalciteSqlParser.compileToPinotQuery("SELECT * from testTable;")); return get(expression, columnContextMap, dummy); } + // TODO: Move to a test util class @VisibleForTesting public static TransformFunction getNullHandlingEnabled(ExpressionContext expression, Map dataSourceMap) { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/StringPredicateFilterOptimizer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/StringPredicateFilterOptimizer.java index a9b7efe91cf0..d8494684acf7 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/StringPredicateFilterOptimizer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/optimizer/statement/StringPredicateFilterOptimizer.java @@ -115,8 +115,8 @@ private static boolean isString(Expression expression, Schema schema) { if (expressionType == ExpressionType.FUNCTION) { // Check if the function returns STRING as output. Function function = expression.getFunctionCall(); - FunctionInfo functionInfo = - FunctionRegistry.getFunctionInfo(function.getOperator(), function.getOperands().size()); + String canonicalName = FunctionRegistry.canonicalize(function.getOperator()); + FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, function.getOperands().size()); return functionInfo != null && functionInfo.getMethod().getReturnType() == String.class; } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java index 094d891d1d27..370865227c47 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/postaggregation/PostAggregationFunction.java @@ -19,6 +19,7 @@ package org.apache.pinot.core.query.postaggregation; import com.google.common.base.Preconditions; +import java.util.Arrays; import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.function.FunctionInvoker; import org.apache.pinot.common.function.FunctionRegistry; @@ -36,20 +37,21 @@ public class PostAggregationFunction { private final ColumnDataType _resultType; public PostAggregationFunction(String functionName, ColumnDataType[] argumentTypes) { - int numArguments = argumentTypes.length; - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName, numArguments); + String canonicalName = FunctionRegistry.canonicalize(functionName); + FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes); if (functionInfo == null) { - if (FunctionRegistry.containsFunction(functionName)) { + if (FunctionRegistry.contains(canonicalName)) { throw new IllegalArgumentException( - String.format("Unsupported function: %s with %d parameters", functionName, numArguments)); + String.format("Unsupported function: %s with argument types: %s", functionName, + Arrays.toString(argumentTypes))); } else { - throw new IllegalArgumentException( - String.format("Unsupported function: %s not found", functionName)); + throw new IllegalArgumentException(String.format("Unsupported function: %s", functionName)); } } _functionInvoker = new FunctionInvoker(functionInfo); Class[] parameterClasses = _functionInvoker.getParameterClasses(); PinotDataType[] parameterTypes = _functionInvoker.getParameterTypes(); + int numArguments = argumentTypes.length; int numParameters = parameterClasses.length; Preconditions.checkArgument(numArguments == numParameters, "Wrong number of arguments for method: %s, expected: %s, actual: %s", functionInfo.getMethod(), numParameters, diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java index 07ba2e00c72d..65e930aae564 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/data/function/InbuiltFunctionEvaluatorTest.java @@ -18,10 +18,9 @@ */ package org.apache.pinot.core.data.function; -import java.lang.reflect.Method; import java.util.Collections; -import org.apache.pinot.common.function.FunctionRegistry; import org.apache.pinot.segment.local.function.InbuiltFunctionEvaluator; +import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.data.readers.GenericRow; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; @@ -30,7 +29,6 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; public class InbuiltFunctionEvaluatorTest { @@ -127,11 +125,8 @@ public void testNestedFunction() { } @Test - public void testStateSharedBetweenRowsForExecution() - throws Exception { - MyFunc myFunc = new MyFunc(); - Method method = myFunc.getClass().getDeclaredMethod("appendToStringAndReturn", String.class); - FunctionRegistry.registerFunction(method, false, false, false); + public void testStateSharedBetweenRowsForExecution() { + // This function is auto registered with @ScalarFunction annotation under MyFunc class String expression = "appendToStringAndReturn('test ')"; InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); assertTrue(evaluator.getArguments().isEmpty()); @@ -158,24 +153,11 @@ public void testNullReturnedByInbuiltFunctionEvaluatorThatCannotTakeNull() { } } - @Test - public void testPlaceholderFunctionShouldNotBeRegistered() - throws Exception { - GenericRow row = new GenericRow(); - row.putValue("testColumn", "testValue"); - String expression = "text_match(testColumn, 'pattern')"; - try { - InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); - evaluator.evaluate(row); - fail(); - } catch (Throwable t) { - assertTrue(t.getMessage().contains("text_match")); - } - } - + @SuppressWarnings("unused") public static class MyFunc { String _baseString = ""; + @ScalarFunction public String appendToStringAndReturn(String addedString) { _baseString += addedString; return _baseString; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/function/FunctionRegistryTest.java b/pinot-core/src/test/java/org/apache/pinot/core/function/FunctionRegistryTest.java new file mode 100644 index 000000000000..cf2fd6e8c29e --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/function/FunctionRegistryTest.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.function; + +import java.util.EnumSet; +import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.spi.annotations.ScalarFunction; +import org.apache.pinot.sql.FilterKind; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + + +// NOTE: Keep this test in pinot-core to include all built-in scalar functions. +public class FunctionRegistryTest { + // TODO: Support these functions + private static final EnumSet IGNORED_TRANSFORM_FUNCTION_TYPES = EnumSet.of( + // Special placeholder functions without implementation + TransformFunctionType.SCALAR, + // Special functions that requires index + TransformFunctionType.JSON_EXTRACT_INDEX, TransformFunctionType.MAP_VALUE, TransformFunctionType.LOOKUP, + // TODO: Support these functions + TransformFunctionType.IN, TransformFunctionType.NOT_IN, TransformFunctionType.IS_TRUE, + TransformFunctionType.IS_NOT_TRUE, TransformFunctionType.IS_FALSE, TransformFunctionType.IS_NOT_FALSE, + TransformFunctionType.AND, TransformFunctionType.OR, TransformFunctionType.JSON_EXTRACT_SCALAR, + TransformFunctionType.JSON_EXTRACT_KEY, TransformFunctionType.TIME_CONVERT, + TransformFunctionType.DATE_TIME_CONVERT_WINDOW_HOP, TransformFunctionType.ARRAY_LENGTH, + TransformFunctionType.ARRAY_AVERAGE, TransformFunctionType.ARRAY_MIN, TransformFunctionType.ARRAY_MAX, + TransformFunctionType.ARRAY_SUM, TransformFunctionType.VALUE_IN, TransformFunctionType.IN_ID_SET, + TransformFunctionType.GROOVY, TransformFunctionType.CLP_DECODE, TransformFunctionType.CLP_ENCODED_VARS_MATCH, + TransformFunctionType.ST_POLYGON, TransformFunctionType.ST_AREA); + private static final EnumSet IGNORED_FILTER_KINDS = EnumSet.of( + // Special filter functions without implementation + FilterKind.TEXT_MATCH, FilterKind.TEXT_CONTAINS, FilterKind.JSON_MATCH, FilterKind.VECTOR_SIMILARITY, + // TODO: Support these functions + FilterKind.AND, FilterKind.OR, FilterKind.RANGE, FilterKind.IN, FilterKind.NOT_IN); + + @Test + public void testTransformAndFilterFunctionsRegistered() { + for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { + if (IGNORED_TRANSFORM_FUNCTION_TYPES.contains(transformFunctionType)) { + continue; + } + for (String name : transformFunctionType.getNames()) { + assertTrue(FunctionRegistry.contains(FunctionRegistry.canonicalize(name)), + "Unable to find transform function signature for: " + name); + } + } + for (FilterKind filterKind : FilterKind.values()) { + if (IGNORED_FILTER_KINDS.contains(filterKind)) { + continue; + } + assertTrue(FunctionRegistry.contains(FunctionRegistry.canonicalize(filterKind.name())), + "Unable to find filter function signature for: " + filterKind); + } + } + + @ScalarFunction(names = {"testFunc1", "testFunc2"}) + public static String testScalarFunction(long randomArg1, String randomArg2) { + return null; + } + + @Test + public void testScalarFunctionNames() { + assertNotNull(FunctionRegistry.lookupFunctionInfo("testfunc1", 2)); + assertNotNull(FunctionRegistry.lookupFunctionInfo("testfunc2", 2)); + assertNull(FunctionRegistry.lookupFunctionInfo("testscalarfunction", 2)); + assertNull(FunctionRegistry.lookupFunctionInfo("testfunc1", 1)); + assertNull(FunctionRegistry.lookupFunctionInfo("testfunc2", 1)); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/DateTimeTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/DateTimeTransformFunctionTest.java index f9325c0e04c0..91c8dca652de 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/DateTimeTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/DateTimeTransformFunctionTest.java @@ -39,9 +39,9 @@ public static Object[][] testCasesUTC() { {"year", (LongToIntFunction) DateTimeFunctions::year, DateTimeTransformFunction.Year.class}, {"yearOfWeek", (LongToIntFunction) DateTimeFunctions::yearOfWeek, DateTimeTransformFunction.YearOfWeek.class}, {"yow", (LongToIntFunction) DateTimeFunctions::yearOfWeek, DateTimeTransformFunction.YearOfWeek.class}, - {"month", (LongToIntFunction) DateTimeFunctions::monthOfYear, DateTimeTransformFunction.Month.class}, - {"week", (LongToIntFunction) DateTimeFunctions::weekOfYear, DateTimeTransformFunction.WeekOfYear.class}, - {"weekOfYear", (LongToIntFunction) DateTimeFunctions::weekOfYear, DateTimeTransformFunction.WeekOfYear.class}, + {"month", (LongToIntFunction) DateTimeFunctions::month, DateTimeTransformFunction.Month.class}, + {"week", (LongToIntFunction) DateTimeFunctions::week, DateTimeTransformFunction.WeekOfYear.class}, + {"weekOfYear", (LongToIntFunction) DateTimeFunctions::week, DateTimeTransformFunction.WeekOfYear.class}, {"quarter", (LongToIntFunction) DateTimeFunctions::quarter, DateTimeTransformFunction.Quarter.class}, {"dayOfWeek", (LongToIntFunction) DateTimeFunctions::dayOfWeek, DateTimeTransformFunction.DayOfWeek.class}, {"dow", (LongToIntFunction) DateTimeFunctions::dayOfWeek, DateTimeTransformFunction.DayOfWeek.class}, @@ -62,9 +62,9 @@ public static Object[][] testCasesZoned() { {"year", (ZonedTimeFunction) DateTimeFunctions::year, DateTimeTransformFunction.Year.class}, {"yearOfWeek", (ZonedTimeFunction) DateTimeFunctions::yearOfWeek, DateTimeTransformFunction.YearOfWeek.class}, {"yow", (ZonedTimeFunction) DateTimeFunctions::yearOfWeek, DateTimeTransformFunction.YearOfWeek.class}, - {"month", (ZonedTimeFunction) DateTimeFunctions::monthOfYear, DateTimeTransformFunction.Month.class}, - {"week", (ZonedTimeFunction) DateTimeFunctions::weekOfYear, DateTimeTransformFunction.WeekOfYear.class}, - {"weekOfYear", (ZonedTimeFunction) DateTimeFunctions::weekOfYear, DateTimeTransformFunction.WeekOfYear.class}, + {"month", (ZonedTimeFunction) DateTimeFunctions::month, DateTimeTransformFunction.Month.class}, + {"week", (ZonedTimeFunction) DateTimeFunctions::week, DateTimeTransformFunction.WeekOfYear.class}, + {"weekOfYear", (ZonedTimeFunction) DateTimeFunctions::week, DateTimeTransformFunction.WeekOfYear.class}, {"quarter", (ZonedTimeFunction) DateTimeFunctions::quarter, DateTimeTransformFunction.Quarter.class}, {"dayOfWeek", (ZonedTimeFunction) DateTimeFunctions::dayOfWeek, DateTimeTransformFunction.DayOfWeek.class}, {"dow", (ZonedTimeFunction) DateTimeFunctions::dayOfWeek, DateTimeTransformFunction.DayOfWeek.class}, diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ExtractTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ExtractTransformFunctionTest.java index 96eda1c96c68..42c56da02ea9 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ExtractTransformFunctionTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ExtractTransformFunctionTest.java @@ -38,9 +38,9 @@ public static Object[][] testCases() { return new Object[][]{ //@formatter:off {"year", (LongToIntFunction) DateTimeFunctions::year}, - {"quarter", (LongToIntFunction) timestamp -> (DateTimeFunctions.monthOfYear(timestamp) - 1) / 3 + 1}, - {"month", (LongToIntFunction) DateTimeFunctions::monthOfYear}, - {"week", (LongToIntFunction) DateTimeFunctions::weekOfYear}, + {"quarter", (LongToIntFunction) timestamp -> (DateTimeFunctions.month(timestamp) - 1) / 3 + 1}, + {"month", (LongToIntFunction) DateTimeFunctions::month}, + {"week", (LongToIntFunction) DateTimeFunctions::week}, {"day", (LongToIntFunction) DateTimeFunctions::dayOfMonth}, {"doy", (LongToIntFunction) DateTimeFunctions::dayOfYear}, {"dow", (LongToIntFunction) DateTimeFunctions::dayOfWeek}, diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/TimestampQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/TimestampQueriesTest.java index c96c993ed2ca..3a8da1d6c551 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/TimestampQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/TimestampQueriesTest.java @@ -223,10 +223,12 @@ public void testQueries() { } } + //@formatter:off @Test( expectedExceptions = BadQueryRequestException.class, - expectedExceptionsMessageRegExp = ".*attimezone not found.*" + expectedExceptionsMessageRegExp = "Unsupported function: attimezone" ) + //@formatter:on public void shouldThrowOnAtTimeZone() { // this isn't yet implemented but the syntax is supported, make sure the // degradation experience is clean diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java index 48c703b65358..6d0fc53284de 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/OfflineClusterIntegrationTest.java @@ -3009,9 +3009,10 @@ public void testDistinctCountHllPlus(boolean useMultiStageQueryEngine) assertEquals(postQuery(query).get("resultTable").get("rows").get(0).get(0).asLong(), expectedResults[10]); } - @Test - public void testAggregationFunctionsWithUnderscoreV1() + @Test(dataProvider = "useBothQueryEngines") + public void testAggregationFunctionsWithUnderscore(boolean useMultiStageQueryEngine) throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); String query; // The Accurate value is 6538. @@ -3023,21 +3024,6 @@ public void testAggregationFunctionsWithUnderscoreV1() assertEquals(postQuery(query).get("resultTable").get("rows").get(0).get(0).asInt(), 115545); } - @Test - public void testAggregationFunctionsWithUnderscoreV2() - throws Exception { - setUseMultiStageQueryEngine(true); - String query; - - // The Accurate value is 6538. - query = "SELECT distinct_count(FlightNum) FROM mytable"; - assertEquals(postQuery(query).get("resultTable").get("rows").get(0).get(0).asInt(), 6538); - - // This is not supported in V2. - query = "SELECT c_o_u_n_t(FlightNum) FROM mytable"; - testQueryError(query, QueryException.QUERY_PLANNING_ERROR_CODE); - } - @Test public void testExplainPlanQueryV1() throws Exception { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/jdbc/CalciteSchemaBuilder.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/jdbc/CalciteSchemaBuilder.java deleted file mode 100644 index 1da5e2d7f1a0..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/jdbc/CalciteSchemaBuilder.java +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.calcite.jdbc; - -import java.util.List; -import java.util.Map; -import org.apache.calcite.jdbc.CalciteSchema; -import org.apache.calcite.schema.Function; -import org.apache.calcite.schema.Schema; -import org.apache.calcite.schema.SchemaPlus; -import org.apache.pinot.common.function.FunctionRegistry; - - -/** - * This class is used to create a {@link CalciteSchema} with a given {@link Schema} as the root. - */ -public class CalciteSchemaBuilder { - private CalciteSchemaBuilder() { - } - - /** - * Creates a {@link CalciteSchema} with a given {@link Schema} as the root. - * - * @param root schema to use as a root schema - * @return calcite schema with given schema as the root - */ - public static CalciteSchema asRootSchema(Schema root, String name) { - CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, name, root); - SchemaPlus schemaPlus = rootSchema.plus(); - for (Map.Entry> e : FunctionRegistry.getRegisteredCalciteFunctionMap().entrySet()) { - for (Function f : e.getValue()) { - schemaPlus.add(e.getKey(), f); - } - } - return rootSchema; - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/prepare/PinotCalciteCatalogReader.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/prepare/PinotCalciteCatalogReader.java deleted file mode 100644 index c345ca216cb1..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/prepare/PinotCalciteCatalogReader.java +++ /dev/null @@ -1,468 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.calcite.prepare; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Iterables; -import java.util.ArrayList; -import java.util.Collection; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.NavigableSet; -import java.util.Objects; -import java.util.function.Function; -import java.util.function.Predicate; -import org.apache.calcite.config.CalciteConnectionConfig; -import org.apache.calcite.jdbc.CalciteSchema; -import org.apache.calcite.jdbc.JavaTypeFactoryImpl; -import org.apache.calcite.linq4j.function.Hints; -import org.apache.calcite.model.ModelHandler; -import org.apache.calcite.plan.RelOptPlanner; -import org.apache.calcite.prepare.CalciteCatalogReader; -import org.apache.calcite.prepare.Prepare; -import org.apache.calcite.prepare.RelOptTableImpl; -import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rel.type.RelDataTypeFactoryImpl; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.schema.AggregateFunction; -import org.apache.calcite.schema.ScalarFunction; -import org.apache.calcite.schema.Table; -import org.apache.calcite.schema.TableFunction; -import org.apache.calcite.schema.TableMacro; -import org.apache.calcite.schema.Wrapper; -import org.apache.calcite.schema.impl.ScalarFunctionImpl; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlOperatorTable; -import org.apache.calcite.sql.SqlSyntax; -import org.apache.calcite.sql.parser.SqlParserPos; -import org.apache.calcite.sql.type.InferTypes; -import org.apache.calcite.sql.type.OperandTypes; -import org.apache.calcite.sql.type.ReturnTypes; -import org.apache.calcite.sql.type.SqlOperandMetadata; -import org.apache.calcite.sql.type.SqlOperandTypeInference; -import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.sql.type.SqlTypeFamily; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.util.ListSqlOperatorTable; -import org.apache.calcite.sql.validate.SqlMoniker; -import org.apache.calcite.sql.validate.SqlMonikerImpl; -import org.apache.calcite.sql.validate.SqlMonikerType; -import org.apache.calcite.sql.validate.SqlNameMatcher; -import org.apache.calcite.sql.validate.SqlNameMatchers; -import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction; -import org.apache.calcite.sql.validate.SqlUserDefinedFunction; -import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction; -import org.apache.calcite.sql.validate.SqlUserDefinedTableMacro; -import org.apache.calcite.sql.validate.SqlValidatorUtil; -import org.apache.calcite.util.Optionality; -import org.apache.calcite.util.Util; -import org.checkerframework.checker.nullness.qual.Nullable; - - -/** - * ============================================================================================================== - * THIS CLASS IS COPIED FROM Calcite's {@link CalciteCatalogReader} and modified the case sensitivity of Function - * lookup, which is ALWAYS case-insensitive regardless of conventions on column/table identifier. - * ============================================================================================================== - * - * Pinot's implementation of {@link Prepare.CatalogReader} and also {@link SqlOperatorTable} based on tables and - * functions defined schemas. - */ -//@formatter:off -public class PinotCalciteCatalogReader implements Prepare.CatalogReader { - protected final CalciteSchema _rootSchema; - protected final RelDataTypeFactory _typeFactory; - private final List> _schemaPaths; - protected final SqlNameMatcher _nameMatcher; - protected final CalciteConnectionConfig _config; - - public PinotCalciteCatalogReader(CalciteSchema rootSchema, - List defaultSchema, RelDataTypeFactory typeFactory, CalciteConnectionConfig config) { - this(rootSchema, SqlNameMatchers.withCaseSensitive(config != null && config.caseSensitive()), - ImmutableList.of(Objects.requireNonNull(defaultSchema, "defaultSchema")), - typeFactory, config); - } - - protected PinotCalciteCatalogReader(CalciteSchema rootSchema, - SqlNameMatcher nameMatcher, List> schemaPaths, - RelDataTypeFactory typeFactory, CalciteConnectionConfig config) { - _rootSchema = Objects.requireNonNull(rootSchema, "rootSchema"); - _nameMatcher = nameMatcher; - _schemaPaths = - Util.immutableCopy(Util.isDistinct(schemaPaths) - ? schemaPaths - : new LinkedHashSet<>(schemaPaths)); - _typeFactory = typeFactory; - _config = config; - } - - @Override public PinotCalciteCatalogReader withSchemaPath(List schemaPath) { - return new PinotCalciteCatalogReader(_rootSchema, _nameMatcher, - ImmutableList.of(schemaPath, ImmutableList.of()), _typeFactory, _config); - } - - @Override public Prepare.@Nullable PreparingTable getTable(final List names) { - // First look in the default schema, if any. - // If not found, look in the root schema. - CalciteSchema.TableEntry entry = SqlValidatorUtil.getTableEntry(this, names); - if (entry != null) { - final Table table = entry.getTable(); - if (table instanceof Wrapper) { - final Prepare.PreparingTable relOptTable = - ((Wrapper) table).unwrap(Prepare.PreparingTable.class); - if (relOptTable != null) { - return relOptTable; - } - } - return RelOptTableImpl.create(this, - table.getRowType(_typeFactory), entry, null); - } - return null; - } - - @Override public CalciteConnectionConfig getConfig() { - return _config; - } - - private Collection getFunctionsFrom( - List names) { - final List functions2 = - new ArrayList<>(); - final List> schemaNameList = new ArrayList<>(); - if (names.size() > 1) { - // Name qualified: ignore path. But we do look in "/catalog" and "/", - // the last 2 items in the path. - if (_schemaPaths.size() > 1) { - schemaNameList.addAll(Util.skip(_schemaPaths)); - } else { - schemaNameList.addAll(_schemaPaths); - } - } else { - for (List schemaPath : _schemaPaths) { - CalciteSchema schema = - SqlValidatorUtil.getSchema(_rootSchema, schemaPath, _nameMatcher); - if (schema != null) { - schemaNameList.addAll(schema.getPath()); - } - } - } - for (List schemaNames : schemaNameList) { - CalciteSchema schema = - SqlValidatorUtil.getSchema(_rootSchema, - Iterables.concat(schemaNames, Util.skipLast(names)), _nameMatcher); - if (schema != null) { - final String name = Util.last(names); - // ==================================================================== - // LINES CHANGED BELOW - // ==================================================================== - functions2.addAll(schema.getFunctions(name, false)); - // ==================================================================== - // LINES CHANGED ABOVE - // ==================================================================== - } - } - return functions2; - } - - @Override public @Nullable RelDataType getNamedType(SqlIdentifier typeName) { - CalciteSchema.TypeEntry typeEntry = SqlValidatorUtil.getTypeEntry(getRootSchema(), typeName); - if (typeEntry != null) { - return typeEntry.getType().apply(_typeFactory); - } else { - return null; - } - } - - @Override public List getAllSchemaObjectNames(List names) { - final CalciteSchema schema = - SqlValidatorUtil.getSchema(_rootSchema, names, _nameMatcher); - if (schema == null) { - return ImmutableList.of(); - } - final ImmutableList.Builder result = new ImmutableList.Builder<>(); - - // Add root schema if not anonymous - if (!schema.name.equals("")) { - result.add(moniker(schema, null, SqlMonikerType.SCHEMA)); - } - - final Map schemaMap = schema.getSubSchemaMap(); - - for (String subSchema : schemaMap.keySet()) { - result.add(moniker(schema, subSchema, SqlMonikerType.SCHEMA)); - } - - for (String table : schema.getTableNames()) { - result.add(moniker(schema, table, SqlMonikerType.TABLE)); - } - - final NavigableSet functions = schema.getFunctionNames(); - for (String function : functions) { // views are here as well - result.add(moniker(schema, function, SqlMonikerType.FUNCTION)); - } - return result.build(); - } - - private static SqlMonikerImpl moniker(CalciteSchema schema, @Nullable String name, - SqlMonikerType type) { - final List path = schema.path(name); - if (path.size() == 1 - && !schema.root().name.equals("") - && type == SqlMonikerType.SCHEMA) { - type = SqlMonikerType.CATALOG; - } - return new SqlMonikerImpl(path, type); - } - - @Override public List> getSchemaPaths() { - return _schemaPaths; - } - - @Override public Prepare.@Nullable PreparingTable getTableForMember(List names) { - return getTable(names); - } - - @SuppressWarnings("deprecation") - @Override public @Nullable RelDataTypeField field(RelDataType rowType, String alias) { - return _nameMatcher.field(rowType, alias); - } - - @SuppressWarnings("deprecation") - @Override public boolean matches(String string, String name) { - return _nameMatcher.matches(string, name); - } - - @Override public RelDataType createTypeFromProjection(final RelDataType type, - final List columnNameList) { - return SqlValidatorUtil.createTypeFromProjection(type, columnNameList, _typeFactory, - _nameMatcher.isCaseSensitive()); - } - - @Override public void lookupOperatorOverloads(final SqlIdentifier opName, - @Nullable SqlFunctionCategory category, - SqlSyntax syntax, - List operatorList, - SqlNameMatcher nameMatcher) { - if (syntax != SqlSyntax.FUNCTION) { - return; - } - - final Predicate predicate; - if (category == null) { - predicate = function -> true; - } else if (category.isTableFunction()) { - predicate = function -> - function instanceof TableMacro - || function instanceof TableFunction; - } else { - predicate = function -> - !(function instanceof TableMacro - || function instanceof TableFunction); - } - getFunctionsFrom(opName.names) - .stream() - .filter(predicate) - .map(function -> toOp(opName, function)) - .forEachOrdered(operatorList::add); - } - - /** Creates an operator table that contains functions in the given class - * or classes. - * - * @see ModelHandler#addFunctions */ - public static SqlOperatorTable operatorTable(String... classNames) { - // Dummy schema to collect the functions - final CalciteSchema schema = - CalciteSchema.createRootSchema(false, false); - for (String className : classNames) { - ModelHandler.addFunctions(schema.plus(), null, ImmutableList.of(), - className, "*", true); - } - - final ListSqlOperatorTable table = new ListSqlOperatorTable(); - for (String name : schema.getFunctionNames()) { - schema.getFunctions(name, true).forEach(function -> { - final SqlIdentifier id = new SqlIdentifier(name, SqlParserPos.ZERO); - table.add(toOp(id, function)); - }); - } - return table; - } - - /** Converts a function to a {@link org.apache.calcite.sql.SqlOperator}. */ - private static SqlOperator toOp(SqlIdentifier name, - final org.apache.calcite.schema.Function function) { - final Function> argTypesFactory = - typeFactory -> function.getParameters() - .stream() - .map(o -> o.getType(typeFactory)) - .collect(Util.toImmutableList()); - final Function> typeFamiliesFactory = - typeFactory -> argTypesFactory.apply(typeFactory) - .stream() - .map(type -> - Util.first(type.getSqlTypeName().getFamily(), - SqlTypeFamily.ANY)) - .collect(Util.toImmutableList()); - final Function> paramTypesFactory = - typeFactory -> - argTypesFactory.apply(typeFactory) - .stream() - .map(type -> toSql(typeFactory, type)) - .collect(Util.toImmutableList()); - - // Use a short-lived type factory to populate "typeFamilies" and "argTypes". - // SqlOperandMetadata.paramTypes will use the real type factory, during - // validation. - final RelDataTypeFactory dummyTypeFactory = new JavaTypeFactoryImpl(); - final List argTypes = argTypesFactory.apply(dummyTypeFactory); - final List typeFamilies = - typeFamiliesFactory.apply(dummyTypeFactory); - - final SqlOperandTypeInference operandTypeInference = - InferTypes.explicit(argTypes); - - final SqlOperandMetadata operandMetadata = - OperandTypes.operandMetadata(typeFamilies, paramTypesFactory, - i -> function.getParameters().get(i).getName(), - i -> function.getParameters().get(i).isOptional()); - - final SqlKind kind = kind(function); - if (function instanceof ScalarFunction) { - final SqlReturnTypeInference returnTypeInference = - infer((ScalarFunction) function); - return new SqlUserDefinedFunction(name, kind, returnTypeInference, - operandTypeInference, operandMetadata, function); - } else if (function instanceof AggregateFunction) { - final SqlReturnTypeInference returnTypeInference = - infer((AggregateFunction) function); - return new SqlUserDefinedAggFunction(name, kind, - returnTypeInference, operandTypeInference, - operandMetadata, (AggregateFunction) function, false, false, - Optionality.FORBIDDEN); - } else if (function instanceof TableMacro) { - return new SqlUserDefinedTableMacro(name, kind, ReturnTypes.CURSOR, - operandTypeInference, operandMetadata, (TableMacro) function); - } else if (function instanceof TableFunction) { - return new SqlUserDefinedTableFunction(name, kind, ReturnTypes.CURSOR, - operandTypeInference, operandMetadata, (TableFunction) function); - } else { - throw new AssertionError("unknown function type " + function); - } - } - - /** Deduces the {@link org.apache.calcite.sql.SqlKind} of a user-defined - * function based on a {@link Hints} annotation, if present. */ - private static SqlKind kind(org.apache.calcite.schema.Function function) { - if (function instanceof ScalarFunctionImpl) { - Hints hints = - ((ScalarFunctionImpl) function).method.getAnnotation(Hints.class); - if (hints != null) { - for (String hint : hints.value()) { - if (hint.startsWith("SqlKind:")) { - return SqlKind.valueOf(hint.substring("SqlKind:".length())); - } - } - } - } - return SqlKind.OTHER_FUNCTION; - } - - private static SqlReturnTypeInference infer(final ScalarFunction function) { - return opBinding -> { - final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); - final RelDataType type; - if (function instanceof ScalarFunctionImpl) { - type = ((ScalarFunctionImpl) function).getReturnType(typeFactory, - opBinding); - } else { - type = function.getReturnType(typeFactory); - } - return toSql(typeFactory, type); - }; - } - - private static SqlReturnTypeInference infer( - final AggregateFunction function) { - return opBinding -> { - final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); - final RelDataType type = function.getReturnType(typeFactory); - return toSql(typeFactory, type); - }; - } - - private static RelDataType toSql(RelDataTypeFactory typeFactory, - RelDataType type) { - if (type instanceof RelDataTypeFactoryImpl.JavaType - && ((RelDataTypeFactoryImpl.JavaType) type).getJavaClass() - == Object.class) { - return typeFactory.createTypeWithNullability( - typeFactory.createSqlType(SqlTypeName.ANY), true); - } - return JavaTypeFactoryImpl.toSql(typeFactory, type); - } - - @Override public List getOperatorList() { - final ImmutableList.Builder builder = ImmutableList.builder(); - for (List schemaPath : _schemaPaths) { - CalciteSchema schema = - SqlValidatorUtil.getSchema(_rootSchema, schemaPath, _nameMatcher); - if (schema != null) { - for (String name : schema.getFunctionNames()) { - schema.getFunctions(name, true).forEach(f -> - builder.add(toOp(new SqlIdentifier(name, SqlParserPos.ZERO), f))); - } - } - } - return builder.build(); - } - - @Override public CalciteSchema getRootSchema() { - return _rootSchema; - } - - @Override public RelDataTypeFactory getTypeFactory() { - return _typeFactory; - } - - @Override public void registerRules(RelOptPlanner planner) { - } - - @SuppressWarnings("deprecation") - @Override public boolean isCaseSensitive() { - return _nameMatcher.isCaseSensitive(); - } - - @Override public SqlNameMatcher nameMatcher() { - return _nameMatcher; - } - - @Override public @Nullable C unwrap(Class aClass) { - if (aClass.isInstance(this)) { - return aClass.cast(this); - } - return null; - } -} -//@formatter:on diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java index b91facc487db..a727f0cefe1b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/logical/PinotLogicalAggregate.java @@ -19,6 +19,7 @@ package org.apache.pinot.calcite.rel.logical; import java.util.List; +import javax.annotation.Nullable; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -27,7 +28,6 @@ import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.util.ImmutableBitSet; import org.apache.pinot.query.planner.plannode.AggregateNode.AggType; -import org.checkerframework.checker.nullness.qual.Nullable; public class PinotLogicalAggregate extends Aggregate { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java index 8a9f17917935..9d7b8211231e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java @@ -36,12 +36,17 @@ import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.ImmutableBitSet; @@ -54,7 +59,7 @@ import org.apache.pinot.calcite.rel.logical.PinotLogicalAggregate; import org.apache.pinot.calcite.rel.logical.PinotLogicalExchange; import org.apache.pinot.calcite.rel.logical.PinotLogicalSortExchange; -import org.apache.pinot.calcite.sql.PinotSqlAggFunction; +import org.apache.pinot.common.function.sql.PinotSqlAggFunction; import org.apache.pinot.query.planner.plannode.AggregateNode.AggType; import org.apache.pinot.segment.spi.AggregationFunctionType; @@ -294,42 +299,39 @@ private static List buildAggCalls(Aggregate aggRel, AggType aggTy // - argList is replaced with rexList private static AggregateCall buildAggCall(RelNode input, AggregateCall orgAggCall, List rexList, int numGroups, AggType aggType) { - String functionName = orgAggCall.getAggregation().getName(); + SqlAggFunction orgAggFunction = orgAggCall.getAggregation(); + String functionName = orgAggFunction.getName(); + SqlKind kind = orgAggFunction.getKind(); + SqlFunctionCategory functionCategory = orgAggFunction.getFunctionType(); if (orgAggCall.isDistinct()) { - if (functionName.equals("COUNT")) { + if (kind == SqlKind.COUNT) { functionName = "DISTINCTCOUNT"; - } else if (functionName.equals("LISTAGG")) { + kind = SqlKind.OTHER_FUNCTION; + functionCategory = SqlFunctionCategory.USER_DEFINED_FUNCTION; + } else if (kind == SqlKind.LISTAGG) { rexList.add(input.getCluster().getRexBuilder().makeLiteral(true)); } } - AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName); - SqlAggFunction sqlAggFunction; - switch (aggType) { - case DIRECT: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - ReturnTypes.explicit(orgAggCall.getType()), null, functionType.getOperandTypeChecker(), - functionType.getSqlFunctionCategory()); - break; - case LEAF: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - functionType.getIntermediateReturnTypeInference(), null, functionType.getOperandTypeChecker(), - functionType.getSqlFunctionCategory()); - break; - case INTERMEDIATE: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - functionType.getIntermediateReturnTypeInference(), null, OperandTypes.ANY, - functionType.getSqlFunctionCategory()); - break; - case FINAL: - sqlAggFunction = new PinotSqlAggFunction(functionName, null, functionType.getSqlKind(), - ReturnTypes.explicit(orgAggCall.getType()), null, OperandTypes.ANY, functionType.getSqlFunctionCategory()); - break; - default: - throw new IllegalStateException("Unsupported AggType: " + aggType); + SqlReturnTypeInference returnTypeInference = null; + RelDataType returnType = null; + // Override the intermediate result type inference if it is provided + if (aggType.isOutputIntermediateFormat()) { + AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName); + returnTypeInference = functionType.getIntermediateReturnTypeInference(); } + // When the output is not intermediate format, or intermediate result type inference is not provided (intermediate + // result type the same as final result type), use the explicit return type + if (returnTypeInference == null) { + returnType = orgAggCall.getType(); + returnTypeInference = ReturnTypes.explicit(returnType); + } + SqlOperandTypeChecker operandTypeChecker = + aggType.isInputIntermediateFormat() ? OperandTypes.ANY : orgAggFunction.getOperandTypeChecker(); + SqlAggFunction sqlAggFunction = + new PinotSqlAggFunction(functionName, kind, returnTypeInference, operandTypeChecker, functionCategory); return AggregateCall.create(sqlAggFunction, false, orgAggCall.isApproximate(), orgAggCall.ignoreNulls(), rexList, ImmutableList.of(), aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg, orgAggCall.distinctKeys, - orgAggCall.collation, numGroups, input, null, null); + orgAggCall.collation, numGroups, input, returnType, null); } @Nullable diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java new file mode 100644 index 000000000000..b27d1a636db4 --- /dev/null +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.calcite.rel.rules; + +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule; +import org.apache.calcite.sql.SqlKind; + + +/** + * Pinot customized version of {@link AggregateReduceFunctionsRule} which only reduce on SUM and AVG. + * We don't want to reduce on STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, COVAR_POP, COVAR_SAMP because Pinot supports + * them natively, but not REGR_COUNT which can be generated during reduce. + */ +public class PinotAggregateReduceFunctionsRule extends AggregateReduceFunctionsRule { + public static final PinotAggregateReduceFunctionsRule INSTANCE = + new PinotAggregateReduceFunctionsRule(Config.DEFAULT); + + private PinotAggregateReduceFunctionsRule(Config config) { + super(config); + } + + @Override + public boolean canReduce(AggregateCall call) { + SqlKind kind = call.getAggregation().getKind(); + return kind == SqlKind.SUM || kind == SqlKind.AVG; + } +} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java index ebf61df801a6..e93609a38a3d 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotAggregateToSemiJoinRule.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; +import javax.annotation.Nullable; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -36,7 +37,6 @@ import org.apache.calcite.tools.RelBuilderFactory; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.ImmutableIntList; -import org.checkerframework.checker.nullness.qual.Nullable; /** diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java index c0b7b4b21cb6..1d7f15ec5da9 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java @@ -38,6 +38,7 @@ import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.ArraySqlType; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilderFactory; @@ -45,6 +46,8 @@ import org.apache.pinot.common.function.FunctionInfo; import org.apache.pinot.common.function.FunctionInvoker; import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.query.planner.logical.RelToPlanNodeConverter; import org.apache.pinot.spi.utils.TimestampUtils; import org.apache.pinot.sql.parsers.SqlCompilationException; @@ -130,11 +133,9 @@ private static class EvaluateLiteralShuttle extends RexShuttle { public RexNode visitCall(RexCall call) { RexCall visitedCall = (RexCall) super.visitCall(call); // Check if all operands are RexLiteral - if (visitedCall.operands.stream().allMatch(operand -> - operand instanceof RexLiteral - || (operand instanceof RexCall && ((RexCall) operand).getOperands().stream() - .allMatch(op -> op instanceof RexLiteral)) - )) { + if (visitedCall.operands.stream().allMatch( + operand -> operand instanceof RexLiteral || (operand instanceof RexCall && ((RexCall) operand).getOperands() + .stream().allMatch(op -> op instanceof RexLiteral)))) { return evaluateLiteralOnlyFunction(visitedCall, _rexBuilder); } else { return visitedCall; @@ -147,38 +148,39 @@ public RexNode visitCall(RexCall call) { * itself (RexCall) if it cannot be evaluated. */ private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall, RexBuilder rexBuilder) { - String functionName = PinotRuleUtils.extractFunctionName(rexCall); List operands = rexCall.getOperands(); - assert operands.stream().allMatch(operand -> operand instanceof RexLiteral - || (operand instanceof RexCall && ((RexCall) operand).getOperands().stream() - .allMatch(op -> op instanceof RexLiteral))); - int numOperands = operands.size(); - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName, numOperands); - if (functionInfo == null) { - // Function cannot be evaluated - return rexCall; - } - Object[] arguments = new Object[numOperands]; - for (int i = 0; i < numOperands; i++) { + assert operands.stream().allMatch( + operand -> operand instanceof RexLiteral || (operand instanceof RexCall && ((RexCall) operand).getOperands() + .stream().allMatch(op -> op instanceof RexLiteral))); + int numArguments = operands.size(); + ColumnDataType[] argumentTypes = new ColumnDataType[numArguments]; + Object[] arguments = new Object[numArguments]; + for (int i = 0; i < numArguments; i++) { RexNode rexNode = operands.get(i); - if (rexNode instanceof RexCall - && ((RexCall) rexNode).getOperator().getName().equalsIgnoreCase("CAST")) { - // this must be a cast function - RexCall operand = (RexCall) rexNode; - arguments[i] = getLiteralValue((RexLiteral) operand.getOperands().get(0)); + RexLiteral rexLiteral; + if (rexNode instanceof RexCall && ((RexCall) rexNode).getOperator().getKind() == SqlKind.CAST) { + rexLiteral = (RexLiteral) ((RexCall) rexNode).getOperands().get(0); } else if (rexNode instanceof RexLiteral) { - arguments[i] = getLiteralValue((RexLiteral) rexNode); + rexLiteral = (RexLiteral) rexNode; } else { // Function operands cannot be evaluated, skip return rexCall; } + argumentTypes[i] = RelToPlanNodeConverter.convertToColumnDataType(rexLiteral.getType()); + arguments[i] = getLiteralValue(rexLiteral); + } + String canonicalName = FunctionRegistry.canonicalize(PinotRuleUtils.extractFunctionName(rexCall)); + FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes); + if (functionInfo == null) { + // Function cannot be evaluated + return rexCall; } RelDataType rexNodeType = rexCall.getType(); Object resultValue; try { FunctionInvoker invoker = new FunctionInvoker(functionInfo); if (functionInfo.getMethod().isVarArgs()) { - resultValue = invoker.invoke(new Object[] {arguments}); + resultValue = invoker.invoke(new Object[]{arguments}); } else { invoker.convertTypes(arguments); resultValue = invoker.invoke(arguments); @@ -188,8 +190,8 @@ private static RexNode evaluateLiteralOnlyFunction(RexCall rexCall, RexBuilder r if (componentType != null) { if (Objects.requireNonNull(componentType.getSqlTypeName()) == SqlTypeName.CHAR) { // Calcite uses CHAR for STRING, but we need to use VARCHAR for STRING - rexNodeType = rexBuilder.getTypeFactory().createArrayType( - rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1); + rexNodeType = rexBuilder.getTypeFactory() + .createArrayType(rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR), -1); } } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java index e75dadb26811..c85811fd742f 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotQueryRuleSets.java @@ -88,8 +88,9 @@ private PinotQueryRuleSets() { // aggregate union rule CoreRules.AGGREGATE_UNION_AGGREGATE, - // reduce aggregate functions like AVG, STDDEV_POP etc. - CoreRules.AGGREGATE_REDUCE_FUNCTIONS, + // reduce SUM and AVG + // TODO: Consider not reduce at all. + PinotAggregateReduceFunctionsRule.INSTANCE, // convert CASE-style filtered aggregates into true filtered aggregates // put it after AGGREGATE_REDUCE_FUNCTIONS where SUM is converted to SUM0 diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java index 9ced39e91a96..15b18317c3aa 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRule.java @@ -35,15 +35,13 @@ import org.apache.pinot.calcite.rel.logical.PinotLogicalSortExchange; import org.apache.pinot.query.planner.logical.RexExpressionUtils; import org.apache.pinot.query.type.TypeFactory; -import org.apache.pinot.query.type.TypeSystem; public class PinotSortExchangeCopyRule extends RelRule { - public static final PinotSortExchangeCopyRule SORT_EXCHANGE_COPY = PinotSortExchangeCopyRule.Config.DEFAULT.toRule(); private static final int DEFAULT_SORT_EXCHANGE_COPY_THRESHOLD = 10_000; - private static final TypeFactory TYPE_FACTORY = new TypeFactory(new TypeSystem()); + private static final TypeFactory TYPE_FACTORY = new TypeFactory(); private static final RexBuilder REX_BUILDER = new RexBuilder(TYPE_FACTORY); private static final RexLiteral REX_ZERO = REX_BUILDER.makeLiteral(0, TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER)); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/PinotSqlAggFunction.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/PinotSqlAggFunction.java deleted file mode 100644 index 56a6cb7f0e04..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/PinotSqlAggFunction.java +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.calcite.sql; - -import org.apache.calcite.sql.SqlAggFunction; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.type.SqlOperandTypeChecker; -import org.apache.calcite.sql.type.SqlOperandTypeInference; -import org.apache.calcite.sql.type.SqlReturnTypeInference; -import org.apache.calcite.util.Optionality; -import org.checkerframework.checker.nullness.qual.Nullable; - - -/** - * Pinot SqlAggFunction class to register the Pinot aggregation functions with the Calcite operator table. - */ -public class PinotSqlAggFunction extends SqlAggFunction { - - public PinotSqlAggFunction(String name, @Nullable SqlIdentifier sqlIdentifier, SqlKind sqlKind, - SqlReturnTypeInference returnTypeInference, @Nullable SqlOperandTypeInference sqlOperandTypeInference, - @Nullable SqlOperandTypeChecker operandTypeChecker, SqlFunctionCategory sqlFunctionCategory) { - this(name, sqlIdentifier, sqlKind, returnTypeInference, sqlOperandTypeInference, operandTypeChecker, - sqlFunctionCategory, false, false, Optionality.FORBIDDEN); - } - - public PinotSqlAggFunction(String name, @Nullable SqlIdentifier sqlIdentifier, SqlKind sqlKind, - SqlReturnTypeInference returnTypeInference, @Nullable SqlOperandTypeInference sqlOperandTypeInference, - @Nullable SqlOperandTypeChecker operandTypeChecker, SqlFunctionCategory sqlFunctionCategory, - boolean requiresOrder, boolean requiresOver, Optionality optionality) { - super(name, sqlIdentifier, sqlKind, returnTypeInference, sqlOperandTypeInference, operandTypeChecker, - sqlFunctionCategory, requiresOrder, requiresOver, optionality); - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java index 68dbce88f362..fd6f9e6d8766 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotOperatorTable.java @@ -18,155 +18,349 @@ */ package org.apache.pinot.calcite.sql.fun; -import java.lang.reflect.Field; -import java.util.ArrayList; +import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; +import java.util.HashMap; import java.util.List; -import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; +import javax.annotation.Nullable; import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.SqlSyntax; import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.validate.SqlNameMatchers; -import org.apache.calcite.util.Util; -import org.apache.pinot.calcite.sql.PinotSqlAggFunction; -import org.apache.pinot.calcite.sql.PinotSqlTransformFunction; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeTransforms; +import org.apache.calcite.sql.validate.SqlNameMatcher; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.pinot.common.function.FunctionRegistry; +import org.apache.pinot.common.function.PinotScalarFunction; import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.common.function.sql.PinotSqlAggFunction; +import org.apache.pinot.common.function.sql.PinotSqlFunction; import org.apache.pinot.segment.spi.AggregationFunctionType; -import org.checkerframework.checker.nullness.qual.MonotonicNonNull; /** - * {@link PinotOperatorTable} defines the {@link SqlOperator} overrides on top of the {@link SqlStdOperatorTable}. - * - *

The main purpose of this Pinot specific SQL operator table is to + * This class defines all the {@link SqlOperator}s allowed by Pinot. + *

It contains the following types of operators: *

+ *

The core method is {@link #lookupOperatorOverloads} which is used to look up the {@link SqlOperator} with the + * {@link SqlIdentifier} during query parsing. */ @SuppressWarnings("unused") // unused fields are accessed by reflection -public class PinotOperatorTable extends SqlStdOperatorTable { - - private static @MonotonicNonNull PinotOperatorTable _instance; - - // TODO: clean up lazy init by using Suppliers.memorized(this::computeInstance) and make getter wrapped around - // supplier instance. this should replace all lazy init static objects in the codebase - public static synchronized PinotOperatorTable instance() { - if (_instance == null) { - // Creates and initializes the standard operator table. - // Uses two-phase construction, because we can't initialize the - // table until the constructor of the sub-class has completed. - _instance = new PinotOperatorTable(); - _instance.initNoDuplicate(); - } - return _instance; +public class PinotOperatorTable implements SqlOperatorTable { + private static final Supplier INSTANCE = Suppliers.memoize(PinotOperatorTable::new); + + public static PinotOperatorTable instance() { + return INSTANCE.get(); } /** - * Initialize without duplicate, e.g. when 2 duplicate operator is linked with the same op - * {@link org.apache.calcite.sql.SqlKind} it causes problem. - * - *

This is a direct copy of the {@link org.apache.calcite.sql.util.ReflectiveSqlOperatorTable} and can be hard to - * debug, suggest changing to a non-dynamic registration. Dynamic function support should happen via catalog. - * - * This also registers aggregation functions defined in {@link org.apache.pinot.segment.spi.AggregationFunctionType} - * which are multistage enabled. + * This list includes the supported standard {@link SqlOperator}s defined in {@link SqlStdOperatorTable}. + * NOTE: The operator order follows the same order as defined in {@link SqlStdOperatorTable} for easier search. + * Some operators are commented out and re-declared in {@link #STANDARD_OPERATORS_WITH_ALIASES}. + * TODO: Add more operators as needed. */ - public final void initNoDuplicate() { - // Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. - register(new PinotSqlCoalesceFunction()); - // Ensure ArrayValueConstructor is registered before ArrayQueryConstructor - register(ARRAY_VALUE_CONSTRUCTOR); - - // TODO: reflection based registration is not ideal, we should use a static list of operators and register them - // Use reflection to register the expressions stored in public fields. - for (Field field : getClass().getFields()) { - try { - if (SqlFunction.class.isAssignableFrom(field.getType())) { - SqlFunction op = (SqlFunction) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } else if (SqlOperator.class.isAssignableFrom(field.getType())) { - SqlOperator op = (SqlOperator) field.get(this); - if (op != null && notRegistered(op)) { - register(op); - } - } - } catch (IllegalArgumentException | IllegalAccessException e) { - throw Util.throwAsRuntime(Util.causeOrSelf(e)); + //@formatter:off + private static final List STANDARD_OPERATORS = List.of( + // SET OPERATORS + SqlStdOperatorTable.UNION, + SqlStdOperatorTable.UNION_ALL, + SqlStdOperatorTable.EXCEPT, + SqlStdOperatorTable.EXCEPT_ALL, + SqlStdOperatorTable.INTERSECT, + SqlStdOperatorTable.INTERSECT_ALL, + + // BINARY OPERATORS + SqlStdOperatorTable.AND, + SqlStdOperatorTable.AS, + SqlStdOperatorTable.FILTER, + SqlStdOperatorTable.WITHIN_GROUP, + SqlStdOperatorTable.WITHIN_DISTINCT, + SqlStdOperatorTable.CONCAT, + SqlStdOperatorTable.DIVIDE, + SqlStdOperatorTable.PERCENT_REMAINDER, + SqlStdOperatorTable.DOT, + SqlStdOperatorTable.EQUALS, + SqlStdOperatorTable.GREATER_THAN, + SqlStdOperatorTable.IS_DISTINCT_FROM, + SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, + SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, + SqlStdOperatorTable.IN, + SqlStdOperatorTable.NOT_IN, + SqlStdOperatorTable.SEARCH, + SqlStdOperatorTable.LESS_THAN, + SqlStdOperatorTable.LESS_THAN_OR_EQUAL, + SqlStdOperatorTable.MINUS, + SqlStdOperatorTable.MULTIPLY, + SqlStdOperatorTable.NOT_EQUALS, + SqlStdOperatorTable.OR, + SqlStdOperatorTable.PLUS, + SqlStdOperatorTable.INTERVAL, + + // POSTFIX OPERATORS + SqlStdOperatorTable.DESC, + SqlStdOperatorTable.NULLS_FIRST, + SqlStdOperatorTable.NULLS_LAST, + SqlStdOperatorTable.IS_NOT_NULL, + SqlStdOperatorTable.IS_NULL, + SqlStdOperatorTable.IS_NOT_TRUE, + SqlStdOperatorTable.IS_TRUE, + SqlStdOperatorTable.IS_NOT_FALSE, + SqlStdOperatorTable.IS_FALSE, + SqlStdOperatorTable.IS_NOT_UNKNOWN, + SqlStdOperatorTable.IS_UNKNOWN, + + // PREFIX OPERATORS + SqlStdOperatorTable.EXISTS, + SqlStdOperatorTable.NOT, + + // AGGREGATE OPERATORS + SqlStdOperatorTable.SUM, + SqlStdOperatorTable.COUNT, + SqlStdOperatorTable.MODE, + SqlStdOperatorTable.MIN, + SqlStdOperatorTable.MAX, + SqlStdOperatorTable.LAST_VALUE, + SqlStdOperatorTable.FIRST_VALUE, + SqlStdOperatorTable.LEAD, + SqlStdOperatorTable.LAG, + SqlStdOperatorTable.AVG, + SqlStdOperatorTable.STDDEV_POP, + SqlStdOperatorTable.COVAR_POP, + SqlStdOperatorTable.COVAR_SAMP, + SqlStdOperatorTable.STDDEV_SAMP, + SqlStdOperatorTable.VAR_POP, + SqlStdOperatorTable.VAR_SAMP, + SqlStdOperatorTable.SUM0, + + // WINDOW Rank Functions + SqlStdOperatorTable.DENSE_RANK, + SqlStdOperatorTable.RANK, + SqlStdOperatorTable.ROW_NUMBER, + + // SPECIAL OPERATORS + SqlStdOperatorTable.BETWEEN, + SqlStdOperatorTable.SYMMETRIC_BETWEEN, + SqlStdOperatorTable.NOT_BETWEEN, + SqlStdOperatorTable.SYMMETRIC_NOT_BETWEEN, + SqlStdOperatorTable.NOT_LIKE, + SqlStdOperatorTable.LIKE, +// SqlStdOperatorTable.CASE, + SqlStdOperatorTable.OVER, + + // FUNCTIONS + // String functions + SqlStdOperatorTable.SUBSTRING, + SqlStdOperatorTable.REPLACE, + SqlStdOperatorTable.TRIM, + SqlStdOperatorTable.UPPER, + SqlStdOperatorTable.LOWER, + // Arithmetic functions + SqlStdOperatorTable.POWER, + SqlStdOperatorTable.SQRT, + SqlStdOperatorTable.MOD, +// SqlStdOperatorTable.LN, + SqlStdOperatorTable.LOG10, + SqlStdOperatorTable.ABS, + SqlStdOperatorTable.ACOS, + SqlStdOperatorTable.ASIN, + SqlStdOperatorTable.ATAN, + SqlStdOperatorTable.ATAN2, + SqlStdOperatorTable.COS, + SqlStdOperatorTable.COT, + SqlStdOperatorTable.DEGREES, + SqlStdOperatorTable.EXP, + SqlStdOperatorTable.RADIANS, + SqlStdOperatorTable.ROUND, + SqlStdOperatorTable.SIGN, + SqlStdOperatorTable.SIN, + SqlStdOperatorTable.TAN, + SqlStdOperatorTable.TRUNCATE, + SqlStdOperatorTable.FLOOR, + SqlStdOperatorTable.CEIL, + SqlStdOperatorTable.TIMESTAMP_ADD, + SqlStdOperatorTable.TIMESTAMP_DIFF, + SqlStdOperatorTable.CAST, + + SqlStdOperatorTable.EXTRACT, + // TODO: The following operators are all rewritten to EXTRACT. Consider removing them because they are all + // supported without rewrite. + SqlStdOperatorTable.YEAR, + SqlStdOperatorTable.QUARTER, + SqlStdOperatorTable.MONTH, + SqlStdOperatorTable.WEEK, + SqlStdOperatorTable.DAYOFYEAR, + SqlStdOperatorTable.DAYOFMONTH, + SqlStdOperatorTable.DAYOFWEEK, + SqlStdOperatorTable.HOUR, + SqlStdOperatorTable.MINUTE, + SqlStdOperatorTable.SECOND, + + SqlStdOperatorTable.ITEM, + SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, + SqlStdOperatorTable.LISTAGG + ); + + private static final List>> STANDARD_OPERATORS_WITH_ALIASES = List.of( + Pair.of(SqlStdOperatorTable.CASE, List.of("CASE", "CASE_WHEN")), + Pair.of(SqlStdOperatorTable.LN, List.of("LN", "LOG")) + ); + + /** + * This list includes the customized {@link SqlOperator}s. + */ + private static final List PINOT_OPERATORS = List.of( + // Placeholder for special predicates + new PinotSqlFunction("TEXT_MATCH", ReturnTypes.BOOLEAN, OperandTypes.CHARACTER_CHARACTER), + new PinotSqlFunction("TEXT_CONTAINS", ReturnTypes.BOOLEAN, OperandTypes.CHARACTER_CHARACTER), + new PinotSqlFunction("JSON_MATCH", ReturnTypes.BOOLEAN, OperandTypes.CHARACTER_CHARACTER), + new PinotSqlFunction("VECTOR_SIMILARITY", ReturnTypes.BOOLEAN, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 2)), + + // Placeholder for special functions to handle MV + // NOTE: + // ARRAY_TO_MV is not deterministic. + // We need to explicitly set it as not deterministic in order to do not let Calcite optimize expressions like + // `ARRAY_TO_MV(RandomAirports) = 'MFR' and ARRAY_TO_MV(RandomAirports) = 'GTR'` as `false`. + // If the function were deterministic, its value would never be MFR and GTR at the same time, so Calcite is + // smart enough to know there is no value that satisfies the condition. + // In fact what ARRAY_TO_MV does is just to trick Calcite typesystem, but then what the leaf stage executor + // receives is `RandomAirports = 'MFR' and RandomAirports = 'GTR'`, which in the V1 semantics means: + // true if and only if RandomAirports contains a value equal to 'MFR' and RandomAirports contains a value equal + // to 'GTR' + new PinotSqlFunction("ARRAY_TO_MV", opBinding -> opBinding.getOperandType(0).getComponentType(), + OperandTypes.ARRAY, false), + + // SqlStdOperatorTable.COALESCE without rewrite + new SqlFunction("COALESCE", SqlKind.COALESCE, + ReturnTypes.LEAST_RESTRICTIVE.andThen(SqlTypeTransforms.LEAST_NULLABLE), null, OperandTypes.SAME_VARIADIC, + SqlFunctionCategory.SYSTEM), + + // The scalar function version returns long instead of Timestamp + // TODO: Consider unifying the return type to Timestamp + new PinotSqlFunction("FROM_DATE_TIME", ReturnTypes.TIMESTAMP_NULLABLE, OperandTypes.family( + List.of(SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY), + i -> i > 1)) + ); + + private static final List>> PINOT_OPERATORS_WITH_ALIASES = List.of( + ); + //@formatter:on + + // Key is canonical name + private final Map _operatorMap; + private final List _operatorList; + + private PinotOperatorTable() { + Map operatorMap = new HashMap<>(); + + // Register standard operators + for (SqlOperator operator : STANDARD_OPERATORS) { + register(operator.getName(), operator, operatorMap); + } + for (Pair> pair : STANDARD_OPERATORS_WITH_ALIASES) { + SqlOperator operator = pair.getLeft(); + for (String name : pair.getRight()) { + register(name, operator, operatorMap); } } - // Walk through all the Pinot aggregation types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) { - if (aggregationFunctionType.getSqlKind() != null) { - // 1. Register the aggregation function with Calcite - registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType); - // 2. Register the aggregation function with Calcite on all alternative names - List alternativeFunctionNames = aggregationFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerAggregateFunction(alternativeFunctionName, aggregationFunctionType); - } + // Register Pinot operators + for (SqlOperator operator : PINOT_OPERATORS) { + register(operator.getName(), operator, operatorMap); + } + for (Pair> pair : PINOT_OPERATORS_WITH_ALIASES) { + SqlOperator operator = pair.getLeft(); + for (String name : pair.getRight()) { + register(name, operator, operatorMap); } } - // Walk through all the Pinot transform types and - // 1. register those that are supported in multistage in addition to calcite standard opt table. - // 2. register special handling that differs from calcite standard. - for (TransformFunctionType transformFunctionType : TransformFunctionType.values()) { - if (transformFunctionType.getSqlKind() != null) { - // 1. Register the transform function with Calcite - registerTransformFunction(transformFunctionType.getName(), transformFunctionType); - // 2. Register the transform function with Calcite on all alternative names - List alternativeFunctionNames = transformFunctionType.getAlternativeNames(); - for (String alternativeFunctionName : alternativeFunctionNames) { - registerTransformFunction(alternativeFunctionName, transformFunctionType); - } + registerAggregateFunctions(operatorMap); + registerTransformFunctions(operatorMap); + registerScalarFunctions(operatorMap); + + _operatorMap = Map.copyOf(operatorMap); + _operatorList = List.copyOf(operatorMap.values()); + } + + private void register(String name, SqlOperator sqlOperator, Map operatorMap) { + Preconditions.checkState(operatorMap.put(FunctionRegistry.canonicalize(name), sqlOperator) == null, + "SqlOperator: %s is already registered", name); + } + + private void registerAggregateFunctions(Map operatorMap) { + for (AggregationFunctionType functionType : AggregationFunctionType.values()) { + if (functionType.getReturnTypeInference() != null) { + String functionName = functionType.getName(); + PinotSqlAggFunction function = new PinotSqlAggFunction(functionName, functionType.getReturnTypeInference(), + functionType.getOperandTypeChecker()); + Preconditions.checkState(operatorMap.put(FunctionRegistry.canonicalize(functionName), function) == null, + "Aggregate function: %s is already registered", functionName); } } } - private void registerAggregateFunction(String functionName, AggregationFunctionType functionType) { - // register function behavior that's different from Calcite - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlAggFunction sqlAggFunction = - new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null, functionType.getSqlKind(), - functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(), - functionType.getSqlFunctionCategory()); - if (notRegistered(sqlAggFunction)) { - register(sqlAggFunction); + private void registerTransformFunctions(Map operatorMap) { + for (TransformFunctionType functionType : TransformFunctionType.values()) { + if (functionType.getReturnTypeInference() != null) { + PinotSqlFunction function = new PinotSqlFunction(functionType.getName(), functionType.getReturnTypeInference(), + functionType.getOperandTypeChecker()); + for (String name : functionType.getNames()) { + Preconditions.checkState(operatorMap.put(FunctionRegistry.canonicalize(name), function) == null, + "Transform function: %s is already registered", name); + } } } } - private void registerTransformFunction(String functionName, TransformFunctionType functionType) { - // register function behavior that's different from Calcite - if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) { - PinotSqlTransformFunction sqlTransformFunction = - new PinotSqlTransformFunction(functionName.toUpperCase(Locale.ROOT), functionType.getSqlKind(), - functionType.getReturnTypeInference(), null, functionType.getOperandTypeChecker(), - functionType.getSqlFunctionCategory(), functionType.isDeterministic()); - if (notRegistered(sqlTransformFunction)) { - register(sqlTransformFunction); + private void registerScalarFunctions(Map operatorMap) { + for (Map.Entry entry : FunctionRegistry.FUNCTION_MAP.entrySet()) { + String canonicalName = entry.getKey(); + PinotScalarFunction scalarFunction = entry.getValue(); + PinotSqlFunction sqlFunction = scalarFunction.toPinotSqlFunction(); + if (sqlFunction == null) { + continue; + } + if (operatorMap.containsKey(canonicalName)) { + // Skip registering ArgumentCountBasedScalarFunction if it is already registered + Preconditions.checkState(scalarFunction instanceof FunctionRegistry.ArgumentCountBasedScalarFunction, + "Scalar function: %s is already registered", canonicalName); + continue; } + operatorMap.put(canonicalName, sqlFunction); } } - private boolean notRegistered(SqlFunction op) { - List operatorList = new ArrayList<>(); - lookupOperatorOverloads(op.getNameAsId(), op.getFunctionType(), op.getSyntax(), operatorList, - SqlNameMatchers.withCaseSensitive(false)); - return operatorList.size() == 0; + @Override + public void lookupOperatorOverloads(SqlIdentifier opName, @Nullable SqlFunctionCategory category, SqlSyntax syntax, + List operatorList, SqlNameMatcher nameMatcher) { + if (!opName.isSimple()) { + return; + } + String canonicalName = FunctionRegistry.canonicalize(opName.getSimple()); + SqlOperator operator = _operatorMap.get(canonicalName); + if (operator != null) { + operatorList.add(operator); + } } - private boolean notRegistered(SqlOperator op) { - List operatorList = new ArrayList<>(); - lookupOperatorOverloads(op.getNameAsId(), null, op.getSyntax(), operatorList, - SqlNameMatchers.withCaseSensitive(false)); - return operatorList.size() == 0; + @Override + public List getOperatorList() { + return _operatorList; } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotSqlCoalesceFunction.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotSqlCoalesceFunction.java deleted file mode 100644 index 086c42a57247..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/fun/PinotSqlCoalesceFunction.java +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.calcite.sql.fun; - -import org.apache.calcite.sql.SqlCall; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.fun.SqlCoalesceFunction; -import org.apache.calcite.sql.validate.SqlValidator; - - -/** - * Pinot supports native COALESCE function, thus no need to create CASE WHEN conversion. - */ -public class PinotSqlCoalesceFunction extends SqlCoalesceFunction { - - @Override - public SqlNode rewriteCall(SqlValidator validator, SqlCall call) { - return call; - } -} diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/util/PinotChainedSqlOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/util/PinotChainedSqlOperatorTable.java deleted file mode 100644 index 8d4db390b712..000000000000 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql/util/PinotChainedSqlOperatorTable.java +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.pinot.calcite.sql.util; - -import com.google.common.collect.ImmutableList; -import java.util.ArrayList; -import java.util.List; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.SqlOperatorTable; -import org.apache.calcite.sql.SqlSyntax; -import org.apache.calcite.sql.util.ChainedSqlOperatorTable; -import org.apache.calcite.sql.util.SqlOperatorTables; -import org.apache.calcite.sql.validate.SqlNameMatcher; -import org.checkerframework.checker.nullness.qual.Nullable; - - -/** - * ================================================================================================================= - * THIS CLASS IS COPIED FROM Calcite's {@link ChainedSqlOperatorTable} and modified the function lookup to terminate - * early once found from ordered SqlOperatorTable list. This is to avoid some hard-coded casting assuming all Sql - * identifier looked-up are of the same SqlOperator type. - * ================================================================================================================= - * - * PinotChainedSqlOperatorTable implements the {@link SqlOperatorTable} interface by chaining together any number of - * underlying operator table instances. - */ -//@formatter:off -public class PinotChainedSqlOperatorTable implements SqlOperatorTable { - //~ Instance fields -------------------------------------------------------- - - protected final List _tableList; - - //~ Constructors ----------------------------------------------------------- - - public PinotChainedSqlOperatorTable(List tableList) { - this(ImmutableList.copyOf(tableList)); - } - - /** Internal constructor; call {@link SqlOperatorTables#chain}. */ - protected PinotChainedSqlOperatorTable(ImmutableList tableList) { - _tableList = ImmutableList.copyOf(tableList); - } - - //~ Methods ---------------------------------------------------------------- - - @Deprecated // to be removed before 2.0 - public void add(SqlOperatorTable table) { - if (!_tableList.contains(table)) { - _tableList.add(table); - } - } - - @Override public void lookupOperatorOverloads(SqlIdentifier opName, - @Nullable SqlFunctionCategory category, SqlSyntax syntax, - List operatorList, SqlNameMatcher nameMatcher) { - for (SqlOperatorTable table : _tableList) { - table.lookupOperatorOverloads(opName, category, syntax, operatorList, - nameMatcher); - // ==================================================================== - // LINES CHANGED BELOW - // ==================================================================== - if (!operatorList.isEmpty()) { - break; - } - // ==================================================================== - // LINES CHANGED ABOVE - // ==================================================================== - } - } - - @Override public List getOperatorList() { - List list = new ArrayList<>(); - for (SqlOperatorTable table : _tableList) { - list.addAll(table.getOperatorList()); - } - return list; - } -} -//@formatter:on diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java index 5633d58a6972..f8e8545530f6 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/sql2rel/PinotConvertletTable.java @@ -19,6 +19,7 @@ package org.apache.pinot.calcite.sql2rel; import java.util.List; +import javax.annotation.Nullable; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlCall; @@ -27,7 +28,6 @@ import org.apache.calcite.sql2rel.SqlRexConvertlet; import org.apache.calcite.sql2rel.SqlRexConvertletTable; import org.apache.calcite.sql2rel.StandardConvertletTable; -import org.checkerframework.checker.nullness.qual.Nullable; /** diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java index 9c53cdee6a9a..1a54eec610bc 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/QueryEnvironment.java @@ -20,11 +20,11 @@ import com.google.common.annotations.VisibleForTesting; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Properties; import java.util.Set; import javax.annotation.Nullable; +import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionConfigImpl; import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.jdbc.CalciteSchema; @@ -35,10 +35,9 @@ import org.apache.calcite.plan.hep.HepMatchOrder; import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.plan.hep.HepProgramBuilder; -import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.prepare.CalciteCatalogReader; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; -import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.runtime.CalciteContextException; import org.apache.calcite.sql.SqlExplain; @@ -51,14 +50,13 @@ import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; -import org.apache.pinot.calcite.prepare.PinotCalciteCatalogReader; import org.apache.pinot.calcite.rel.rules.PinotQueryRuleSets; import org.apache.pinot.calcite.rel.rules.PinotRelDistributionTraitRule; import org.apache.pinot.calcite.rel.rules.PinotRuleUtils; import org.apache.pinot.calcite.sql.fun.PinotOperatorTable; -import org.apache.pinot.calcite.sql.util.PinotChainedSqlOperatorTable; import org.apache.pinot.calcite.sql2rel.PinotConvertletTable; import org.apache.pinot.common.config.provider.TableCache; +import org.apache.pinot.query.catalog.PinotCatalog; import org.apache.pinot.query.context.PlannerContext; import org.apache.pinot.query.planner.PlannerUtils; import org.apache.pinot.query.planner.SubPlan; @@ -81,30 +79,34 @@ *

It provide the higher level entry interface to convert a SQL string into a {@link DispatchableSubPlan}. */ public class QueryEnvironment { - // Calcite configurations - private final RelDataTypeFactory _typeFactory; - private final Prepare.CatalogReader _catalogReader; + private static final CalciteConnectionConfig CONNECTION_CONFIG; + + static { + Properties connectionConfigProperties = new Properties(); + connectionConfigProperties.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName(), "true"); + CONNECTION_CONFIG = new CalciteConnectionConfigImpl(connectionConfigProperties); + } + + private final TypeFactory _typeFactory = new TypeFactory(); private final FrameworkConfig _config; + private final CalciteCatalogReader _catalogReader; private final HepProgram _optProgram; private final HepProgram _traitProgram; // Pinot extensions - private final WorkerManager _workerManager; private final TableCache _tableCache; + private final WorkerManager _workerManager; - public QueryEnvironment(TypeFactory typeFactory, CalciteSchema rootSchema, WorkerManager workerManager, - TableCache tableCache) { - _typeFactory = typeFactory; - // Calcite extension/plugins - _workerManager = workerManager; - _tableCache = tableCache; - - // catalog & config - _catalogReader = getCatalogReader(_typeFactory, rootSchema); - _config = getConfig(_catalogReader); - // opt programs + public QueryEnvironment(String database, TableCache tableCache, @Nullable WorkerManager workerManager) { + PinotCatalog catalog = new PinotCatalog(database, tableCache); + CalciteSchema rootSchema = CalciteSchema.createRootSchema(false, false, database, catalog); + _config = Frameworks.newConfigBuilder().traitDefs().operatorTable(PinotOperatorTable.instance()) + .defaultSchema(rootSchema.plus()).sqlToRelConverterConfig(PinotRuleUtils.PINOT_SQL_TO_REL_CONFIG).build(); + _catalogReader = new CalciteCatalogReader(rootSchema, List.of(database), _typeFactory, CONNECTION_CONFIG); _optProgram = getOptProgram(); _traitProgram = getTraitProgram(); + _tableCache = tableCache; + _workerManager = workerManager; } private PlannerContext getPlannerContext() { @@ -301,20 +303,6 @@ private DispatchableSubPlan toDispatchableSubPlan(RelRoot relRoot, PlannerContex // utils // -------------------------------------------------------------------------- - private static Prepare.CatalogReader getCatalogReader(RelDataTypeFactory typeFactory, CalciteSchema rootSchema) { - Properties properties = new Properties(); - properties.setProperty(CalciteConnectionProperty.CASE_SENSITIVE.camelName(), "true"); - return new PinotCalciteCatalogReader(rootSchema, List.of(rootSchema.getName()), typeFactory, - new CalciteConnectionConfigImpl(properties)); - } - - private static FrameworkConfig getConfig(Prepare.CatalogReader catalogReader) { - return Frameworks.newConfigBuilder().traitDefs() - .operatorTable(new PinotChainedSqlOperatorTable(Arrays.asList(PinotOperatorTable.instance(), catalogReader))) - .defaultSchema(catalogReader.getRootSchema().plus()) - .sqlToRelConverterConfig(PinotRuleUtils.PINOT_SQL_TO_REL_CONFIG).build(); - } - private static HepProgram getOptProgram() { HepProgramBuilder hepProgramBuilder = new HepProgramBuilder(); // Set the match order as DEPTH_FIRST. The default is arbitrary which works the same as DEPTH_FIRST, but it's diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java index cb9630e7244d..7a364b56afab 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/catalog/PinotCatalog.java @@ -20,7 +20,6 @@ import com.google.common.base.Preconditions; import java.util.Collection; -import java.util.Collections; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -32,7 +31,6 @@ import org.apache.calcite.schema.SchemaVersion; import org.apache.calcite.schema.Schemas; import org.apache.calcite.schema.Table; -import org.apache.pinot.calcite.jdbc.CalciteSchemaBuilder; import org.apache.pinot.common.config.provider.TableCache; import org.apache.pinot.common.utils.DatabaseUtils; import org.apache.pinot.spi.utils.builder.TableNameBuilder; @@ -87,8 +85,7 @@ public Table getTable(String name) { */ @Override public Set getTableNames() { - return _tableCache.getTableNameMap().keySet().stream() - .filter(n -> DatabaseUtils.isPartOfDatabase(n, _databaseName)) + return _tableCache.getTableNameMap().keySet().stream().filter(n -> DatabaseUtils.isPartOfDatabase(n, _databaseName)) .collect(Collectors.toSet()); } @@ -99,25 +96,17 @@ public RelProtoDataType getType(String name) { @Override public Set getTypeNames() { - return Collections.emptySet(); + return Set.of(); } - /** - * {@code PinotCatalog} doesn't need to return function collections b/c they are already registered. - * see: {@link CalciteSchemaBuilder#asRootSchema(Schema, String)} - */ @Override public Collection getFunctions(String name) { - return Collections.emptyList(); + return Set.of(); } - /** - * {@code PinotCatalog} doesn't need to return function name set b/c they are already registered. - * see: {@link CalciteSchemaBuilder#asRootSchema(Schema, String)} - */ @Override public Set getFunctionNames() { - return Collections.emptySet(); + return Set.of(); } @Override @@ -127,7 +116,7 @@ public Schema getSubSchema(String name) { @Override public Set getSubSchemaNames() { - return Collections.emptySet(); + return Set.of(); } @Override diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/utils/ParserUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/utils/ParserUtils.java index 9e897483a4d2..f914e6ba8d3e 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/utils/ParserUtils.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/utils/ParserUtils.java @@ -18,10 +18,8 @@ */ package org.apache.pinot.query.parser.utils; -import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.pinot.common.config.provider.TableCache; import org.apache.pinot.query.QueryEnvironment; -import org.apache.pinot.query.type.TypeFactory; -import org.apache.pinot.query.type.TypeSystem; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,16 +31,13 @@ private ParserUtils() { } /** - * @param query the query string to be parsed and compiled - * @param calciteSchema the Calcite schema to be used for compilation - * @return true if the query can be parsed and compiled using the v2 multi-stage query engine + * Returns whether the query can be parsed and compiled using the multi-stage query engine. */ - public static boolean canCompileQueryUsingV2Engine(String query, CalciteSchema calciteSchema) { + public static boolean canCompileWithMultiStageEngine(String query, String database, TableCache tableCache) { // try to parse and compile the query with the Calcite planner used by the multi-stage query engine try { LOGGER.info("Trying to compile query `{}` using the multi-stage query engine", query); - QueryEnvironment queryEnvironment = - new QueryEnvironment(new TypeFactory(new TypeSystem()), calciteSchema, null, null); + QueryEnvironment queryEnvironment = new QueryEnvironment(database, tableCache, null); queryEnvironment.getTableNamesForQuery(query); LOGGER.info("Successfully compiled query using the multi-stage query engine: `{}`", query); return true; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java index d06ee0473a2c..3419fdc5bb03 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java @@ -21,9 +21,9 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; +import javax.annotation.Nullable; import org.apache.calcite.rex.RexNode; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; -import org.checkerframework.checker.nullness.qual.Nullable; /** diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java index 0c6c4ee4b2e8..0004e5d1c426 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java @@ -26,6 +26,7 @@ import java.util.Calendar; import java.util.List; import java.util.Set; +import javax.annotation.Nullable; import org.apache.calcite.avatica.util.ByteString; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Window; @@ -41,7 +42,6 @@ import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.spi.utils.BooleanUtils; import org.apache.pinot.spi.utils.ByteArray; -import org.checkerframework.checker.nullness.qual.Nullable; @SuppressWarnings({"rawtypes", "unchecked"}) diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeFactory.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeFactory.java index 054d19113492..8e3ec003e8ba 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeFactory.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeFactory.java @@ -21,10 +21,8 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Map; -import java.util.function.Predicate; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.Schema; @@ -41,44 +39,34 @@ * upgrading Calcite versions. */ public class TypeFactory extends JavaTypeFactoryImpl { - private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8; - public TypeFactory(RelDataTypeSystem typeSystem) { - super(typeSystem); + public TypeFactory() { + super(TypeSystem.INSTANCE); } @Override public Charset getDefaultCharset() { - return DEFAULT_CHARSET; + return StandardCharsets.UTF_8; } public RelDataType createRelDataTypeFromSchema(Schema schema) { Builder builder = new Builder(this); - Predicate isNullable; - if (schema.isEnableColumnBasedNullHandling()) { - isNullable = FieldSpec::isNullable; - } else { - isNullable = fieldSpec -> false; - } - for (Map.Entry e : schema.getFieldSpecMap().entrySet()) { - builder.add(e.getKey(), toRelDataType(e.getValue(), isNullable)); + boolean enableNullHandling = schema.isEnableColumnBasedNullHandling(); + for (Map.Entry entry : schema.getFieldSpecMap().entrySet()) { + builder.add(entry.getKey(), toRelDataType(entry.getValue(), enableNullHandling)); } return builder.build(); } - private RelDataType toRelDataType(FieldSpec fieldSpec, Predicate isNullable) { + private RelDataType toRelDataType(FieldSpec fieldSpec, boolean enableNullHandling) { RelDataType type = createSqlType(getSqlTypeName(fieldSpec)); - boolean isArray = !fieldSpec.isSingleValueField(); - if (isArray) { + if (!fieldSpec.isSingleValueField()) { type = createArrayType(type, -1); } - if (isNullable.test(fieldSpec)) { - type = createTypeWithNullability(type, true); - } - return type; + return enableNullHandling && fieldSpec.isNullable() ? createTypeWithNullability(type, true) : type; } - private SqlTypeName getSqlTypeName(FieldSpec fieldSpec) { + private static SqlTypeName getSqlTypeName(FieldSpec fieldSpec) { switch (fieldSpec.getDataType()) { case INT: return SqlTypeName.INTEGER; diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeSystem.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeSystem.java index fd0bab96a601..5268e91edf0a 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeSystem.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/type/TypeSystem.java @@ -29,6 +29,8 @@ * The {@code TypeSystem} overwrites Calcite type system with Pinot specific logics. */ public class TypeSystem extends RelDataTypeSystemImpl { + public static final TypeSystem INSTANCE = new TypeSystem(); + private static final int MAX_DECIMAL_SCALE = 1000; private static final int MAX_DECIMAL_PRECISION = 1000; @@ -44,6 +46,9 @@ public class TypeSystem extends RelDataTypeSystemImpl { private static final int DERIVED_DECIMAL_PRECISION = 19; private static final int DERIVED_DECIMAL_SCALE = 1; + private TypeSystem() { + } + @Override public boolean shouldConvertRaggedUnionTypesToVarying() { // A "ragged" union refers to a union of two or more data types that don't all @@ -68,8 +73,7 @@ public int getMaxNumericPrecision() { } @Override - public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory, - RelDataType argumentType) { + public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory, RelDataType argumentType) { assert SqlTypeUtil.isNumeric(argumentType); switch (argumentType.getSqlTypeName()) { @@ -84,8 +88,7 @@ public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory, } @Override - public RelDataType deriveSumType(RelDataTypeFactory typeFactory, - RelDataType argumentType) { + public RelDataType deriveSumType(RelDataTypeFactory typeFactory, RelDataType argumentType) { assert SqlTypeUtil.isNumeric(argumentType); switch (argumentType.getSqlTypeName()) { case TINYINT: @@ -100,44 +103,40 @@ public RelDataType deriveSumType(RelDataTypeFactory typeFactory, } @Override - public RelDataType deriveDecimalPlusType(RelDataTypeFactory typeFactory, - RelDataType type1, RelDataType type2) { + public RelDataType deriveDecimalPlusType(RelDataTypeFactory typeFactory, RelDataType type1, RelDataType type2) { RelDataType dataType = super.deriveDecimalPlusType(typeFactory, type1, type2); - if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) - && (dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { + if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) && ( + dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { return typeFactory.createSqlType(SqlTypeName.DECIMAL, DERIVED_DECIMAL_PRECISION, DERIVED_DECIMAL_SCALE); } return dataType; } @Override - public RelDataType deriveDecimalMultiplyType(RelDataTypeFactory typeFactory, - RelDataType type1, RelDataType type2) { + public RelDataType deriveDecimalMultiplyType(RelDataTypeFactory typeFactory, RelDataType type1, RelDataType type2) { RelDataType dataType = super.deriveDecimalMultiplyType(typeFactory, type1, type2); - if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) - && (dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { + if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) && ( + dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { return typeFactory.createSqlType(SqlTypeName.DECIMAL, DERIVED_DECIMAL_PRECISION, DERIVED_DECIMAL_SCALE); } return dataType; } @Override - public RelDataType deriveDecimalDivideType(RelDataTypeFactory typeFactory, - RelDataType type1, RelDataType type2) { + public RelDataType deriveDecimalDivideType(RelDataTypeFactory typeFactory, RelDataType type1, RelDataType type2) { RelDataType dataType = super.deriveDecimalDivideType(typeFactory, type1, type2); - if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) - && (dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { + if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) && ( + dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { return typeFactory.createSqlType(SqlTypeName.DECIMAL, DERIVED_DECIMAL_PRECISION, DERIVED_DECIMAL_SCALE); } return dataType; } @Override - public RelDataType deriveDecimalModType(RelDataTypeFactory typeFactory, - RelDataType type1, RelDataType type2) { + public RelDataType deriveDecimalModType(RelDataTypeFactory typeFactory, RelDataType type1, RelDataType type2) { RelDataType dataType = super.deriveDecimalModType(typeFactory, type1, type2); - if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) - && (dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { + if (dataType != null && SqlTypeUtil.isExactNumeric(dataType) && SqlTypeUtil.isDecimal(dataType) && ( + dataType.getPrecision() > DERIVED_DECIMAL_PRECISION)) { return typeFactory.createSqlType(SqlTypeName.DECIMAL, DERIVED_DECIMAL_PRECISION, DERIVED_DECIMAL_SCALE); } return dataType; diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRuleTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRuleTest.java index d470f7a4251c..b938b8656997 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRuleTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/calcite/rel/rules/PinotSortExchangeCopyRuleTest.java @@ -35,7 +35,6 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.pinot.calcite.rel.logical.PinotLogicalSortExchange; import org.apache.pinot.query.type.TypeFactory; -import org.apache.pinot.query.type.TypeSystem; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; @@ -47,8 +46,7 @@ public class PinotSortExchangeCopyRuleTest { - - public static final TypeFactory TYPE_FACTORY = new TypeFactory(new TypeSystem()); + private static final TypeFactory TYPE_FACTORY = new TypeFactory(); private static final RexBuilder REX_BUILDER = new RexBuilder(TYPE_FACTORY); private AutoCloseable _mocks; diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java index b776d0001e67..a66514e98798 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java @@ -31,16 +31,12 @@ import javax.annotation.Nullable; import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.tuple.Pair; -import org.apache.pinot.calcite.jdbc.CalciteSchemaBuilder; import org.apache.pinot.common.config.provider.TableCache; import org.apache.pinot.core.routing.RoutingManager; import org.apache.pinot.core.routing.TablePartitionInfo; import org.apache.pinot.core.routing.TablePartitionInfo.PartitionInfo; -import org.apache.pinot.query.catalog.PinotCatalog; import org.apache.pinot.query.routing.WorkerManager; import org.apache.pinot.query.testutils.MockRoutingManagerFactory; -import org.apache.pinot.query.type.TypeFactory; -import org.apache.pinot.query.type.TypeSystem; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.Schema; import org.apache.pinot.spi.utils.CommonConstants; @@ -278,9 +274,8 @@ public static QueryEnvironment getQueryEnvironment(int reducerPort, int port1, i } RoutingManager routingManager = factory.buildRoutingManager(partitionInfoMap); TableCache tableCache = factory.buildTableCache(); - return new QueryEnvironment(new TypeFactory(new TypeSystem()), - CalciteSchemaBuilder.asRootSchema(new PinotCatalog(tableCache), CommonConstants.DEFAULT_DATABASE), - new WorkerManager("localhost", reducerPort, routingManager), tableCache); + return new QueryEnvironment(CommonConstants.DEFAULT_DATABASE, tableCache, + new WorkerManager("localhost", reducerPort, routingManager)); } /** diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/type/TypeFactoryTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/type/TypeFactoryTest.java index 543e1c3b70e7..fdc3dae3e013 100644 --- a/pinot-query-planner/src/test/java/org/apache/pinot/query/type/TypeFactoryTest.java +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/type/TypeFactoryTest.java @@ -39,8 +39,7 @@ public class TypeFactoryTest { - private static final TypeSystem TYPE_SYSTEM = new TypeSystem(); - private static final JavaTypeFactory TYPE_FACTORY = new TestJavaTypeFactoryImpl(TYPE_SYSTEM); + private static final JavaTypeFactory TYPE_FACTORY = new TestJavaTypeFactoryImpl(); @DataProvider(name = "relDataTypeConversion") public Iterator relDataTypeConversion() { @@ -120,7 +119,7 @@ public Iterator relDataTypeConversion() { @Test(dataProvider = "relDataTypeConversion") public void testScalarTypes(FieldSpec.DataType dataType, RelDataType scalarType, RelDataType arrayType, boolean columnNullMode) { - TypeFactory typeFactory = new TypeFactory(TYPE_SYSTEM); + TypeFactory typeFactory = new TypeFactory(); Schema testSchema = new Schema.SchemaBuilder() .addSingleValueDimension("col", dataType) .setEnableColumnBasedNullHandling(columnNullMode) @@ -135,7 +134,7 @@ public void testScalarTypes(FieldSpec.DataType dataType, RelDataType scalarType, @Test(dataProvider = "relDataTypeConversion") public void testNullableScalarTypes(FieldSpec.DataType dataType, RelDataType scalarType, RelDataType arrayType, boolean columnNullMode) { - TypeFactory typeFactory = new TypeFactory(TYPE_SYSTEM); + TypeFactory typeFactory = new TypeFactory(); Schema testSchema = new Schema.SchemaBuilder() .addDimensionField("col", dataType, field -> field.setNullable(true)) .setEnableColumnBasedNullHandling(columnNullMode) @@ -154,7 +153,7 @@ public void testNullableScalarTypes(FieldSpec.DataType dataType, RelDataType sca @Test(dataProvider = "relDataTypeConversion") public void testNotNullableScalarTypes(FieldSpec.DataType dataType, RelDataType scalarType, RelDataType arrayType, boolean columnNullMode) { - TypeFactory typeFactory = new TypeFactory(TYPE_SYSTEM); + TypeFactory typeFactory = new TypeFactory(); Schema testSchema = new Schema.SchemaBuilder() .addDimensionField("col", dataType, field -> field.setNullable(false)) .setEnableColumnBasedNullHandling(columnNullMode) @@ -173,7 +172,7 @@ private boolean isColNullable(Schema schema) { @Test(dataProvider = "relDataTypeConversion") public void testArrayTypes(FieldSpec.DataType dataType, RelDataType scalarType, RelDataType arrayType, boolean columnNullMode) { - TypeFactory typeFactory = new TypeFactory(TYPE_SYSTEM); + TypeFactory typeFactory = new TypeFactory(); Schema testSchema = new Schema.SchemaBuilder() .addMultiValueDimension("col", dataType) .setEnableColumnBasedNullHandling(columnNullMode) @@ -192,7 +191,7 @@ public void testArrayTypes(FieldSpec.DataType dataType, RelDataType scalarType, @Test(dataProvider = "relDataTypeConversion") public void testNullableArrayTypes(FieldSpec.DataType dataType, RelDataType scalarType, RelDataType arrayType, boolean columnNullMode) { - TypeFactory typeFactory = new TypeFactory(TYPE_SYSTEM); + TypeFactory typeFactory = new TypeFactory(); Schema testSchema = new Schema.SchemaBuilder() .addDimensionField("col", dataType, field -> { field.setNullable(true); @@ -214,7 +213,7 @@ public void testNullableArrayTypes(FieldSpec.DataType dataType, RelDataType scal @Test(dataProvider = "relDataTypeConversion") public void testNotNullableArrayTypes(FieldSpec.DataType dataType, RelDataType scalarType, RelDataType arrayType, boolean columnNullMode) { - TypeFactory typeFactory = new TypeFactory(TYPE_SYSTEM); + TypeFactory typeFactory = new TypeFactory(); Schema testSchema = new Schema.SchemaBuilder() .addDimensionField("col", dataType, field -> { field.setNullable(false); @@ -233,7 +232,7 @@ public void testNotNullableArrayTypes(FieldSpec.DataType dataType, RelDataType s @Test public void testRelDataTypeConversion() { - TypeFactory typeFactory = new TypeFactory(TYPE_SYSTEM); + TypeFactory typeFactory = new TypeFactory(); Schema testSchema = new Schema.SchemaBuilder().addSingleValueDimension("INT_COL", FieldSpec.DataType.INT) .addSingleValueDimension("LONG_COL", FieldSpec.DataType.LONG) .addSingleValueDimension("FLOAT_COL", FieldSpec.DataType.FLOAT) @@ -253,54 +252,54 @@ public void testRelDataTypeConversion() { for (RelDataTypeField field : fieldList) { switch (field.getName()) { case "INT_COL": - BasicSqlType intBasicSqlType = new BasicSqlType(TYPE_SYSTEM, SqlTypeName.INTEGER); + BasicSqlType intBasicSqlType = new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.INTEGER); Assert.assertEquals(field.getType(), intBasicSqlType); checkPrecisionScale(field, intBasicSqlType); break; case "LONG_COL": - BasicSqlType bigIntBasicSqlType = new BasicSqlType(TYPE_SYSTEM, SqlTypeName.BIGINT); + BasicSqlType bigIntBasicSqlType = new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.BIGINT); Assert.assertEquals(field.getType(), bigIntBasicSqlType); checkPrecisionScale(field, bigIntBasicSqlType); break; case "FLOAT_COL": case "DOUBLE_COL": - BasicSqlType doubleBasicSqlType = new BasicSqlType(TYPE_SYSTEM, SqlTypeName.DOUBLE); + BasicSqlType doubleBasicSqlType = new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.DOUBLE); Assert.assertEquals(field.getType(), doubleBasicSqlType); checkPrecisionScale(field, doubleBasicSqlType); break; case "STRING_COL": case "JSON_COL": Assert.assertEquals(field.getType(), - TYPE_FACTORY.createTypeWithCharsetAndCollation(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARCHAR), + TYPE_FACTORY.createTypeWithCharsetAndCollation(new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.VARCHAR), StandardCharsets.UTF_8, SqlCollation.IMPLICIT)); break; case "BYTES_COL": - Assert.assertEquals(field.getType(), new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARBINARY)); + Assert.assertEquals(field.getType(), new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.VARBINARY)); break; case "INT_ARRAY_COL": Assert.assertEquals(field.getType(), - new ArraySqlType(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.INTEGER), false)); + new ArraySqlType(new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.INTEGER), false)); break; case "LONG_ARRAY_COL": Assert.assertEquals(field.getType(), - new ArraySqlType(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.BIGINT), false)); + new ArraySqlType(new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.BIGINT), false)); break; case "FLOAT_ARRAY_COL": Assert.assertEquals(field.getType(), - new ArraySqlType(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.REAL), false)); + new ArraySqlType(new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.REAL), false)); break; case "DOUBLE_ARRAY_COL": Assert.assertEquals(field.getType(), - new ArraySqlType(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.DOUBLE), false)); + new ArraySqlType(new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.DOUBLE), false)); break; case "STRING_ARRAY_COL": Assert.assertEquals(field.getType(), new ArraySqlType( - TYPE_FACTORY.createTypeWithCharsetAndCollation(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARCHAR), + TYPE_FACTORY.createTypeWithCharsetAndCollation(new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.VARCHAR), StandardCharsets.UTF_8, SqlCollation.IMPLICIT), false)); break; case "BYTES_ARRAY_COL": Assert.assertEquals(field.getType(), - new ArraySqlType(new BasicSqlType(TYPE_SYSTEM, SqlTypeName.VARBINARY), false)); + new ArraySqlType(new BasicSqlType(TypeSystem.INSTANCE, SqlTypeName.VARBINARY), false)); break; default: Assert.fail("Unexpected column name: " + field.getName()); @@ -310,8 +309,8 @@ public void testRelDataTypeConversion() { } private static class TestJavaTypeFactoryImpl extends JavaTypeFactoryImpl { - public TestJavaTypeFactoryImpl(TypeSystem typeSystem) { - super(typeSystem); + public TestJavaTypeFactoryImpl() { + super(TypeSystem.INSTANCE); } @Override diff --git a/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json b/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json index e822c536a908..7c5f0de3d051 100644 --- a/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json +++ b/pinot-query-planner/src/test/resources/queries/BasicQueryPlans.json @@ -105,7 +105,7 @@ "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", "\n PinotLogicalExchange(distribution=[hash[0]])", "\n PinotLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])", - "\n LogicalProject(EXPR$1=[ARRAYTOMV($6)], col3=[$2])", + "\n LogicalProject(EXPR$1=[ARRAY_TO_MV($6)], col3=[$2])", "\n LogicalTableScan(table=[[default, e]])", "\n" ] diff --git a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json index 88994cba3bdd..6298709bf524 100644 --- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json +++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json @@ -35,7 +35,7 @@ "sql": "EXPLAIN PLAN FOR SELECT dateTrunc('MONTH', FROMDATETIME( '1997-02-01 00:00:00', 'yyyy-MM-dd HH:mm:ss')) FROM d", "output": [ "Execution Plan", - "\nLogicalProject(EXPR$0=[CAST(854755200000:BIGINT):BIGINT])", + "\nLogicalProject(EXPR$0=[854755200000:BIGINT])", "\n LogicalTableScan(table=[[default, d]])", "\n" ] @@ -45,7 +45,7 @@ "sql": "EXPLAIN PLAN FOR SELECT timestampDiff(DAY, CAST(ts as TIMESTAMP), CAST(dateTrunc('MONTH', FROMDATETIME('1997-02-01 00:00:00', 'yyyy-MM-dd HH:mm:ss')) as TIMESTAMP)) FROM d", "output": [ "Execution Plan", - "\nLogicalProject(EXPR$0=[TIMESTAMPDIFF(FLAG(DAY), CAST($6):TIMESTAMP(0) NOT NULL, CAST(854755200000:BIGINT):TIMESTAMP(0))])", + "\nLogicalProject(EXPR$0=[TIMESTAMPDIFF(FLAG(DAY), CAST($6):TIMESTAMP(0) NOT NULL, CAST(854755200000:BIGINT):TIMESTAMP(0) NOT NULL)])", "\n LogicalTableScan(table=[[default, d]])", "\n" ] @@ -65,7 +65,7 @@ "sql": "EXPLAIN PLAN FOR SELECT dateTrunc('MONTH', 854755200000) AS day FROM a", "output": [ "Execution Plan", - "\nLogicalProject(day=[CAST(854755200000:BIGINT):BIGINT])", + "\nLogicalProject(day=[854755200000:BIGINT])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -95,7 +95,7 @@ "sql": "EXPLAIN PLAN FOR SELECT concat('month', ' 1') FROM a", "output": [ "Execution Plan", - "\nLogicalProject(EXPR$0=[CAST(_UTF-8'month 1':VARCHAR CHARACTER SET \"UTF-8\"):VARCHAR CHARACTER SET \"UTF-8\"])", + "\nLogicalProject(EXPR$0=[_UTF-8'month 1':VARCHAR CHARACTER SET \"UTF-8\"])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -105,7 +105,7 @@ "sql": "EXPLAIN PLAN FOR SELECT substr('month',2) FROM a", "output": [ "Execution Plan", - "\nLogicalProject(EXPR$0=[CAST(_UTF-8'nth':VARCHAR CHARACTER SET \"UTF-8\"):VARCHAR CHARACTER SET \"UTF-8\"])", + "\nLogicalProject(EXPR$0=[_UTF-8'nth':VARCHAR CHARACTER SET \"UTF-8\"])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] @@ -115,7 +115,7 @@ "sql": "EXPLAIN PLAN FOR SELECT upper(lower(upper(substr('month',2)))) FROM a", "output": [ "Execution Plan", - "\nLogicalProject(EXPR$0=[CAST(_UTF-8'NTH':VARCHAR CHARACTER SET \"UTF-8\"):VARCHAR CHARACTER SET \"UTF-8\"])", + "\nLogicalProject(EXPR$0=[_UTF-8'NTH':VARCHAR CHARACTER SET \"UTF-8\"])", "\n LogicalTableScan(table=[[default, a]])", "\n" ] diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java index 89c5585a318b..903a763968bf 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FunctionOperand.java @@ -20,6 +20,7 @@ import com.google.common.base.Preconditions; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import javax.annotation.Nullable; import org.apache.pinot.common.function.FunctionInfo; @@ -47,9 +48,32 @@ public FunctionOperand(RexExpression.FunctionCall functionCall, DataSchema dataS _resultType = functionCall.getDataType(); List operands = functionCall.getFunctionOperands(); int numOperands = operands.size(); - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionCall.getFunctionName(), numOperands); - Preconditions.checkState(functionInfo != null, "Cannot find function with name: %s", - functionCall.getFunctionName()); + ColumnDataType[] argumentTypes = new ColumnDataType[numOperands]; + for (int i = 0; i < numOperands; i++) { + RexExpression operand = operands.get(i); + ColumnDataType argumentType; + if (operand instanceof RexExpression.InputRef) { + argumentType = dataSchema.getColumnDataType(((RexExpression.InputRef) operand).getIndex()); + } else if (operand instanceof RexExpression.Literal) { + argumentType = ((RexExpression.Literal) operand).getDataType(); + } else { + assert operand instanceof RexExpression.FunctionCall; + argumentType = ((RexExpression.FunctionCall) operand).getDataType(); + } + argumentTypes[i] = argumentType; + } + String functionName = functionCall.getFunctionName(); + String canonicalName = FunctionRegistry.canonicalize(functionName); + FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes); + if (functionInfo == null) { + if (FunctionRegistry.contains(canonicalName)) { + throw new IllegalArgumentException( + String.format("Unsupported function: %s with argument types: %s", functionName, + Arrays.toString(argumentTypes))); + } else { + throw new IllegalArgumentException(String.format("Unsupported function: %s", functionName)); + } + } _functionInvoker = new FunctionInvoker(functionInfo); if (!_functionInvoker.getMethod().isVarArgs()) { Class[] parameterClasses = _functionInvoker.getParameterClasses(); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java index ee2441fe7079..a72a166bc73f 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java @@ -178,8 +178,8 @@ public static class Accumulator { "MIN", cdt -> AggregationUtils::mergeMin, "MAX", cdt -> AggregationUtils::mergeMax, "COUNT", cdt -> new AggregationUtils.MergeCounts(), - "BOOL_AND", cdt -> AggregationUtils::mergeBoolAnd, - "BOOL_OR", cdt -> AggregationUtils::mergeBoolOr + "BOOLAND", cdt -> AggregationUtils::mergeBoolAnd, + "BOOLOR", cdt -> AggregationUtils::mergeBoolOr ); //@formatter:on diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java index 1fb29c79f07a..8f2dddcabb82 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/FilterOperatorTest.java @@ -218,8 +218,12 @@ public void shouldHandleBooleanFunction() { assertEquals(resultRows.get(0), new Object[]{"starTree"}); } - @Test(expectedExceptions = IllegalStateException.class, expectedExceptionsMessageRegExp = "Cannot find function " - + "with name: startsWithError") + //@formatter:off + @Test( + expectedExceptions = IllegalArgumentException.class, + expectedExceptionsMessageRegExp = "Unsupported function: startsWithError" + ) + //@formatter:on public void shouldThrowOnInvalidFunction() { DataSchema inputSchema = new DataSchema(new String[]{"string1"}, new ColumnDataType[]{ ColumnDataType.STRING diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java index 65f6906b6896..7f7b4e836d05 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/queries/QueryRunnerTest.java @@ -51,10 +51,15 @@ * pattern goes here. */ public class QueryRunnerTest extends QueryRunnerTestBase { + //@formatter:off public static final Object[][] ROWS = new Object[][]{ - new Object[]{"foo", "foo", 1}, new Object[]{"bar", "bar", 42}, new Object[]{"alice", "alice", 1}, new Object[]{ - "bob", "foo", 42}, new Object[]{"charlie", "bar", 1}, + new Object[]{"foo", "foo", 1}, + new Object[]{"bar", "bar", 42}, + new Object[]{"alice", "alice", 1}, + new Object[]{"bob", "foo", 42}, + new Object[]{"charlie", "bar", 1} }; + //@formatter:on public static final Schema.SchemaBuilder SCHEMA_BUILDER; static { @@ -199,18 +204,21 @@ public void testSqlWithExceptionMsgChecker(String sql, String exceptionMsg) { @DataProvider(name = "testDataWithSqlToFinalRowCount") private Object[][] provideTestSqlAndRowCount() { + //@formatter:off return new Object[][]{ // special hint test, the table is not actually partitioned by col1, thus this hint gives wrong result. but // b/c in order to test whether this hint produces the proper optimized plan, we are making this assumption new Object[]{ - "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ col1, COUNT(*) FROM a " - + " GROUP BY 1 ORDER BY 2", 10 + "SELECT /*+ aggOptions(is_partitioned_by_group_by_keys='true') */ col1, COUNT(*) FROM a GROUP BY 1 " + + "ORDER BY 2", + 10 }, // special hint test, we want to try if dynamic broadcast works for just any random table */ new Object[]{ - "SELECT /*+ joinOptions(join_strategy='dynamic_broadcast') */ col1 FROM a " - + " WHERE a.col1 IN (SELECT b.col2 FROM b WHERE b.col3 < 10) AND a.col3 > 0", 9 + "SELECT /*+ joinOptions(join_strategy='dynamic_broadcast') */ col1 FROM a WHERE a.col1 IN " + + "(SELECT b.col2 FROM b WHERE b.col3 < 10) AND a.col3 > 0", + 9 }, // using join clause @@ -231,13 +239,14 @@ private Object[][] provideTestSqlAndRowCount() { // test function can be used in predicate/leaf/intermediate stage (using regexpLike) new Object[]{"SELECT a.col1, b.col1 FROM a JOIN b ON a.col3 = b.col3 WHERE regexpLike(a.col2, b.col1)", 9}, new Object[]{"SELECT a.col1, b.col1 FROM a JOIN b ON a.col3 = b.col3 WHERE regexp_like(a.col2, b.col1)", 9}, - new Object[]{"SELECT regexpLike(a.col1, b.col1) FROM a JOIN b ON a.col3 = b.col3", 39}, new Object[]{"SELECT " - + "regexp_like(a.col1, b.col1) FROM a JOIN b ON a.col3 = b.col3", 39}, + new Object[]{"SELECT regexpLike(a.col1, b.col1) FROM a JOIN b ON a.col3 = b.col3", 39}, + new Object[]{"SELECT regexp_like(a.col1, b.col1) FROM a JOIN b ON a.col3 = b.col3", 39}, // test function with @ScalarFunction annotation and alias works (using round_decimal) - new Object[]{"SELECT roundDecimal(col3) FROM a", 15}, new Object[]{"SELECT round_decimal(col3) FROM a", 15}, - new Object[]{"SELECT col1, roundDecimal(COUNT(*)) FROM a GROUP BY col1", 5}, new Object[]{"SELECT col1, " - + "round_decimal(COUNT(*)) FROM a GROUP BY col1", 5}, + new Object[]{"SELECT roundDecimal(col3) FROM a", 15}, + new Object[]{"SELECT round_decimal(col3) FROM a", 15}, + new Object[]{"SELECT col1, roundDecimal(COUNT(*)) FROM a GROUP BY col1", 5}, + new Object[]{"SELECT col1, round_decimal(COUNT(*)) FROM a GROUP BY col1", 5}, // test queries with special query options attached // - when leaf limit is set, each server returns multiStageLeafLimit number of rows only. @@ -246,61 +255,74 @@ private Object[][] provideTestSqlAndRowCount() { // test groups limit in both leaf and intermediate stage new Object[]{"SET numGroupsLimit = 1; SELECT col1, COUNT(*) FROM a GROUP BY col1", 1}, new Object[]{"SET numGroupsLimit = 2; SELECT col1, COUNT(*) FROM a GROUP BY col1", 2}, - new Object[]{"SET numGroupsLimit = 1; " - + "SELECT a.col2, b.col2, COUNT(*) FROM a JOIN b USING (col1) GROUP BY a.col2, b.col2", 1}, - new Object[]{"SET numGroupsLimit = 2; " - + "SELECT a.col2, b.col2, COUNT(*) FROM a JOIN b USING (col1) GROUP BY a.col2, b.col2", 2}, + new Object[]{ + "SET numGroupsLimit = 1; " + + "SELECT a.col2, b.col2, COUNT(*) FROM a JOIN b USING (col1) GROUP BY a.col2, b.col2", + 1 + }, + new Object[]{ + "SET numGroupsLimit = 2; " + + "SELECT a.col2, b.col2, COUNT(*) FROM a JOIN b USING (col1) GROUP BY a.col2, b.col2", + 2 + }, // TODO: Consider pushing down hint to the leaf stage - new Object[]{"SET numGroupsLimit = 2; SELECT /*+ aggOptions(num_groups_limit='1') */ " - + "col1, COUNT(*) FROM a GROUP BY col1", 2}, - new Object[]{"SET numGroupsLimit = 2; SELECT /*+ aggOptions(num_groups_limit='1') */ " - + "a.col2, b.col2, COUNT(*) FROM a JOIN b USING (col1) GROUP BY a.col2, b.col2", 1}, + new Object[]{ + "SET numGroupsLimit = 2; " + + "SELECT /*+ aggOptions(num_groups_limit='1') */ col1, COUNT(*) FROM a GROUP BY col1", + 2 + }, + new Object[]{ + "SET numGroupsLimit = 2; " + + "SELECT /*+ aggOptions(num_groups_limit='1') */ a.col2, b.col2, COUNT(*) FROM a JOIN b USING (col1) " + + "GROUP BY a.col2, b.col2", + 1 + }, new Object[]{"SELECT * FROM \"default.tbl-escape-naming\"", 5} }; + //@formatter:on } @DataProvider(name = "testDataWithSqlExecutionExceptions") private Object[][] provideTestSqlWithExecutionException() { + //@formatter:off return new Object[][]{ // Missing index - new Object[]{ - "SELECT col1 FROM a WHERE textMatch(col1, 'f') LIMIT 10", "without text index" - }, + new Object[]{"SELECT col1 FROM a WHERE textMatch(col1, 'f') LIMIT 10", "without text index"}, // Query hint with dynamic broadcast pipeline breaker should return error upstream new Object[]{ - "SELECT /*+ joinOptions(join_strategy='dynamic_broadcast') */ col1 FROM a " - + " WHERE a.col1 IN (SELECT b.col2 FROM b WHERE textMatch(col1, 'f')) AND a.col3 > 0", "without text " - + "index" + "SELECT /*+ joinOptions(join_strategy='dynamic_broadcast') */ col1 FROM a WHERE a.col1 IN " + + "(SELECT b.col2 FROM b WHERE textMatch(col1, 'f')) AND a.col3 > 0", + "without text index" }, // Timeout exception should occur with this option: - new Object[]{ - "SET timeoutMs = 1; SELECT * FROM a JOIN b ON a.col1 = b.col1 JOIN c ON a.col1 = c.col1", "Timeout" - }, + new Object[]{"SET timeoutMs = 1; SELECT * FROM a JOIN b ON a.col1 = b.col1 JOIN c ON a.col1 = c.col1", + "Timeout"}, // Function with incorrect argument signature should throw runtime exception when casting string to numeric - new Object[]{ - "SELECT least(a.col2, b.col3) FROM a JOIN b ON a.col1 = b.col1", "For input string:" - }, + new Object[]{"SELECT least(a.col2, b.col3) FROM a JOIN b ON a.col1 = b.col1", "For input string:"}, // Scalar function that doesn't have a valid use should throw an exception on the leaf stage // - predicate only functions: - new Object[]{"SELECT * FROM a WHERE textMatch(col1, 'f')", "without text index"}, new Object[]{"SELECT * FROM" - + " a WHERE text_match(col1, 'f')", "without text index"}, new Object[]{"SELECT * FROM a WHERE textContains" - + "(col1, 'f')", "supported only on native text index"}, new Object[]{"SELECT * FROM a WHERE text_contains" - + "(col1, 'f')", "supported only on native text index"}, + new Object[]{"SELECT * FROM a WHERE textMatch(col1, 'f')", "without text index"}, + new Object[]{"SELECT * FROM a WHERE text_match(col1, 'f')", "without text index"}, + new Object[]{"SELECT * FROM a WHERE textContains(col1, 'f')", "supported only on native text index"}, + new Object[]{"SELECT * FROM a WHERE text_contains(col1, 'f')", "supported only on native text index"}, // - transform only functions - new Object[]{"SELECT jsonExtractKey(col1, 'path') FROM a", "was expecting (JSON String"}, new Object[]{ - "SELECT json_extract_key(col1, 'path') FROM a", "was expecting (JSON String"}, + new Object[]{"SELECT jsonExtractKey(col1, 'path') FROM a", "was expecting (JSON String"}, + new Object[]{"SELECT json_extract_key(col1, 'path') FROM a", "was expecting (JSON String"}, // - PlaceholderScalarFunction registered will throw on intermediate stage, but works on leaf stage. // - checked "Illegal Json Path" as col1 is not actually a json string, but the call is correctly triggered. new Object[]{"SELECT CAST(jsonExtractScalar(col1, 'path', 'INT') AS INT) FROM a", "Cannot resolve JSON path"}, // - checked function cannot be found b/c there's no intermediate stage impl for json_extract_scalar - new Object[]{"SELECT CAST(json_extract_scalar(a.col1, b.col2, 'INT') AS INT)" - + "FROM a JOIN b ON a.col1 = b.col1", "Cannot find function with name: JSON_EXTRACT_SCALAR"}, + new Object[]{ + "SELECT CAST(json_extract_scalar(a.col1, b.col2, 'INT') AS INT) FROM a JOIN b ON a.col1 = b.col1", + "Unsupported function: JSONEXTRACTSCALAR" + } }; + //@formatter:on } } diff --git a/pinot-query-runtime/src/test/resources/queries/StatisticAggregates.json b/pinot-query-runtime/src/test/resources/queries/StatisticAggregates.json index 9e05805887c2..4eb29419afc8 100644 --- a/pinot-query-runtime/src/test/resources/queries/StatisticAggregates.json +++ b/pinot-query-runtime/src/test/resources/queries/StatisticAggregates.json @@ -144,31 +144,13 @@ }, "queries": [ { - "sql": "SELECT groupingCol, AVG(val), COVARPOP(val, val), COVARSAMP(val, val), VARPOP(val), VARSAMP(val), STDDEVPOP(val), STDDEVSAMP(val) FROM {tbl} GROUP BY groupingCol", - "h2Sql": "SELECT groupingCol, AVG(val), COVAR_POP(val, val), COVAR_SAMP(val, val), VAR_POP(val), VAR_SAMP(val), STDDEV_POP(val), STDDEV_SAMP(val) FROM {tbl} GROUP BY groupingCol" + "sql": "SELECT groupingCol, AVG(val), COVAR_POP(val, val), COVAR_SAMP(val, val), VAR_POP(val), VAR_SAMP(val), STDDEV_POP(val), STDDEV_SAMP(val) FROM {tbl} GROUP BY groupingCol" }, { - "sql": "SELECT AVG(val), COVARPOP(val, val), COVARSAMP(val, val), VARPOP(val), VARSAMP(val), STDDEVPOP(val), STDDEVSAMP(val) FROM {tbl} WHERE groupingCol='a'", - "h2Sql": "SELECT AVG(val), COVAR_POP(val, val), COVAR_SAMP(val, val), VAR_POP(val), VAR_SAMP(val), STDDEV_POP(val), STDDEV_SAMP(val) FROM {tbl} WHERE groupingCol='a'" + "sql": "SELECT AVG(val), COVAR_POP(val, val), COVAR_SAMP(val, val), VAR_POP(val), VAR_SAMP(val), STDDEV_POP(val), STDDEV_SAMP(val) FROM {tbl} WHERE groupingCol='a'" }, { - "sql": "SELECT t1.groupingCol, AVG(t1.val), COVARPOP(t1.val, t2.val), COVARSAMP(t1.val, t2.val), VARPOP(t1.val + t2.val), VARSAMP(t1.val + t2.val), STDDEVPOP(t1.val + t2.val), STDDEVSAMP(t1.val + t2.val) FROM {tbl} AS t1 LEFT JOIN {tbl2} AS t2 USING (partitionCol) GROUP BY t1.groupingCol", - "h2Sql": "SELECT t1.groupingCol, AVG(t1.val), COVAR_POP(t1.val, t2.val), COVAR_SAMP(t1.val, t2.val), VAR_POP(t1.val + t2.val), VAR_SAMP(t1.val + t2.val), STDDEV_POP(t1.val + t2.val), STDDEV_SAMP(t1.val + t2.val) FROM {tbl} AS t1 LEFT JOIN {tbl2} AS t2 USING (partitionCol) GROUP BY t1.groupingCol" - }, - { - "ignored": true, - "comments": "standard name with underscore creates un-supported function from agg-reduce-function calcite rule", - "sql": "SELECT groupingCol, COVAR_POP(val, val), COVAR_SAMP(val, val), VAR_POP(val), VAR_SAMP(val), STD_DEV_POP(val), STD_DEV_SAMP(val) FROM {tbl} GROUP BY groupingCol" - }, - { - "ignored": true, - "comments": "standard name with underscore creates un-supported function from agg-reduce-function calcite rule", - "sql": "SELECT COVAR_POP(val, val), COVAR_SAMP(val, val), VAR_POP(val), VAR_SAMP(val), STD_DEV_POP(val), STD_DEV_SAMP(val) FROM {tbl} WHERE groupingCol='a'" - }, - { - "ignored": true, - "comments": "standard name with underscore creates un-supported function from agg-reduce-function calcite rule", - "sql": "SELECT t1.groupingCol, COVAR_POP(t1.val, t2.val), COVAR_SAMP(t1.val, t2.val), VAR_POP(t1.val + t2.val), VAR_SAMP(t1.val + t2.val), STD_DEV_POP(t1.val + t2.val), STD_DEV_SAMP(t1.val + t2.val) FROM {tbl} AS t1 LEFT JOIN {tbl2} AS t2 USING (partitionCol) GROUP BY t1.groupingCol" + "sql": "SELECT t1.groupingCol, AVG(t1.val), COVAR_POP(t1.val, t2.val), COVAR_SAMP(t1.val, t2.val), VAR_POP(t1.val + t2.val), VAR_SAMP(t1.val + t2.val), STDDEV_POP(t1.val + t2.val), STDDEV_SAMP(t1.val + t2.val) FROM {tbl} AS t1 LEFT JOIN {tbl2} AS t2 USING (partitionCol) GROUP BY t1.groupingCol" } ] } diff --git a/pinot-query-runtime/src/test/resources/queries/UDFAggregates.json b/pinot-query-runtime/src/test/resources/queries/UDFAggregates.json index fa61530dcf01..b11143f2cdb9 100644 --- a/pinot-query-runtime/src/test/resources/queries/UDFAggregates.json +++ b/pinot-query-runtime/src/test/resources/queries/UDFAggregates.json @@ -150,15 +150,15 @@ }, { "sql": "SELECT PERCENTILE_TDIGEST(float_col, 50), PERCENTILE_TDIGEST(double_col, 5), PERCENTILE_TDIGEST(int_col, 75), PERCENTILE_TDIGEST(long_col, 75) FROM {tbl}", - "outputs": [[1.75, 1.0, 137, 137]] + "outputs": [[1.75, 1.0, 137.75, 137.75]] }, { "sql": "SELECT bool_col, PERCENTILE_TDIGEST(float_col, 50), PERCENTILE_TDIGEST(double_col, 5), PERCENTILE_TDIGEST(int_col, 75), PERCENTILE_TDIGEST(long_col, 75) FROM {tbl} GROUP BY bool_col", - "outputs": [[false, 1.255, 1.0, 125, 125], [true, 300, 1.75, 131, 131]] + "outputs": [[false, 1.255, 1.0, 125.5, 125.5], [true, 300, 1.75, 131.75, 131.75]] }, { "sql": "SELECT /*+ aggOptions(is_skip_leaf_stage_aggregate='true') */ string_col, PERCENTILE_TDIGEST(float_col, 50), PERCENTILE_TDIGEST(double_col, 5), PERCENTILE_TDIGEST(int_col, 75), PERCENTILE_TDIGEST(long_col, 75) FROM {tbl} GROUP BY string_col", - "outputs": [["a", 350, 300, 2, 2], ["b", 50.5, 1, 100, 100], ["c", 1.5, 1.01, 168, 168]] + "outputs": [["a", 350, 300, 2, 2], ["b", 50.5, 1, 100, 100], ["c", 1.5, 1.01, 168.75, 168.75]] }, { "sql": "SELECT PERCENTILE_KLL(float_col, 50), PERCENTILE_KLL(double_col, 5), PERCENTILE_KLL(int_col, 75), PERCENTILE_KLL(long_col, 75) FROM {tbl}", diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java index 7b28ce7850f8..de804d104e92 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/function/InbuiltFunctionEvaluator.java @@ -70,7 +70,8 @@ private ExecutableNode planExecution(ExpressionContext expression) { childNodes[i] = planExecution(arguments.get(i)); } String functionName = function.getFunctionName(); - switch (functionName) { + String canonicalName = FunctionRegistry.canonicalize(functionName); + switch (canonicalName) { case "and": return new AndExecutionNode(childNodes); case "or": @@ -86,14 +87,13 @@ private ExecutableNode planExecution(ExpressionContext expression) { } return new ArrayConstantExecutionNode(values); default: - FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName, numArguments); + FunctionInfo functionInfo = FunctionRegistry.lookupFunctionInfo(canonicalName, numArguments); if (functionInfo == null) { - if (FunctionRegistry.containsFunction(functionName)) { + if (FunctionRegistry.contains(canonicalName)) { throw new IllegalStateException( - String.format("Unsupported function: %s with %d parameters", functionName, numArguments)); + String.format("Unsupported function: %s with %d arguments", functionName, numArguments)); } else { - throw new IllegalStateException( - String.format("Unsupported function: %s not found", functionName)); + throw new IllegalStateException(String.format("Unsupported function: %s", functionName)); } } return new FunctionExecutionNode(functionInfo, childNodes); diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index 97ef7adf2f63..9b722a515ef4 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -18,9 +18,7 @@ */ package org.apache.pinot.segment.spi; -import com.google.common.collect.ImmutableList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -28,8 +26,6 @@ import javax.annotation.Nullable; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.sql.SqlFunctionCategory; -import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.ReturnTypes; @@ -53,396 +49,230 @@ */ public enum AggregationFunctionType { // Aggregation functions for single-valued columns - COUNT("count", null, SqlKind.COUNT, SqlFunctionCategory.NUMERIC, OperandTypes.ONE_OR_MORE, - ReturnTypes.explicit(SqlTypeName.BIGINT), ReturnTypes.explicit(SqlTypeName.BIGINT)), + COUNT("count"), // TODO: min/max only supports NUMERIC in Pinot, where Calcite supports COMPARABLE_ORDERED - MIN("min", null, SqlKind.MIN, SqlFunctionCategory.SYSTEM, OperandTypes.NUMERIC, ReturnTypes.ARG0_NULLABLE_IF_EMPTY, - ReturnTypes.explicit(SqlTypeName.DOUBLE)), - MAX("max", null, SqlKind.MAX, SqlFunctionCategory.SYSTEM, OperandTypes.NUMERIC, ReturnTypes.ARG0_NULLABLE_IF_EMPTY, - ReturnTypes.explicit(SqlTypeName.DOUBLE)), - SUM("sum", null, SqlKind.SUM, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC, ReturnTypes.AGG_SUM, - ReturnTypes.explicit(SqlTypeName.DOUBLE)), - SUM0("$sum0", null, SqlKind.SUM0, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC, - ReturnTypes.AGG_SUM_EMPTY_IS_ZERO, ReturnTypes.explicit(SqlTypeName.DOUBLE)), - SUMPRECISION("sumPrecision", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.ANY, ReturnTypes.explicit(SqlTypeName.DECIMAL), ReturnTypes.explicit(SqlTypeName.OTHER)), - // NO NEEDED in v2, AVG is compiled as SUM/COUNT - AVG("avg"), - MODE("mode"), - FIRSTWITHTIME("firstWithTime", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.or( - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER))), - ReturnTypes.ARG0, ReturnTypes.explicit(SqlTypeName.OTHER)), - LASTWITHTIME("lastWithTime", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.or( - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP, SqlTypeFamily.CHARACTER)), - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER))), - ReturnTypes.ARG0, ReturnTypes.explicit(SqlTypeName.OTHER)), - MINMAXRANGE("minMaxRange", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC, - ReturnTypes.ARG0, ReturnTypes.explicit(SqlTypeName.OTHER)), + MIN("min", SqlTypeName.DOUBLE), + MAX("max", SqlTypeName.DOUBLE), + SUM("sum", SqlTypeName.DOUBLE), + SUM0("$sum0", SqlTypeName.DOUBLE), + SUMPRECISION("sumPrecision", ReturnTypes.explicit(SqlTypeName.DECIMAL), OperandTypes.ANY, SqlTypeName.OTHER), + AVG("avg", SqlTypeName.OTHER), + MODE("mode", SqlTypeName.OTHER), + FIRSTWITHTIME("firstWithTime", ReturnTypes.ARG0, + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), SqlTypeName.OTHER), + LASTWITHTIME("lastWithTime", ReturnTypes.ARG0, + OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), SqlTypeName.OTHER), + MINMAXRANGE("minMaxRange", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER), + /** * for all distinct count family functions: * (1) distinct_count only supports single argument; * (2) count(distinct ...) support multi-argument and will be converted into DISTINCT + COUNT */ - DISTINCTCOUNT("distinctCount", ImmutableList.of("DISTINCT_COUNT"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.ANY, ReturnTypes.BIGINT, - ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTBITMAP("distinctCountBitmap", ImmutableList.of("DISTINCT_COUNT_BITMAP"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.ANY, ReturnTypes.BIGINT, - ReturnTypes.explicit(SqlTypeName.OTHER)), - SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount", - ImmutableList.of("SEGMENT_PARTITIONED_DISTINCT_COUNT"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.ANY, ReturnTypes.BIGINT, ReturnTypes.BIGINT), - DISTINCTCOUNTHLL("distinctCountHLL", ImmutableList.of("DISTINCT_COUNT_HLL"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC), ordinal -> ordinal > 0), - ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTRAWHLL("distinctCountRawHLL", ImmutableList.of("DISTINCT_COUNT_RAW_HLL"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), ordinal -> ordinal > 0), - ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTSMARTHLL("distinctCountSmartHLL", ImmutableList.of("DISTINCT_COUNT_SMART_HLL"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), ordinal -> ordinal > 0), - ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - // DEPRECATED in v2 + DISTINCTCOUNT("distinctCount", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER), + DISTINCTSUM("distinctSum", ReturnTypes.AGG_SUM, OperandTypes.NUMERIC, SqlTypeName.OTHER), + DISTINCTAVG("distinctAvg", ReturnTypes.DOUBLE, OperandTypes.NUMERIC, SqlTypeName.OTHER), + DISTINCTCOUNTBITMAP("distinctCountBitmap", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER), + SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount", ReturnTypes.BIGINT, OperandTypes.ANY, + SqlTypeName.OTHER), + DISTINCTCOUNTHLL("distinctCountHLL", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTRAWHLL("distinctCountRawHLL", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTSMARTHLL("distinctCountSmartHLL", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), i -> i == 1), SqlTypeName.OTHER), @Deprecated FASTHLL("fastHLL"), - DISTINCTCOUNTTHETASKETCH("distinctCountThetaSketch", null, SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), ordinal -> ordinal > 0), - ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTRAWTHETASKETCH("distinctCountRawThetaSketch", null, SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), ordinal -> ordinal > 0), - ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTSUM("distinctSum", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC, - ReturnTypes.AGG_SUM, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTAVG("distinctAvg", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC, - ReturnTypes.explicit(SqlTypeName.DOUBLE), ReturnTypes.explicit(SqlTypeName.OTHER)), - - PERCENTILE("percentile", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.ARG0, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILEEST("percentileEst", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.ARG0, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILERAWEST("percentileRawEst", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILETDIGEST("percentileTDigest", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.ARG0, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILERAWTDIGEST("percentileRawTDigest", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), - // DEPRECATED in v2 - @Deprecated PERCENTILESMARTTDIGEST("percentileSmartTDigest"), - PERCENTILEKLL("percentileKLL", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.ARG0, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILERAWKLL("percentileRawKLL", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), - // hyper log log plus plus functions - DISTINCTCOUNTHLLPLUS("distinctCountHLLPlus", ImmutableList.of("DISTINCT_COUNT_HLL_PLUS"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC), ordinal -> ordinal > 0), - ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTRAWHLLPLUS("distinctCountRawHLLPlus", ImmutableList.of("DISTINCT_COUNT_RAW_HLL_PLUS"), - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), ordinal -> ordinal > 0), - ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)), - - DISTINCTCOUNTULL("distinctCountULL", ImmutableList.of("DISTINCT_COUNT_ULL"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC), ordinal -> ordinal > 0), - ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - - DISTINCTCOUNTRAWULL("distinctCountRawULL", ImmutableList.of("DISTINCT_COUNT_RAW_ULL"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC), ordinal -> ordinal > 0), - ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)), - - // DEPRECATED in v2 - @Deprecated IDSET("idSet"), - - // TODO: support histogram requires solving ARRAY constructor and multi-function signature without optional ordinal - HISTOGRAM("histogram"), - - // TODO: support underscore separated version of the stats functions, resolving conflict in SqlStdOptTable - // currently Pinot is missing generated agg functions impl from Calcite's AggregateReduceFunctionsRule - COVARPOP("covarPop", Collections.emptyList(), SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, - ReturnTypes.explicit(SqlTypeName.OTHER)), - COVARSAMP("covarSamp", Collections.emptyList(), SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, - ReturnTypes.explicit(SqlTypeName.OTHER)), - VARPOP("varPop", Collections.emptyList(), SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - VARSAMP("varSamp", Collections.emptyList(), SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - STDDEVPOP("stdDevPop", Collections.emptyList(), SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - STDDEVSAMP("stdDevSamp", Collections.emptyList(), SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - SKEWNESS("skewness", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.NUMERIC, - ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - KURTOSIS("kurtosis", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.NUMERIC, - ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - FOURTHMOMENT("fourthMoment"), - - // DataSketches Tuple Sketch support - DISTINCTCOUNTTUPLESKETCH("distinctCountTupleSketch", ImmutableList.of("DISTINCT_COUNT_TUPLE_SKETCH"), - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.BINARY, ReturnTypes.BIGINT, - ReturnTypes.explicit(SqlTypeName.OTHER)), - - // DataSketches Tuple Sketch support for Integer based Tuple Sketches - DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH("distinctCountRawIntegerSumTupleSketch", - ImmutableList.of("DISTINCT_COUNT_RAW_INTEGER_SUM_TUPLE_SKETCH"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.BINARY, ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), - - SUMVALUESINTEGERSUMTUPLESKETCH("sumValuesIntegerSumTupleSketch", - ImmutableList.of("SUM_VALUES_INTEGER_SUM_TUPLE_SKETCH"), SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.BINARY, ReturnTypes.BIGINT, - ReturnTypes.explicit(SqlTypeName.OTHER)), - AVGVALUEINTEGERSUMTUPLESKETCH("avgValueIntegerSumTupleSketch", ImmutableList.of("AVG_VALUE_INTEGER_SUM_TUPLE_SKETCH"), - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.BINARY, ReturnTypes.BIGINT, - ReturnTypes.explicit(SqlTypeName.OTHER)), + DISTINCTCOUNTHLLPLUS("distinctCountHLLPlus", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTRAWHLLPLUS("distinctCountRawHLLPlus", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTULL("distinctCountULL", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTRAWULL("distinctCountRawULL", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTTHETASKETCH("distinctCountThetaSketch", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTRAWTHETASKETCH("distinctCountRawThetaSketch", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTTUPLESKETCH("distinctCountTupleSketch", ReturnTypes.BIGINT, OperandTypes.BINARY, SqlTypeName.OTHER), + DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH("distinctCountRawIntegerSumTupleSketch", ReturnTypes.VARCHAR, + OperandTypes.BINARY, SqlTypeName.OTHER), + SUMVALUESINTEGERSUMTUPLESKETCH("sumValuesIntegerSumTupleSketch", ReturnTypes.BIGINT, OperandTypes.BINARY, + SqlTypeName.OTHER), + AVGVALUEINTEGERSUMTUPLESKETCH("avgValueIntegerSumTupleSketch", ReturnTypes.BIGINT, OperandTypes.BINARY, + SqlTypeName.OTHER), + DISTINCTCOUNTCPCSKETCH("distinctCountCPCSketch", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.ANY), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTRAWCPCSKETCH("distinctCountRawCPCSketch", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.ANY), i -> i == 1), SqlTypeName.OTHER), + + PERCENTILE("percentile", ReturnTypes.ARG0, OperandTypes.ANY_NUMERIC, SqlTypeName.OTHER), + PERCENTILEEST("percentileEst", ReturnTypes.BIGINT, OperandTypes.ANY_NUMERIC, SqlTypeName.OTHER), + PERCENTILERAWEST("percentileRawEst", ReturnTypes.VARCHAR, OperandTypes.ANY_NUMERIC, SqlTypeName.OTHER), + PERCENTILETDIGEST("percentileTDigest", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER), + PERCENTILERAWTDIGEST("percentileRawTDigest", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER), + PERCENTILESMARTTDIGEST("percentileSmartTDigest", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER), i -> i == 2), + SqlTypeName.OTHER), + PERCENTILEKLL("percentileKLL", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER), + PERCENTILERAWKLL("percentileRawKLL", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER), + + IDSET("idSet", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), i -> i == 1), SqlTypeName.OTHER), + + HISTOGRAM("histogram", new ArrayReturnTypeInference(SqlTypeName.DOUBLE), OperandTypes.VARIADIC, SqlTypeName.OTHER), + + COVARPOP("covarPop", SqlTypeName.OTHER), + COVARSAMP("covarSamp", SqlTypeName.OTHER), + VARPOP("varPop", SqlTypeName.OTHER), + VARSAMP("varSamp", SqlTypeName.OTHER), + STDDEVPOP("stdDevPop", SqlTypeName.OTHER), + STDDEVSAMP("stdDevSamp", SqlTypeName.OTHER), + + SKEWNESS("skewness", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER), + KURTOSIS("kurtosis", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER), + FOURTHMOMENT("fourthMoment", ReturnTypes.DOUBLE, OperandTypes.ANY, SqlTypeName.OTHER), // Datasketches Frequent Items support - FREQUENTSTRINGSSKETCH("frequentStringsSketch"), - FREQUENTLONGSSKETCH("frequentLongsSketch"), - - // Datasketches CPC Sketch support - DISTINCTCOUNTCPCSKETCH("distinctCountCPCSketch", ImmutableList.of("DISTINCT_COUNT_CPC_SKETCH"), - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), ordinal -> ordinal > 0), - ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTRAWCPCSKETCH("distinctCountRawCPCSketch", ImmutableList.of("DISTINCT_COUNT_RAW_CPC_SKETCH"), - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), ordinal -> ordinal > 0), - ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)), + FREQUENTSTRINGSSKETCH("frequentStringsSketch", ReturnTypes.VARCHAR, OperandTypes.ANY, SqlTypeName.OTHER), + FREQUENTLONGSSKETCH("frequentLongsSketch", ReturnTypes.VARCHAR, OperandTypes.ANY, SqlTypeName.OTHER), // Geo aggregation functions - STUNION("STUnion", ImmutableList.of("ST_UNION"), SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.BINARY, ReturnTypes.explicit(SqlTypeName.VARBINARY), ReturnTypes.explicit(SqlTypeName.OTHER)), - - // Aggregation functions for multi-valued columns - COUNTMV("countMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.explicit(SqlTypeName.BIGINT), - ReturnTypes.explicit(SqlTypeName.BIGINT)), - MINMV("minMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)), - MAXMV("maxMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)), - SUMMV("sumMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)), - AVGMV("avgMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - MINMAXRANGEMV("minMaxRangeMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.explicit(SqlTypeName.DOUBLE), - ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTMV("distinctCountMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV", null, SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, - ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTHLLMV("distinctCountHLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTRAWHLLMV("distinctCountRawHLLMV", null, SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTSUMMV("distinctSumMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTAVGMV("distinctAvgMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILEMV("percentileMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILEESTMV("percentileEstMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILERAWESTMV("percentileRawEstMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILETDIGESTMV("percentileTDigestMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILERAWTDIGESTMV("percentileRawTDigestMV", null, SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILEKLLMV("percentileKLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC), - ordinal -> ordinal > 1 && ordinal < 4), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), - PERCENTILERAWKLLMV("percentileRawKLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC), - ordinal -> ordinal > 1 && ordinal < 4), ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)), - // hyper log log plus plus functions - DISTINCTCOUNTHLLPLUSMV("distinctCountHLLPlusMV", null, SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, - ReturnTypes.explicit(SqlTypeName.OTHER)), - DISTINCTCOUNTRAWHLLPLUSMV("distinctCountRawHLLPlusMV", null, SqlKind.OTHER_FUNCTION, - SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.VARCHAR_2000, - ReturnTypes.explicit(SqlTypeName.OTHER)), + STUNION("STUnion", ReturnTypes.VARBINARY, OperandTypes.BINARY, SqlTypeName.OTHER), // boolean aggregate functions - BOOLAND("boolAnd", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.BOOLEAN, - ReturnTypes.BOOLEAN, ReturnTypes.explicit(SqlTypeName.INTEGER)), - BOOLOR("boolOr", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.BOOLEAN, - ReturnTypes.BOOLEAN, ReturnTypes.explicit(SqlTypeName.INTEGER)), + BOOLAND("boolAnd", ReturnTypes.BOOLEAN, OperandTypes.BOOLEAN, SqlTypeName.INTEGER), + BOOLOR("boolOr", ReturnTypes.BOOLEAN, OperandTypes.BOOLEAN, SqlTypeName.INTEGER), // ExprMin and ExprMax // TODO: revisit support for ExprMin/Max count in V2, particularly plug query rewriter in the right place - EXPRMIN("exprMin", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY), ordinal -> ordinal > 1), ReturnTypes.ARG0, - ReturnTypes.explicit(SqlTypeName.OTHER)), - EXPRMAX("exprMax", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY), ordinal -> ordinal > 1), ReturnTypes.ARG0, - ReturnTypes.explicit(SqlTypeName.OTHER)), - - PINOTPARENTAGGEXPRMIN(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + EXPRMIN.getName(), null, - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.ANY), ordinal -> ordinal > 2), - ReturnTypes.explicit(SqlTypeName.OTHER), ReturnTypes.explicit(SqlTypeName.OTHER)), - PINOTPARENTAGGEXPRMAX(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + EXPRMAX.getName(), null, - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.ANY), ordinal -> ordinal > 2), - ReturnTypes.explicit(SqlTypeName.OTHER), ReturnTypes.explicit(SqlTypeName.OTHER)), - - PINOTCHILDAGGEXPRMIN(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + EXPRMIN.getName(), null, - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.ANY), ordinal -> ordinal > 3), - ReturnTypes.ARG1, ReturnTypes.explicit(SqlTypeName.OTHER)), - PINOTCHILDAGGEXPRMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + EXPRMAX.getName(), null, - SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.INTEGER, SqlTypeFamily.ANY), ordinal -> ordinal > 3), - ReturnTypes.ARG1, ReturnTypes.explicit(SqlTypeName.OTHER)), + EXPRMIN("exprMin", ReturnTypes.ARG0, OperandTypes.VARIADIC, SqlTypeName.OTHER), + EXPRMAX("exprMax", ReturnTypes.ARG0, OperandTypes.VARIADIC, SqlTypeName.OTHER), + PINOTPARENTAGGEXPRMIN(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + EXPRMIN.getName(), + ReturnTypes.explicit(SqlTypeName.OTHER), OperandTypes.VARIADIC, SqlTypeName.OTHER), + PINOTPARENTAGGEXPRMAX(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + EXPRMAX.getName(), + ReturnTypes.explicit(SqlTypeName.OTHER), OperandTypes.VARIADIC, SqlTypeName.OTHER), + PINOTCHILDAGGEXPRMIN(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + EXPRMIN.getName(), + ReturnTypes.ARG1, OperandTypes.VARIADIC, SqlTypeName.OTHER), + PINOTCHILDAGGEXPRMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + EXPRMAX.getName(), + ReturnTypes.ARG1, OperandTypes.VARIADIC, SqlTypeName.OTHER), // Array aggregate functions - ARRAYAGG("arrayAgg", null, SqlKind.ARRAY_AGG, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.BOOLEAN), - ordinal -> ordinal > 1), ReturnTypes.TO_ARRAY, ReturnTypes.explicit(SqlTypeName.OTHER)), - LISTAGG("listAgg", null, SqlKind.LISTAGG, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(ImmutableList.of(SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.BOOLEAN), - ordinal -> ordinal > 1), ReturnTypes.VARCHAR, ReturnTypes.explicit(SqlTypeName.OTHER)), - - SUMARRAYLONG("sumArrayLong", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ArrayReturnTypeInference.forType(SqlTypeName.BIGINT), - ReturnTypes.explicit(SqlTypeName.OTHER)), - SUMARRAYDOUBLE("sumArrayDouble", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.family(SqlTypeFamily.ARRAY), ArrayReturnTypeInference.forType(SqlTypeName.DOUBLE), - ReturnTypes.explicit(SqlTypeName.OTHER)), + ARRAYAGG("arrayAgg", ReturnTypes.TO_ARRAY, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER, SqlTypeFamily.BOOLEAN), i -> i == 2), + SqlTypeName.OTHER), + LISTAGG("listAgg", SqlTypeName.OTHER), + + SUMARRAYLONG("sumArrayLong", new ArrayReturnTypeInference(SqlTypeName.BIGINT), OperandTypes.ARRAY, SqlTypeName.OTHER), + SUMARRAYDOUBLE("sumArrayDouble", new ArrayReturnTypeInference(SqlTypeName.DOUBLE), OperandTypes.ARRAY, + SqlTypeName.OTHER), // funnel aggregate functions - FUNNELMAXSTEP("funnelMaxStep", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.VARIADIC, ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - FUNNELCOMPLETECOUNT("funnelCompleteCount", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.VARIADIC, ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), - FUNNELMATCHSTEP("funnelMatchStep", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, - OperandTypes.VARIADIC, ArrayReturnTypeInference.INT_ARRAY_RETURN_TYPE_INFERENCE, - ReturnTypes.explicit(SqlTypeName.OTHER)), - // TODO: revisit support for funnel count in V2 - FUNNELCOUNT("funnelCount"); + FUNNELMAXSTEP("funnelMaxStep", ReturnTypes.INTEGER, OperandTypes.VARIADIC, SqlTypeName.OTHER), + FUNNELCOMPLETECOUNT("funnelCompleteCount", ReturnTypes.INTEGER, OperandTypes.VARIADIC, SqlTypeName.OTHER), + FUNNELMATCHSTEP("funnelMatchStep", new ArrayReturnTypeInference(SqlTypeName.INTEGER), OperandTypes.VARIADIC, + SqlTypeName.OTHER), + FUNNELCOUNT("funnelCount", new ArrayReturnTypeInference(SqlTypeName.BIGINT), OperandTypes.VARIADIC, + SqlTypeName.OTHER), + + // Aggregation functions for multi-valued columns + COUNTMV("countMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.BIGINT), + MINMV("minMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.DOUBLE), + MAXMV("maxMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.DOUBLE), + SUMMV("sumMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.DOUBLE), + AVGMV("avgMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), + MINMAXRANGEMV("minMaxRangeMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), + DISTINCTCOUNTMV("distinctCountMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.OTHER), + DISTINCTSUMMV("distinctSumMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), + DISTINCTAVGMV("distinctAvgMV", ReturnTypes.DOUBLE, OperandTypes.ARRAY, SqlTypeName.OTHER), + DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV", ReturnTypes.BIGINT, OperandTypes.ARRAY, SqlTypeName.OTHER), + DISTINCTCOUNTHLLMV("distinctCountHLLMV", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTRAWHLLMV("distinctCountRawHLLMV", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTHLLPLUSMV("distinctCountHLLPlusMV", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + DISTINCTCOUNTRAWHLLPLUSMV("distinctCountRawHLLPlusMV", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.INTEGER), i -> i == 1), SqlTypeName.OTHER), + PERCENTILEMV("percentileMV", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), SqlTypeName.OTHER), + PERCENTILEESTMV("percentileEstMV", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), SqlTypeName.OTHER), + PERCENTILERAWESTMV("percentileRawEstMV", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), SqlTypeName.OTHER), + PERCENTILETDIGESTMV("percentileTDigestMV", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER), + PERCENTILERAWTDIGESTMV("percentileRawTDigestMV", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER), + PERCENTILEKLLMV("percentileKLLMV", ReturnTypes.DOUBLE, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER), + PERCENTILERAWKLLMV("percentileRawKLLMV", ReturnTypes.VARCHAR, + OperandTypes.family(List.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.INTEGER), i -> i == 2), + SqlTypeName.OTHER); private static final Set NAMES = Arrays.stream(values()).flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toLowerCase())) .collect(Collectors.toSet()); - // -------------------------------------------------------------------------- - // Function signature used by Calcite. - // -------------------------------------------------------------------------- private final String _name; - private final List _alternativeNames; - - // Fields for registering the aggregation function with Calcite in multistage. These are typically used for the - // user facing aggregation functions and the return and operand types should reflect that which is user facing. - private final SqlKind _sqlKind; - private final SqlFunctionCategory _sqlFunctionCategory; - // override options for Pinot aggregate functions that expects different return type or operand type + // Fields used by multi-stage engine + // When returnTypeInference is provided, the function will be registered as a USER_DEFINED_FUNCTION private final SqlReturnTypeInference _returnTypeInference; private final SqlOperandTypeChecker _operandTypeChecker; // override options for Pinot aggregate rules to insert intermediate results that are non-standard than return type. private final SqlReturnTypeInference _intermediateReturnTypeInference; - /** - * Constructor to use for aggregation functions which are only supported in v1 engine today - */ AggregationFunctionType(String name) { - this(name, null, null, null); + this(name, null, null, (SqlReturnTypeInference) null); } - /** - * Constructor to use for aggregation functions which are supported in both v1 and multistage engines. - *

- */ - AggregationFunctionType(String name, List alternativeNames, SqlKind sqlKind, - SqlFunctionCategory sqlFunctionCategory) { - this(name, alternativeNames, sqlKind, sqlFunctionCategory, null, null, null); + AggregationFunctionType(String name, SqlTypeName intermediateReturnType) { + this(name, null, null, ReturnTypes.explicit(intermediateReturnType)); } - /** - * Constructor to use for aggregation functions which are supported in both v1 and multistage engines with - * different behavior comparing to Calcite and requires literal operand inputs. - * - * @param name name of the agg function - * @param alternativeNames alternative name of the agg function. - * @param sqlKind sql kind indicator, used by Calcite - * @param sqlFunctionCategory function catalog, used by Calcite - * @param operandTypeChecker input operand type signature, used by Calcite - * @param finalReturnType final output type signature, used by Calcite - * @param intermediateReturnType intermediate output type signature, used by Pinot and Calcite - */ - AggregationFunctionType(String name, @Nullable List alternativeNames, @Nullable SqlKind sqlKind, - @Nullable SqlFunctionCategory sqlFunctionCategory, @Nullable SqlOperandTypeChecker operandTypeChecker, - @Nullable SqlReturnTypeInference finalReturnType, @Nullable SqlReturnTypeInference intermediateReturnType) { - _name = name; - if (alternativeNames == null) { - _alternativeNames = Collections.singletonList(getUnderscoreSplitAggregationFunctionName(_name)); - } else { - _alternativeNames = alternativeNames; - } - _sqlKind = sqlKind; - _sqlFunctionCategory = sqlFunctionCategory; + AggregationFunctionType(String name, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker, SqlTypeName intermediateReturnType) { + this(name, returnTypeInference, operandTypeChecker, ReturnTypes.explicit(intermediateReturnType)); + } - _returnTypeInference = finalReturnType; + AggregationFunctionType(String name, @Nullable SqlReturnTypeInference returnTypeInference, + @Nullable SqlOperandTypeChecker operandTypeChecker, + @Nullable SqlReturnTypeInference intermediateReturnTypeInference) { + _name = name; + _returnTypeInference = returnTypeInference; _operandTypeChecker = operandTypeChecker; - _intermediateReturnTypeInference = intermediateReturnType == null ? _returnTypeInference : intermediateReturnType; + _intermediateReturnTypeInference = intermediateReturnTypeInference; } public String getName() { return _name; } - public List getAlternativeNames() { - return _alternativeNames; - } - - public SqlKind getSqlKind() { - return _sqlKind; - } - - public SqlReturnTypeInference getIntermediateReturnTypeInference() { - return _intermediateReturnTypeInference; - } - + @Nullable public SqlReturnTypeInference getReturnTypeInference() { return _returnTypeInference; } + @Nullable public SqlOperandTypeChecker getOperandTypeChecker() { return _operandTypeChecker; } - public SqlFunctionCategory getSqlFunctionCategory() { - return _sqlFunctionCategory; + @Nullable + public SqlReturnTypeInference getIntermediateReturnTypeInference() { + return _intermediateReturnTypeInference; } public static boolean isAggregationFunction(String functionName) { @@ -465,11 +295,6 @@ public static String getNormalizedAggregationFunctionName(String functionName) { return StringUtils.remove(StringUtils.remove(functionName, '_').toUpperCase(), "$"); } - public static String getUnderscoreSplitAggregationFunctionName(String functionName) { - // Skip functions that have numbers for now and return their name as is - return functionName.matches(".*\\d.*") ? functionName : functionName.replaceAll("(.)(\\p{Upper}+|\\d+)", "$1_$2"); - } - /** * Returns the corresponding aggregation function type for the given function name. *

NOTE: Underscores in the function name are ignored. @@ -519,26 +344,18 @@ public static AggregationFunctionType getAggregationFunctionType(String function } } - static class ArrayReturnTypeInference implements SqlReturnTypeInference { - static final ArrayReturnTypeInference INT_ARRAY_RETURN_TYPE_INFERENCE = - ArrayReturnTypeInference.forType(SqlTypeName.INTEGER); - - private final SqlTypeName _sqlTypeName; + private static class ArrayReturnTypeInference implements SqlReturnTypeInference { + final SqlTypeName _sqlTypeName; ArrayReturnTypeInference(SqlTypeName sqlTypeName) { _sqlTypeName = sqlTypeName; } @Override - public RelDataType inferReturnType( - SqlOperatorBinding opBinding) { + public RelDataType inferReturnType(SqlOperatorBinding opBinding) { RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); RelDataType elementType = typeFactory.createSqlType(_sqlTypeName); return typeFactory.createArrayType(elementType, -1); } - - static ArrayReturnTypeInference forType(SqlTypeName sqlTypeName) { - return new ArrayReturnTypeInference(sqlTypeName); - } } } diff --git a/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/AggregationFunctionTypeTest.java b/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/AggregationFunctionTypeTest.java new file mode 100644 index 000000000000..09de660eebdd --- /dev/null +++ b/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/AggregationFunctionTypeTest.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.segment.spi; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + + +public class AggregationFunctionTypeTest { + + @Test + public void testIsAggFunc() { + assertTrue(AggregationFunctionType.isAggregationFunction("count")); + assertTrue(AggregationFunctionType.isAggregationFunction("percentileRawEstMV")); + assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILERAWESTMV")); + assertTrue(AggregationFunctionType.isAggregationFunction("percentilerawestmv")); + assertTrue(AggregationFunctionType.isAggregationFunction("percentile_raw_est_mv")); + assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILE_RAW_EST_MV")); + assertTrue(AggregationFunctionType.isAggregationFunction("PERCENTILEEST90")); + assertTrue(AggregationFunctionType.isAggregationFunction("percentileest90")); + assertFalse(AggregationFunctionType.isAggregationFunction("toEpochSeconds")); + } +} diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java index 0a647a879212..e80b6d0c3b48 100644 --- a/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java +++ b/pinot-spi/src/main/java/org/apache/pinot/spi/annotations/ScalarFunction.java @@ -41,7 +41,7 @@ * - byte[] */ @Retention(RetentionPolicy.RUNTIME) -@Target(ElementType.METHOD) +@Target({ElementType.METHOD, ElementType.TYPE}) public @interface ScalarFunction { boolean enabled() default true; @@ -54,14 +54,13 @@ /** * Whether the scalar function expects and can handle null arguments. - * */ boolean nullableParameters() default false; - boolean isPlaceholder() default false; - /** * Whether the scalar function takes various number of arguments. */ boolean isVarArg() default false; + + @Deprecated boolean isPlaceholder() default false; } diff --git a/pom.xml b/pom.xml index bce3ad239241..3a30523d56e0 100644 --- a/pom.xml +++ b/pom.xml @@ -154,6 +154,7 @@ 2.5.1 2.3.2 1.37.0 + 2.10.1 9.11.1 0.10.2 0.17.0 @@ -1217,6 +1218,7 @@ json-smart ${jsonsmart.version} + org.apache.calcite calcite-core @@ -1245,6 +1247,13 @@ calcite-babel ${calcite.version} + + + org.immutables + value-annotations + ${immutables.version} + + org.codehaus.janino janino