diff --git a/README.md b/README.md index b3f75bb..1bc1ffd 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,8 @@ class DrinkSodaTest( } ``` +If the parameter is nullable, Burst will also test with null. + ### Multiple Parameters Use multiple parameters to test all variations. diff --git a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt index dbb0642..e2efeb7 100644 --- a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt +++ b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt @@ -106,7 +106,7 @@ class BurstKotlinPluginTest { assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages) assertThat(result.messages).contains( "CoffeeTest.kt:7:12 " + - "@Burst parameter must be a boolean, enum, or have a burstValues() default value", + "@Burst parameter must be a boolean, an enum, or have a burstValues() default value", ) } @@ -135,7 +135,7 @@ class BurstKotlinPluginTest { assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages) assertThat(result.messages).contains( "CoffeeTest.kt:9:12 " + - "@Burst parameter default value must be burstValues(), a constant, or absent", + "@Burst parameter default must be burstValues(), a constant, null, or absent", ) } @@ -681,6 +681,166 @@ class BurstKotlinPluginTest { ) } + @Test + fun nullableEnumNoDefault() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import kotlin.test.Test + + @Burst + class CoffeeTest { + val log = mutableListOf() + + @Test + fun test(espresso: Espresso?) { + log += "running ${'$'}espresso" + } + } + + enum class Espresso { Decaf, Regular, Double } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + val baseInstance = baseClass.constructors.single().newInstance() + val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> + + baseClass.getMethod("test_Decaf").invoke(baseInstance) + baseClass.getMethod("test_Regular").invoke(baseInstance) + baseClass.getMethod("test_Double").invoke(baseInstance) + baseClass.getMethod("test_null").invoke(baseInstance) + assertThat(baseLog).containsExactly( + "running Decaf", + "running Regular", + "running Double", + "running null", + ) + } + + @Test + fun nullableBooleanNoDefault() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import kotlin.test.Test + + @Burst + class CoffeeTest { + val log = mutableListOf() + + @Test + fun test(iced: Boolean?) { + log += "running ${'$'}iced" + } + } + + enum class Espresso { Decaf, Regular, Double } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + val baseInstance = baseClass.constructors.single().newInstance() + val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> + + baseClass.getMethod("test_false").invoke(baseInstance) + baseClass.getMethod("test_true").invoke(baseInstance) + baseClass.getMethod("test_null").invoke(baseInstance) + assertThat(baseLog).containsExactly( + "running false", + "running true", + "running null", + ) + } + + @Test + fun nullableEnumAsDefault() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import kotlin.test.Test + + @Burst + class CoffeeTest { + val log = mutableListOf() + + @Test + fun test(espresso: Espresso? = null) { + log += "running ${'$'}espresso" + } + } + + enum class Espresso { Decaf, Regular, Double } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + val baseInstance = baseClass.constructors.single().newInstance() + val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> + + baseClass.getMethod("test_Decaf").invoke(baseInstance) + baseClass.getMethod("test_Regular").invoke(baseInstance) + baseClass.getMethod("test_Double").invoke(baseInstance) + baseClass.getMethod("test").invoke(baseInstance) + assertThat(baseLog).containsExactly( + "running Decaf", + "running Regular", + "running Double", + "running null", + ) + } + + @Test + fun nullableBooleanAsDefault() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import kotlin.test.Test + + @Burst + class CoffeeTest { + val log = mutableListOf() + + @Test + fun test(iced: Boolean? = null) { + log += "running ${'$'}iced" + } + } + + enum class Espresso { Decaf, Regular, Double } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + val baseInstance = baseClass.constructors.single().newInstance() + val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> + + baseClass.getMethod("test_false").invoke(baseInstance) + baseClass.getMethod("test_true").invoke(baseInstance) + baseClass.getMethod("test").invoke(baseInstance) + assertThat(baseLog).containsExactly( + "running false", + "running true", + "running null", + ) + } + private val Class<*>.testSuffixes: List get() = methods.mapNotNull { when { diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt index 3057617..3aaeb2a 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt @@ -33,6 +33,7 @@ import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI import org.jetbrains.kotlin.ir.types.IrType import org.jetbrains.kotlin.ir.types.classFqName import org.jetbrains.kotlin.ir.types.getClass +import org.jetbrains.kotlin.ir.types.isNullable import org.jetbrains.kotlin.ir.util.classId import org.jetbrains.kotlin.ir.util.deepCopyWithSymbols import org.jetbrains.kotlin.ir.util.defaultType @@ -86,6 +87,20 @@ private class BooleanArgument( } } +private class NullArgument( + private val original: IrElement, + private val type: IrType, + override val isDefault: Boolean, +) : Argument { + override val name = "null" + + override fun expression() = IrConstImpl.constNull(original.startOffset, original.endOffset, type) + + override fun accept(visitor: IrElementVisitor, data: D): R { + return original.accept(visitor, data) + } +} + @UnsafeDuringIrConstructionAPI private class BurstValuesArgument( private val parameter: IrValueParameter, @@ -192,49 +207,86 @@ private fun IrExpression.suggestedName(): String? { private fun enumValueArguments( referenceClass: IrClass, parameter: IrValueParameter, -): List { +): List { val enumEntries = referenceClass.declarations.filterIsInstance() - val defaultValueSymbol = parameter.defaultValue?.let { defaultValue -> - (defaultValue.expression as? IrGetEnumValue)?.symbol ?: unexpectedDefaultValue(parameter) + val hasDefaultValue = parameter.defaultValue != null + val defaultEnumSymbol = parameter.defaultValue?.let { defaultValue -> + val expression = defaultValue.expression + when { + expression is IrGetEnumValue -> expression.symbol + expression is IrConst<*> && expression.value == null -> null + else -> unexpectedDefaultValue(parameter) + } } - return enumEntries.map { - EnumValueArgument( - original = parameter, - type = parameter.type, - isDefault = it.symbol == defaultValueSymbol, - value = it, - ) + return buildList { + for (enumEntry in enumEntries) { + add( + EnumValueArgument( + original = parameter, + type = parameter.type, + isDefault = hasDefaultValue && enumEntry.symbol == defaultEnumSymbol, + value = enumEntry, + ), + ) + } + if (parameter.type.isNullable()) { + add( + NullArgument( + original = parameter, + type = parameter.type, + isDefault = hasDefaultValue && defaultEnumSymbol == null, + ), + ) + } } } private fun IrPluginContext.booleanArguments( parameter: IrValueParameter, -): List { +): List { + val hasDefaultValue = parameter.defaultValue != null val defaultValue = parameter.defaultValue?.let { defaultValue -> - (defaultValue.expression as? IrConst<*>)?.value ?: unexpectedDefaultValue(parameter) + val expression = defaultValue.expression + when { + expression is IrConst<*> -> expression.value + else -> unexpectedDefaultValue(parameter) + } } - return listOf(false, true).map { - BooleanArgument( - original = parameter, - booleanType = irBuiltIns.booleanType, - isDefault = defaultValue == it, - value = it, - ) + return buildList { + for (b in listOf(false, true)) { + add( + BooleanArgument( + original = parameter, + booleanType = irBuiltIns.booleanType, + isDefault = hasDefaultValue && defaultValue == b, + value = b, + ), + ) + } + if (parameter.type.isNullable()) { + add( + NullArgument( + original = parameter, + type = parameter.type, + isDefault = hasDefaultValue && defaultValue == null, + ), + ) + } } } private fun unexpectedParameter(parameter: IrValueParameter): Nothing { throw BurstCompilationException( - "@Burst parameter must be a boolean, enum, or have a burstValues() default value", + "@Burst parameter must be a boolean, an enum, or have a burstValues() default value", parameter, ) } private fun unexpectedDefaultValue(parameter: IrValueParameter): Nothing { throw BurstCompilationException( - "@Burst parameter default value must be burstValues(), a constant, or absent", + "@Burst parameter default must be burstValues(), a constant, null, or absent", parameter, ) }