Skip to content

Commit

Permalink
Adds Fn, FnProvider, Agg, and AggProvider APIs (#1722)
Browse files Browse the repository at this point in the history
* Adds FnProvider, Fn, AggProvider, and Agg
* Adds a RoutineSignature and RoutineProviderSignature
* Adds scalar and aggregate function builders
* Updates existing function implementations to use new APIs
* Updates RelOpAggregate and ExprCallDynamic to use arrays
* Adds minor coercion logic for aggregations
  • Loading branch information
johnedquinn authored Jan 24, 2025
1 parent 5b1eca6 commit f3a65ff
Show file tree
Hide file tree
Showing 85 changed files with 1,656 additions and 1,196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.partiql.spi.catalog.Identifier
import org.partiql.spi.errors.PError
import org.partiql.spi.errors.PErrorKind
import org.partiql.spi.errors.Severity
import org.partiql.spi.function.Function
import org.partiql.spi.function.FnOverload
import org.partiql.spi.types.PType
import java.io.PrintWriter
import java.io.Writer
Expand Down Expand Up @@ -186,7 +186,7 @@ object ErrorMessageFormatter {
*/
private fun fnTypeMismatch(error: PError): String {
val functionName = error.getOrNull("FN_ID", Identifier::class.java)
val candidates = error.getListOrNull("CANDIDATES", Function::class.java)
val candidates = error.getListOrNull("CANDIDATES", FnOverload::class.java)
val args = error.getListOrNull("ARG_TYPES", PType::class.java)
val fnNameStr = prepare(functionName.toString(), " ", "")
val fnStr = when {
Expand All @@ -196,20 +196,6 @@ object ErrorMessageFormatter {
}
return buildString {
append("Undefined function$fnStr.")
if (!candidates.isNullOrEmpty()) {
appendLine(" Did you mean: ")
for (variant in candidates) {
variant as Function
append("- ")
append(variant.getName())
append(
variant.getParameters().joinToString(", ", "(", ")") {
"${it.getName()}: ${it.getType()}"
}
)
appendLine()
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ internal class StandardCompiler(strategies: List<Strategy>) : PartiQLCompiler {
// Compile the candidates
val candidates = Array(functions.size) {
val fn = functions[it]
val fnArity = fn.getParameters().size
val fnArity = fn.signature.arity
if (arity == -1) {
// set first
arity = fnArity
Expand All @@ -385,7 +385,7 @@ internal class StandardCompiler(strategies: List<Strategy>) : PartiQLCompiler {
override fun visitCall(rex: RexCall, ctx: Unit): ExprValue {
val func = rex.getFunction()
val args = rex.getArgs()
val catch = func.parameters.any { it.code() == PType.DYNAMIC }
val catch = func.signature.parameters.any { it.type.code() == PType.DYNAMIC }
return when (catch) {
true -> ExprCall(func, Array(args.size) { i -> compile(args[i], Unit).catch() })
else -> ExprCall(func, Array(args.size) { i -> compile(args[i], Unit) })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.partiql.spi.errors.PError
import org.partiql.spi.errors.PErrorKind
import org.partiql.spi.errors.PRuntimeException
import org.partiql.spi.errors.Severity
import org.partiql.spi.function.Function
import org.partiql.spi.function.FnOverload
import org.partiql.spi.types.PType

internal object PErrors {
Expand All @@ -20,8 +20,8 @@ internal object PErrors {
/**
* Returns a PRuntimeException with code: [PError.FUNCTION_TYPE_MISMATCH].
*/
fun functionTypeMismatchException(name: String, actualTypes: List<PType>, candidates: List<Function>): PRuntimeException {
val pError = functionTypeMismatch(name, actualTypes, candidates)
fun functionTypeMismatchException(name: String, actualTypes: Array<PType>, candidates: List<FnOverload>): PRuntimeException {
val pError = functionTypeMismatch(name, actualTypes.toList(), candidates)
return PRuntimeException(pError)
}

Expand Down Expand Up @@ -158,7 +158,7 @@ internal object PErrors {
/**
* Returns a PError with code: [PError.FUNCTION_TYPE_MISMATCH].
*/
private fun functionTypeMismatch(name: String, actualTypes: List<PType>, candidates: List<Function>): PError {
private fun functionTypeMismatch(name: String, actualTypes: List<PType>, candidates: List<FnOverload>): PError {
return PError(
PError.FUNCTION_TYPE_MISMATCH,
Severity.ERROR(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package org.partiql.eval.internal.operator

import org.partiql.eval.ExprValue
import org.partiql.spi.function.Aggregation
import org.partiql.spi.function.Agg

/**
* Simple data class to hold a compile aggregation call.
*/
internal class Aggregate(
val agg: Aggregation,
val agg: Agg,
val args: List<ExprValue>,
val distinct: Boolean
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import org.partiql.eval.ExprValue
import org.partiql.eval.Row
import org.partiql.eval.internal.helpers.DatumArrayComparator
import org.partiql.eval.internal.operator.Aggregate
import org.partiql.spi.function.Aggregation
import org.partiql.spi.types.PType
import org.partiql.spi.function.Accumulator
import org.partiql.spi.value.Datum
import java.util.TreeMap
import java.util.TreeSet
Expand All @@ -23,12 +22,12 @@ internal class RelOpAggregate(
private val aggregationMap = TreeMap<Array<Datum>, List<AccumulatorWrapper>>(DatumArrayComparator)

/**
* Wraps an [Aggregation.Accumulator] to help with filtering distinct values.
* Wraps an [Accumulator] to help with filtering distinct values.
*
* @property seen maintains which values have already been seen. If null, we accumulate all values coming through.
*/
class AccumulatorWrapper(
val delegate: Aggregation.Accumulator,
val delegate: Accumulator,
val args: List<ExprValue>,
val seen: TreeSet<Array<Datum>>?
)
Expand All @@ -47,13 +46,10 @@ internal class RelOpAggregate(
}
}

// TODO IT DOES NOT MATTER NOW, BUT SqlCompiler SHOULD HANDLE GET THE ARGUMENT TYPES FOR .getAccumulator
val args: Array<PType> = emptyArray()

val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) {
aggregates.map {
AccumulatorWrapper(
delegate = it.agg.getAccumulator(args),
delegate = it.agg.accumulator,
args = it.args,
seen = if (it.distinct) TreeSet(DatumArrayComparator) else null
)
Expand Down Expand Up @@ -82,19 +78,20 @@ internal class RelOpAggregate(

// No Aggregations Created
if (groups.isEmpty() && aggregationMap.isEmpty()) {
val record = mutableListOf<Datum>()
aggregates.forEach { function ->
val accumulator = function.agg.getAccumulator(args = emptyArray())
record.add(accumulator.value())
val record = Array<Datum?>(aggregates.size) {
val function = aggregates[it]
val accumulator = function.agg.accumulator
accumulator.value()
}
records = iterator { yield(Row(record.toTypedArray())) }
records = iterator { yield(Row(record)) }
return
}

records = iterator {
aggregationMap.forEach { (keysEvaluated, accumulators) ->
val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated
yield(Row(recordValues.toTypedArray()))
val accumulatorValues = Array(accumulators.size) { i -> accumulators[i].delegate.value() }
val recordValues = accumulatorValues + keysEvaluated
yield(Row(recordValues))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.eval.Environment
import org.partiql.eval.ExprValue
import org.partiql.spi.function.Function
import org.partiql.spi.function.Fn
import org.partiql.spi.value.Datum

/**
Expand All @@ -12,14 +12,14 @@ import org.partiql.spi.value.Datum
* @property args Input argument expressions.
*/
internal class ExprCall(
private var function: Function.Instance,
private var function: Fn,
private var args: Array<ExprValue>,
) : ExprValue {

private var isNullCall: Boolean = function.isNullCall
private var isMissingCall: Boolean = function.isMissingCall
private var nil = { Datum.nullValue(function.returns) }
private var missing = { Datum.missing(function.returns) }
private var isNullCall: Boolean = function.signature.isNullCall
private var isMissingCall: Boolean = function.signature.isMissingCall
private var nil = { Datum.nullValue(function.signature.returns) }
private var missing = { Datum.missing(function.signature.returns) }

override fun eval(env: Environment): Datum {
// Evaluate arguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import org.partiql.eval.internal.helpers.PErrors
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.Candidate
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.DYNAMIC
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.UNKNOWN
import org.partiql.spi.function.Function
import org.partiql.spi.function.Fn
import org.partiql.spi.function.FnOverload
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum

Expand All @@ -31,7 +32,7 @@ import org.partiql.spi.value.Datum
*/
internal class ExprCallDynamic(
private val name: String,
private val functions: Array<Function>,
private val functions: Array<FnOverload>,
private val args: Array<ExprValue>
) : ExprValue {

Expand All @@ -43,15 +44,32 @@ internal class ExprCallDynamic(
/**
* A memoization cache for the [match] function.
*/
private val candidates: MutableMap<List<PType>, Candidate> = mutableMapOf()
private val candidates: MutableMap<ParameterTypes, Candidate> = mutableMapOf()

/**
* Used as the keys of the hash map: [ExprCallDynamic.candidates].
*/
private class ParameterTypes(val types: Array<PType>) {

override fun equals(other: Any?): Boolean {
if (this === other) return true
other as ParameterTypes // We can immediately cast, since this is a private class only used for the cache.
return types.contentEquals(other.types)
}

override fun hashCode(): Int {
return types.contentHashCode()
}
}

override fun eval(env: Environment): Datum {
val actualArgs = args.map { it.eval(env).lowerSafe() }.toTypedArray()
val actualTypes = actualArgs.map { it.type }
var candidate = candidates[actualTypes]
val actualArgs = Array(args.size) { args[it].eval(env).lowerSafe() }
val actualTypes = Array(actualArgs.size) { actualArgs[it].type }
val paramTypes = ParameterTypes(actualTypes)
var candidate = candidates[paramTypes]
if (candidate == null) {
candidate = match(actualTypes) ?: throw PErrors.functionTypeMismatchException(name, actualTypes, functions.toList())
candidates[actualTypes] = candidate
candidates[paramTypes] = candidate
}
return candidate.eval(actualArgs)
}
Expand All @@ -63,19 +81,19 @@ internal class ExprCallDynamic(
*
* @return the index of the candidate to invoke; null if method cannot resolve.
*/
private fun match(args: List<PType>): Candidate? {
private fun match(args: Array<PType>): Candidate? {
var exactMatches: Int = -1
var currentMatch: Int? = null
val argFamilies = args.map { family(it.code()) }
functions.indices.forEach { candidateIndex ->
var currentExactMatches = 0
val params = functions[candidateIndex].getInstance(args.toTypedArray())?.parameters ?: return@forEach
val params = functions[candidateIndex].getInstance(args)?.signature?.parameters ?: return@forEach
for (paramIndex in paramIndices) {
val argType = args[paramIndex]
val paramType = params[paramIndex]
if (paramType.code() == argType.code()) { currentExactMatches++ } // TODO: Convert all functions to use the new modelling, or else we need to only check kinds
if (paramType.type.code() == argType.code()) { currentExactMatches++ } // TODO: Convert all functions to use the new modelling, or else we need to only check kinds
val argFamily = argFamilies[paramIndex]
val paramFamily = family(paramType.code())
val paramFamily = family(paramType.type.code())
if (paramFamily != argFamily && argFamily != UNKNOWN && paramFamily != DYNAMIC) { return@forEach }
}
if (currentExactMatches > exactMatches) {
Expand All @@ -84,7 +102,7 @@ internal class ExprCallDynamic(
}
}
return if (currentMatch == null) null else {
val instance = functions[currentMatch!!].getInstance(args.toTypedArray()) ?: return null
val instance = functions[currentMatch!!].getInstance(args) ?: return null
Candidate(instance)
}
}
Expand Down Expand Up @@ -160,28 +178,28 @@ internal class ExprCallDynamic(
*
* @see ExprCallDynamic
*/
private class Candidate(private var function: Function.Instance) {
private class Candidate(private var function: Fn) {

private var nil = { Datum.nullValue(function.returns) }
private var missing = { Datum.missing(function.returns) }
private var nil = { Datum.nullValue(function.signature.returns) }
private var missing = { Datum.missing(function.signature.returns) }

/**
* Function instance parameters (just types).
*/
fun eval(args: Array<Datum>): Datum {
val coerced = Array(args.size) { i ->
val arg = args[i]
if (function.isNullCall && arg.isNull) {
if (function.signature.isNullCall && arg.isNull) {
return nil.invoke()
}
if (function.isMissingCall && arg.isMissing) {
if (function.signature.isMissingCall && arg.isMissing) {
return missing.invoke()
}
val argType = arg.type
val paramType = function.parameters[i]
when (paramType == argType) {
val paramType = function.signature.parameters[i]
when (paramType.type == argType) {
true -> arg
false -> CastTable.cast(arg, paramType)
false -> CastTable.cast(arg, paramType.type)
}
}
return function.invoke(coerced)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.Mode
import org.partiql.eval.compiler.PartiQLCompiler
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum
import org.partiql.value.PartiQLValue
import org.partiql.value.bagValue
Expand Down Expand Up @@ -1412,14 +1413,10 @@ class PartiQLEvaluatorTest {
fun developmentTest() {
val tc =
SuccessTestCase(
input = "SELECT DISTINCT VALUE t * 100 FROM <<0, 1, 2.0, 3.0>> AS t;",
expected = bagValue(
int32Value(0),
int32Value(100),
decimalValue(BigDecimal.valueOf(2000, 1)),
decimalValue(BigDecimal.valueOf(3000, 1)),
),
mode = Mode.STRICT()
input = """
non_existing_column = 1
""".trimIndent(),
expected = Datum.nullValue(PType.bool())
)
tc.run()
}
Expand Down
Loading

0 comments on commit f3a65ff

Please sign in to comment.