Skip to content

Commit

Permalink
Support nullable parameters
Browse files Browse the repository at this point in the history
Closes: #42
  • Loading branch information
squarejesse committed Oct 30, 2024
1 parent e50a772 commit fcf0fdc
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 23 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class DrinkSodaTest(
}
```

If the parameter is nullable, Burst will also test with null.

### Multiple Parameters

Use multiple parameters to test all variations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class BurstKotlinPluginTest {
assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages)
assertThat(result.messages).contains(
"CoffeeTest.kt:7:12 " +
"@Burst parameter must be a boolean, enum, or have a burstValues() default value",
"@Burst parameter must be a boolean, an enum, or have a burstValues() default value",
)
}

Expand Down Expand Up @@ -135,7 +135,7 @@ class BurstKotlinPluginTest {
assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages)
assertThat(result.messages).contains(
"CoffeeTest.kt:9:12 " +
"@Burst parameter default value must be burstValues(), a constant, or absent",
"@Burst parameter default must be burstValues(), a constant, null, or absent",
)
}

Expand Down Expand Up @@ -681,6 +681,166 @@ class BurstKotlinPluginTest {
)
}

@Test
fun nullableEnumNoDefault() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import kotlin.test.Test
@Burst
class CoffeeTest {
val log = mutableListOf<String>()
@Test
fun test(espresso: Espresso?) {
log += "running ${'$'}espresso"
}
}
enum class Espresso { Decaf, Regular, Double }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")
val baseInstance = baseClass.constructors.single().newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

baseClass.getMethod("test_Decaf").invoke(baseInstance)
baseClass.getMethod("test_Regular").invoke(baseInstance)
baseClass.getMethod("test_Double").invoke(baseInstance)
baseClass.getMethod("test_null").invoke(baseInstance)
assertThat(baseLog).containsExactly(
"running Decaf",
"running Regular",
"running Double",
"running null",
)
}

@Test
fun nullableBooleanNoDefault() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import kotlin.test.Test
@Burst
class CoffeeTest {
val log = mutableListOf<String>()
@Test
fun test(iced: Boolean?) {
log += "running ${'$'}iced"
}
}
enum class Espresso { Decaf, Regular, Double }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")
val baseInstance = baseClass.constructors.single().newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

baseClass.getMethod("test_false").invoke(baseInstance)
baseClass.getMethod("test_true").invoke(baseInstance)
baseClass.getMethod("test_null").invoke(baseInstance)
assertThat(baseLog).containsExactly(
"running false",
"running true",
"running null",
)
}

@Test
fun nullableEnumAsDefault() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import kotlin.test.Test
@Burst
class CoffeeTest {
val log = mutableListOf<String>()
@Test
fun test(espresso: Espresso? = null) {
log += "running ${'$'}espresso"
}
}
enum class Espresso { Decaf, Regular, Double }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")
val baseInstance = baseClass.constructors.single().newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

baseClass.getMethod("test_Decaf").invoke(baseInstance)
baseClass.getMethod("test_Regular").invoke(baseInstance)
baseClass.getMethod("test_Double").invoke(baseInstance)
baseClass.getMethod("test").invoke(baseInstance)
assertThat(baseLog).containsExactly(
"running Decaf",
"running Regular",
"running Double",
"running null",
)
}

@Test
fun nullableBooleanAsDefault() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import kotlin.test.Test
@Burst
class CoffeeTest {
val log = mutableListOf<String>()
@Test
fun test(iced: Boolean? = null) {
log += "running ${'$'}iced"
}
}
enum class Espresso { Decaf, Regular, Double }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")
val baseInstance = baseClass.constructors.single().newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

baseClass.getMethod("test_false").invoke(baseInstance)
baseClass.getMethod("test_true").invoke(baseInstance)
baseClass.getMethod("test").invoke(baseInstance)
assertThat(baseLog).containsExactly(
"running false",
"running true",
"running null",
)
}

private val Class<*>.testSuffixes: List<String>
get() = methods.mapNotNull {
when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.classFqName
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.types.isNullable
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.deepCopyWithSymbols
import org.jetbrains.kotlin.ir.util.defaultType
Expand Down Expand Up @@ -86,6 +87,20 @@ private class BooleanArgument(
}
}

private class NullArgument(
private val original: IrElement,
private val type: IrType,
override val isDefault: Boolean,
) : Argument {
override val name = "null"

override fun expression() = IrConstImpl.constNull(original.startOffset, original.endOffset, type)

override fun <R, D> accept(visitor: IrElementVisitor<R, D>, data: D): R {
return original.accept(visitor, data)
}
}

@UnsafeDuringIrConstructionAPI
private class BurstValuesArgument(
private val parameter: IrValueParameter,
Expand Down Expand Up @@ -192,49 +207,86 @@ private fun IrExpression.suggestedName(): String? {
private fun enumValueArguments(
referenceClass: IrClass,
parameter: IrValueParameter,
): List<EnumValueArgument> {
): List<Argument> {
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
val defaultValueSymbol = parameter.defaultValue?.let { defaultValue ->
(defaultValue.expression as? IrGetEnumValue)?.symbol ?: unexpectedDefaultValue(parameter)
val hasDefaultValue = parameter.defaultValue != null
val defaultEnumSymbol = parameter.defaultValue?.let { defaultValue ->
val expression = defaultValue.expression
when {
expression is IrGetEnumValue -> expression.symbol
expression is IrConst<*> && expression.value == null -> null
else -> unexpectedDefaultValue(parameter)
}
}

return enumEntries.map {
EnumValueArgument(
original = parameter,
type = parameter.type,
isDefault = it.symbol == defaultValueSymbol,
value = it,
)
return buildList {
for (enumEntry in enumEntries) {
add(
EnumValueArgument(
original = parameter,
type = parameter.type,
isDefault = hasDefaultValue && enumEntry.symbol == defaultEnumSymbol,
value = enumEntry,
),
)
}
if (parameter.type.isNullable()) {
add(
NullArgument(
original = parameter,
type = parameter.type,
isDefault = hasDefaultValue && defaultEnumSymbol == null,
),
)
}
}
}

private fun IrPluginContext.booleanArguments(
parameter: IrValueParameter,
): List<BooleanArgument> {
): List<Argument> {
val hasDefaultValue = parameter.defaultValue != null
val defaultValue = parameter.defaultValue?.let { defaultValue ->
(defaultValue.expression as? IrConst<*>)?.value ?: unexpectedDefaultValue(parameter)
val expression = defaultValue.expression
when {
expression is IrConst<*> -> expression.value
else -> unexpectedDefaultValue(parameter)
}
}

return listOf(false, true).map {
BooleanArgument(
original = parameter,
booleanType = irBuiltIns.booleanType,
isDefault = defaultValue == it,
value = it,
)
return buildList {
for (b in listOf(false, true)) {
add(
BooleanArgument(
original = parameter,
booleanType = irBuiltIns.booleanType,
isDefault = hasDefaultValue && defaultValue == b,
value = b,
),
)
}
if (parameter.type.isNullable()) {
add(
NullArgument(
original = parameter,
type = parameter.type,
isDefault = hasDefaultValue && defaultValue == null,
),
)
}
}
}

private fun unexpectedParameter(parameter: IrValueParameter): Nothing {
throw BurstCompilationException(
"@Burst parameter must be a boolean, enum, or have a burstValues() default value",
"@Burst parameter must be a boolean, an enum, or have a burstValues() default value",
parameter,
)
}

private fun unexpectedDefaultValue(parameter: IrValueParameter): Nothing {
throw BurstCompilationException(
"@Burst parameter default value must be burstValues(), a constant, or absent",
"@Burst parameter default must be burstValues(), a constant, null, or absent",
parameter,
)
}

0 comments on commit fcf0fdc

Please sign in to comment.