Skip to content

Commit

Permalink
wip; fix typing and overflow behavior of SUM and AVG; still have clea…
Browse files Browse the repository at this point in the history
…nup to do
  • Loading branch information
alancai98 committed Jan 24, 2025
1 parent f004d3e commit c841580
Show file tree
Hide file tree
Showing 25 changed files with 347 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3332,7 +3332,7 @@ internal class PlanTyperTestsPorted {
"a" to StaticType.INT4,
"_1" to StaticType.INT8,
"_2" to StaticType.INT8,
"_3" to StaticType.INT4,
"_3" to StaticType.INT8,
"_4" to StaticType.INT4,
"_5" to StaticType.INT4,
),
Expand All @@ -3354,7 +3354,7 @@ internal class PlanTyperTestsPorted {
"a" to StaticType.INT4,
"c_s" to StaticType.INT8,
"c" to StaticType.INT8,
"s" to StaticType.INT4,
"s" to StaticType.INT8,
"m" to StaticType.INT4,
),
contentClosed = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ import org.partiql.spi.function.builtins.internal.AccumulatorAvg
import org.partiql.spi.types.PType

// TODO: This needs to be formalized. See https://github.com/partiql/partiql-lang-kotlin/issues/1659
private val AVG_DECIMAL = PType.decimal(38, 19)
private val AVG_DECIMAL = DefaultDecimal.DECIMAL

internal val Agg_AVG__INT8__INT8 = Aggregation.static(
name = "avg",
returns = AVG_DECIMAL,
parameters = arrayOf(Parameter("value", PType.tinyint())),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
)

internal val Agg_AVG__INT16__INT16 = Aggregation.static(
name = "avg",
returns = AVG_DECIMAL,
parameters = arrayOf(Parameter("value", PType.smallint())),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
)

internal val Agg_AVG__INT32__INT32 = Aggregation.static(
Expand All @@ -32,7 +32,7 @@ internal val Agg_AVG__INT32__INT32 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.integer()),
),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
)

internal val Agg_AVG__INT64__INT64 = Aggregation.static(
Expand All @@ -42,7 +42,7 @@ internal val Agg_AVG__INT64__INT64 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.bigint()),
),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
)

internal val Agg_AVG__NUMERIC__NUMERIC = Aggregation.static(
Expand All @@ -52,7 +52,7 @@ internal val Agg_AVG__NUMERIC__NUMERIC = Aggregation.static(
parameters = arrayOf(
Parameter("value", DefaultNumeric.NUMERIC),
),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
)

internal val Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(
Expand All @@ -62,7 +62,7 @@ internal val Agg_AVG__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(
parameters = arrayOf(
Parameter("value", AVG_DECIMAL),
),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(DefaultDecimal.DECIMAL) },
)

internal val Agg_AVG__FLOAT32__FLOAT32 = Aggregation.static(
Expand All @@ -72,7 +72,7 @@ internal val Agg_AVG__FLOAT32__FLOAT32 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.real()),
),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(PType.doublePrecision()) },
)

internal val Agg_AVG__FLOAT64__FLOAT64 = Aggregation.static(
Expand All @@ -82,7 +82,7 @@ internal val Agg_AVG__FLOAT64__FLOAT64 = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.doublePrecision()),
),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(PType.doublePrecision()) },
)

internal val Agg_AVG__ANY__ANY = Aggregation.static(
Expand All @@ -92,5 +92,5 @@ internal val Agg_AVG__ANY__ANY = Aggregation.static(
parameters = arrayOf(
Parameter("value", PType.dynamic()),
),
accumulator = ::AccumulatorAvg,
accumulator = { AccumulatorAvg(PType.dynamic()) },
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,38 @@ import org.partiql.spi.types.PType

internal val Agg_SUM__INT8__INT8 = Aggregation.static(
name = "sum",
returns = PType.tinyint(),
returns = PType.bigint(),
parameters = arrayOf(
Parameter("value", PType.tinyint()),
),
accumulator = { AccumulatorSum(PType.tinyint()) },
accumulator = { AccumulatorSum(PType.bigint()) },
)

internal val Agg_SUM__INT16__INT16 = Aggregation.static(
name = "sum",
returns = PType.smallint(),
returns = PType.bigint(),
parameters = arrayOf(
Parameter("value", PType.smallint()),
),
accumulator = { AccumulatorSum(PType.smallint()) },
accumulator = { AccumulatorSum(PType.bigint()) },
)

internal val Agg_SUM__INT32__INT32 = Aggregation.static(
name = "sum",
returns = PType.integer(),
returns = PType.bigint(),
parameters = arrayOf(
Parameter("value", PType.integer()),
),
accumulator = { AccumulatorSum(PType.integer()) },
accumulator = { AccumulatorSum(PType.bigint()) },
)

internal val Agg_SUM__INT64__INT64 = Aggregation.static(
name = "sum",
returns = PType.bigint(),
returns = DefaultDecimal.DECIMAL,
parameters = arrayOf(
Parameter("value", PType.bigint())
),
accumulator = { AccumulatorSum(PType.bigint()) },
accumulator = { AccumulatorSum(DefaultDecimal.DECIMAL) },
)

internal val Agg_SUM__NUMERIC__NUMERIC = Aggregation.static(
Expand All @@ -55,11 +55,11 @@ internal val Agg_SUM__NUMERIC__NUMERIC = Aggregation.static(

internal val Agg_SUM__DECIMAL_ARBITRARY__DECIMAL_ARBITRARY = Aggregation.static(
name = "sum",
returns = PType.decimal(38, 19),
returns = DefaultDecimal.DECIMAL,
parameters = arrayOf(
Parameter("value", PType.decimal(38, 19)), // TODO: Rewrite aggregations using new function modeling.
),
accumulator = { AccumulatorSum(PType.decimal(38, 19)) },
accumulator = { AccumulatorSum(DefaultDecimal.DECIMAL) },
)

internal val Agg_SUM__FLOAT32__FLOAT32 = Aggregation.static(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.partiql.spi.function.builtins

import org.partiql.spi.types.PType

internal object DefaultDecimal {
// TODO: Once all functions are converted to use the new function modeling, this can be removed.
val DECIMAL: PType = PType.decimal(38, 19)
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package org.partiql.spi.function.builtins
import org.partiql.spi.function.Function
import org.partiql.spi.function.builtins.internal.PErrors
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.isZero
import org.partiql.spi.utils.NumberUtils.isZero
import org.partiql.spi.value.Datum
import java.math.RoundingMode

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package org.partiql.spi.function.builtins

import org.partiql.spi.function.Function
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.compareTo
import org.partiql.spi.utils.NumberUtils.compareTo
import org.partiql.spi.value.Datum

internal object FnGt : DiadicComparisonOperator("gt") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package org.partiql.spi.function.builtins

import org.partiql.spi.function.Function
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.compareTo
import org.partiql.spi.utils.NumberUtils.compareTo
import org.partiql.spi.value.Datum

internal object FnGte : DiadicComparisonOperator("gte") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
package org.partiql.spi.function.builtins

import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.booleanValue
import org.partiql.spi.types.PType
import org.partiql.spi.utils.FunctionUtils
import org.partiql.spi.utils.FunctionUtils.booleanValue
import org.partiql.spi.value.Datum

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
package org.partiql.spi.function.builtins

import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.booleanValue
import org.partiql.spi.types.PType
import org.partiql.spi.utils.FunctionUtils
import org.partiql.spi.utils.FunctionUtils.booleanValue
import org.partiql.spi.value.Datum

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package org.partiql.spi.function.builtins

import org.partiql.spi.function.Function
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.compareTo
import org.partiql.spi.utils.NumberUtils.compareTo
import org.partiql.spi.value.Datum

internal object FnLt : DiadicComparisonOperator("lt") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package org.partiql.spi.function.builtins

import org.partiql.spi.function.Function
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.compareTo
import org.partiql.spi.utils.NumberUtils.compareTo
import org.partiql.spi.value.Datum

internal object FnLte : DiadicComparisonOperator("lte") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import org.partiql.spi.function.Function
import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.PErrors
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.byteOverflows
import org.partiql.spi.utils.NumberExtensions.shortOverflows
import org.partiql.spi.utils.NumberUtils.byteOverflows
import org.partiql.spi.utils.NumberUtils.shortOverflows
import org.partiql.spi.value.Datum

internal object FnMinus : DiadicArithmeticOperator("minus") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package org.partiql.spi.function.builtins
import org.partiql.spi.function.Function
import org.partiql.spi.function.builtins.internal.PErrors
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.isZero
import org.partiql.spi.utils.NumberUtils.isZero
import org.partiql.spi.value.Datum

internal object FnModulo : DiadicArithmeticOperator("mod", false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import org.partiql.spi.function.Function
import org.partiql.spi.function.Parameter
import org.partiql.spi.function.builtins.internal.PErrors
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.byteOverflows
import org.partiql.spi.utils.NumberExtensions.shortOverflows
import org.partiql.spi.utils.NumberUtils.byteOverflows
import org.partiql.spi.utils.NumberUtils.shortOverflows
import org.partiql.spi.value.Datum

internal object FnPlus : DiadicArithmeticOperator("plus") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ package org.partiql.spi.function.builtins
import org.partiql.spi.function.Function
import org.partiql.spi.function.builtins.internal.PErrors
import org.partiql.spi.types.PType
import org.partiql.spi.utils.NumberExtensions.byteOverflows
import org.partiql.spi.utils.NumberExtensions.shortOverflows
import org.partiql.spi.utils.NumberUtils.byteOverflows
import org.partiql.spi.utils.NumberUtils.shortOverflows
import org.partiql.spi.value.Datum

internal object FnTimes : DiadicArithmeticOperator("times") {
Expand Down
Loading

0 comments on commit c841580

Please sign in to comment.