Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v1] Change integer types to error on overflow; give explicit data exception on div/mod by 0 #1715

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package org.partiql.eval.internal

import org.junit.jupiter.api.parallel.Execution
import org.junit.jupiter.api.parallel.ExecutionMode
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource

/**
* E2E evaluation tests that give a data exception.
*/
class DataExceptionTest {

@ParameterizedTest
@MethodSource("plusOverflowTests")
@Execution(ExecutionMode.CONCURRENT)
fun plusOverflow(tc: FailureTestCase) = tc.run()

@ParameterizedTest
@MethodSource("minusOverflowTests")
@Execution(ExecutionMode.CONCURRENT)
fun minusOverflow(tc: FailureTestCase) = tc.run()

@ParameterizedTest
@MethodSource("timesOverflowTests")
@Execution(ExecutionMode.CONCURRENT)
fun timesOverflow(tc: FailureTestCase) = tc.run()

@ParameterizedTest
@MethodSource("divideTests")
@Execution(ExecutionMode.CONCURRENT)
fun divideOverflow(tc: FailureTestCase) = tc.run()

@ParameterizedTest
@MethodSource("divideByZeroTests")
fun divideByZero(tc: FailureTestCase) = tc.run()

companion object {
@JvmStatic
fun plusOverflowTests() = listOf(
// TINYINT
// TODO add parsing and planning support for TINYINT
// FailureTestCase(
// input = "CAST(${Byte.MAX_VALUE} AS TINYINT) + CAST(1 AS TINYINT);"
// ),
// FailureTestCase(
// input = "CAST(${Byte.MIN_VALUE} AS TINYINT) + CAST(-1 AS TINYINT);"
// ),
// SMALLINT
FailureTestCase(
input = "CAST(${Short.MAX_VALUE} AS SMALLINT) + CAST(1 AS SMALLINT);"
),
FailureTestCase(
input = "CAST(${Short.MIN_VALUE} AS SMALLINT) + CAST(-1 AS SMALLINT);"
),
// INT
FailureTestCase(
input = "CAST(${Integer.MAX_VALUE} AS INT) + CAST(1 AS INT);"
),
FailureTestCase(
input = "CAST(${Integer.MIN_VALUE} AS INT) + CAST(-1 AS INT);"
),
// BIGINT
FailureTestCase(
input = "CAST(${Long.MAX_VALUE} AS BIGINT) + CAST(1 AS BIGINT);"
),
FailureTestCase(
input = "CAST(${Long.MIN_VALUE} AS BIGINT) + CAST(-1 AS BIGINT);"
)
)

@JvmStatic
fun minusOverflowTests() = listOf(
// TINYINT
// TODO add parsing and planning support for TINYINT
// FailureTestCase(
// input = "CAST(${Byte.MAX_VALUE} AS TINYINT) - CAST(-1 AS TINYINT);"
// ),
// FailureTestCase(
// input = "CAST(${Byte.MIN_VALUE} AS TINYINT) - CAST(1 AS TINYINT);"
// ),
// SMALLINT
FailureTestCase(
input = "CAST(${Short.MAX_VALUE} AS SMALLINT) - CAST(-1 AS SMALLINT);"
),
FailureTestCase(
input = "CAST(${Short.MIN_VALUE} AS SMALLINT) - CAST(1 AS SMALLINT);"
),
// INT
FailureTestCase(
input = "CAST(${Integer.MAX_VALUE} AS INT) - CAST(-1 AS INT);"
),
FailureTestCase(
input = "CAST(${Integer.MIN_VALUE} AS INT) - CAST(1 AS INT);"
),
// BIGINT
FailureTestCase(
input = "CAST(${Long.MAX_VALUE} AS BIGINT) - CAST(-1 AS BIGINT);"
),
FailureTestCase(
input = "CAST(${Long.MIN_VALUE} AS BIGINT) - CAST(1 AS BIGINT);"
)
)

@JvmStatic
fun timesOverflowTests() = listOf(
// TINYINT
// TODO add parsing and planning support for TINYINT
// FailureTestCase(
// input = "CAST(${Byte.MAX_VALUE} AS TINYINT) * CAST(2 AS TINYINT);"
// ),
// FailureTestCase(
// input = "CAST(${Byte.MIN_VALUE} AS TINYINT) * CAST(2 AS TINYINT);"
// ),
// SMALLINT
FailureTestCase(
input = "CAST(${Short.MAX_VALUE} AS SMALLINT) * CAST(2 AS SMALLINT);"
),
FailureTestCase(
input = "CAST(${Short.MIN_VALUE} AS SMALLINT) * CAST(2 AS SMALLINT);"
),
// INT
FailureTestCase(
input = "CAST(${Integer.MAX_VALUE} AS INT) * CAST(2 AS INT);"
),
FailureTestCase(
input = "CAST(${Integer.MIN_VALUE} AS INT) * CAST(2 AS INT);"
),
// BIGINT
FailureTestCase(
input = "CAST(${Long.MAX_VALUE} AS BIGINT) * CAST(2 AS BIGINT);"
),
FailureTestCase(
input = "CAST(${Long.MIN_VALUE} AS BIGINT) * CAST(2 AS BIGINT);"
)
)

@JvmStatic
fun divideTests() = listOf(
// TINYINT
// TODO add parsing and planning support for TINYINT
// FailureTestCase(
// input = "CAST(${Byte.MIN_VALUE} AS TINYINT) / CAST(-1 AS TINYINT)"
// ),
// SMALLINT
FailureTestCase(
input = "CAST(${Short.MIN_VALUE} AS SMALLINT) / CAST(-1 AS SMALLINT)"
),
// INT
FailureTestCase(
input = "CAST(${Integer.MIN_VALUE} AS INT) / CAST(-1 AS INT)"
),
// BIGINT
FailureTestCase(
input = "CAST(${Long.MIN_VALUE} AS BIGINT) / CAST(-1 AS BIGINT)"
)
)

@JvmStatic
fun divideByZeroTests() = listOf(
// TINYINT
// TODO add parsing and planning support for TINYINT
// FailureTestCase(
// input = "CAST(1 AS TINYINT) / CAST(0 AS TINYINT)"
// ),
// SMALLINT
FailureTestCase(
input = "CAST(1 AS SMALLINT) / CAST(0 AS SMALLINT)"
),
// INT
FailureTestCase(
input = "CAST(1 AS INT) / CAST(0 AS INT)"
),
// BIGINT
FailureTestCase(
input = "CAST(1 AS BIGINT) / CAST(0 AS BIGINT)"
)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class PartiQLEvaluatorTest {
int64Value(2),
),
globals = listOf(
SuccessTestCase.Global(
Global(
name = "t",
value = """
[
Expand All @@ -150,7 +150,7 @@ class PartiQLEvaluatorTest {
int64Value(2),
),
globals = listOf(
SuccessTestCase.Global(
Global(
name = "t",
value = """
[
Expand All @@ -177,7 +177,7 @@ class PartiQLEvaluatorTest {
),
),
globals = listOf(
SuccessTestCase.Global(
Global(
name = "customers",
value = """
[{id:1, name: "Mary"},
Expand All @@ -186,7 +186,7 @@ class PartiQLEvaluatorTest {
]
"""
),
SuccessTestCase.Global(
Global(
name = "orders",
value = """
[{custId:1, name: "foo"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class PlusTest {
mode = Mode.STRICT(),
expected = Datum.decimal(BigDecimal.valueOf(457023), 14, 7),
globals = listOf(
SuccessTestCase.Global(
Global(
"dynamic_decimal",
"456789.0000000"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ import org.partiql.types.fromStaticType
import org.partiql.value.PartiQLValue
import kotlin.test.assertEquals

/**
* @property value is a serialized Ion value.
*/
class Global(
val name: String,
val value: String,
val type: StaticType = StaticType.ANY,
)

public class SuccessTestCase(
val input: String,
val expected: Datum,
Expand All @@ -37,15 +46,6 @@ public class SuccessTestCase(
private val parser = PartiQLParser.standard()
private val planner = PartiQLPlanner.standard()

/**
* @property value is a serialized Ion value.
*/
class Global(
val name: String,
val value: String,
val type: StaticType = StaticType.ANY,
)

override fun run() {
val parseResult = parser.parse(input)
assertEquals(1, parseResult.statements.size)
Expand Down Expand Up @@ -92,3 +92,51 @@ public class SuccessTestCase(
return input
}
}

public class FailureTestCase(
val input: String,
val mode: Mode = Mode.STRICT(), // default to run in STRICT mode
val globals: List<Global> = emptyList(),
) : PTestCase {
private val compiler = PartiQLCompiler.standard()
private val parser = PartiQLParser.standard()
private val planner = PartiQLPlanner.standard()

override fun run() {
val parseResult = parser.parse(input)
assertEquals(1, parseResult.statements.size)
val statement = parseResult.statements[0]
val catalog = Catalog.builder()
.name("memory")
.apply {
globals.forEach {
val table = Table.standard(
name = Name.of(it.name),
schema = fromStaticType(it.type),
datum = DatumReader.ion(it.value.byteInputStream()).next()!!
)
define(table)
}
}
.build()
val session = Session.builder()
.catalog("memory")
.catalogs(catalog)
.build()
var thrown: Throwable? = null
val plan = planner.plan(statement, session).plan
val actual: Datum = try {
DatumMaterialize.materialize(compiler.prepare(plan, mode).execute())
} catch (t: Throwable) {
thrown = t
Datum.nullValue()
}
if (thrown == null) {
val message = buildString {
appendLine("Expected error to be thrown but none was thrown.")
appendLine("Actual Result: $actual")
}
error(message)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package org.partiql.spi.function.builtins

import org.partiql.spi.errors.DataException
import org.partiql.spi.function.Function
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum
Expand All @@ -15,8 +16,13 @@ internal object FnDivide : DiadicArithmeticOperator("divide") {

override fun getTinyIntInstance(tinyIntLhs: PType, tinyIntRhs: PType): Function.Instance {
return basic(PType.tinyint()) { args ->
@Suppress("DEPRECATION") val arg0 = args[0].byte
@Suppress("DEPRECATION") val arg1 = args[1].byte
val arg0 = args[0].byte
val arg1 = args[1].byte
if (arg1 == 0.toByte()) {
throw DataException("Division by zero for TINYINT: $arg0 / $arg1")
} else if (arg0 == Byte.MIN_VALUE && arg1.toInt() == -1) {
throw DataException("Resulting value out of range for: $arg0 / $arg1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the following 3, you removed the type name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah my bad. I will fix in the target branch.

}
Datum.tinyint((arg0 / arg1).toByte())
}
}
Expand All @@ -25,6 +31,11 @@ internal object FnDivide : DiadicArithmeticOperator("divide") {
return basic(PType.smallint()) { args ->
val arg0 = args[0].short
val arg1 = args[1].short
if (arg1 == 0.toShort()) {
throw DataException("Division by zero for SMALLINT: $arg0 / $arg1")
} else if (arg0 == Short.MIN_VALUE && arg1.toInt() == -1) {
throw DataException("Resulting value out of range for: $arg0 / $arg1")
}
Datum.smallint((arg0 / arg1).toShort())
}
}
Expand All @@ -33,6 +44,11 @@ internal object FnDivide : DiadicArithmeticOperator("divide") {
return basic(PType.integer()) { args ->
val arg0 = args[0].int
val arg1 = args[1].int
if (arg1 == 0) {
throw DataException("Division by zero for INT: $arg0 / $arg1")
} else if (arg0 == Int.MIN_VALUE && arg1 == -1) {
throw DataException("Resulting value out of range for INT: $arg0 / $arg1")
}
Datum.integer(arg0 / arg1)
}
}
Expand All @@ -41,6 +57,11 @@ internal object FnDivide : DiadicArithmeticOperator("divide") {
return basic(PType.bigint()) { args ->
val arg0 = args[0].long
val arg1 = args[1].long
if (arg1 == 0L) {
throw DataException("Division by zero for BIGINT: $arg0 / $arg1")
} else if (arg0 == Long.MIN_VALUE && arg1 == -1L) {
throw DataException("Resulting value out of range for BIGINT: $arg0 / $arg1")
}
Datum.bigint(arg0 / arg1)
}
}
Expand Down
Loading
Loading