diff --git a/CHANGELOG.md b/CHANGELOG.md index 98b7a3a..6388c55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,10 @@ ## [Unreleased] [Unreleased]: https://github.com/cashapp/burst/compare/0.5.0...HEAD +**Added** + + * New: Use default parameter values to configure which specialization runs in the IDE. + ## [0.5.0] *(2024-10-17)* [0.5.0]: https://github.com/cashapp/burst/releases/tag/0.4.0 diff --git a/README.md b/README.md index e955d76..34de3dc 100644 --- a/README.md +++ b/README.md @@ -47,19 +47,20 @@ Annotate your test class with `@Burst`, and accept an enum as a constructor para ```kotlin @Burst class DrinkSodaTest( - val soda: Soda, + val soda: Soda = Soda.Pepsi, ) { ... } ``` -Burst will specialize the test class for each value in the enum. +Burst will specialize the test class for each value in the enum. If you specified a default value +for the parameter, it'll be used when you run the test in the IDE. Burst can also specialize individual test functions: ```kotlin @Test -fun drinkFavoriteSodas(soda: Soda) { +fun drinkFavoriteSodas(soda: Soda = Soda.Pepsi) { ... } ``` @@ -68,7 +69,10 @@ Use multiple enums for the combination of their variations. ```kotlin @Test -fun collectSodas(soda: Soda, collectionsFactory: CollectionFactory) { +fun collectSodas( + soda: Soda = Soda.Pepsi, + collectionsFactory: CollectionFactory = CollectionFactory.MutableSetOf, +) { ... } ``` diff --git a/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt b/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt index d679c51..74253aa 100644 --- a/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt +++ b/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt @@ -17,13 +17,10 @@ package app.cash.burst.gradle import assertk.assertThat -import assertk.assertions.contains import assertk.assertions.containsExactlyInAnyOrder -import assertk.assertions.isEmpty import assertk.assertions.isEqualTo import assertk.assertions.isFalse import assertk.assertions.isIn -import assertk.assertions.isTrue import java.io.File import org.gradle.testkit.runner.GradleRunner import org.gradle.testkit.runner.TaskOutcome @@ -70,59 +67,17 @@ class BurstGradlePluginTest { val testResults = projectDir.resolve("lib/build/test-results") - // The original test class runs the default specialization. - with(readTestSuite(testResults.resolve("$testTaskName/TEST-CoffeeTest.xml"))) { - assertThat(testCases.map { it.name }).containsExactlyInAnyOrder( - "test[$platformName]", - "test_Milk[$platformName]", - "test_None[$platformName]", - "test_Oat[$platformName]", - ) - - val defaultFunction = testCases.single { it.name == "test[$platformName]" } - assertThat(defaultFunction.skipped).isFalse() - - val defaultSpecialization = testCases.single { it.name == "test_None[$platformName]" } - assertThat(defaultSpecialization.skipped).isTrue() - - val sampleSpecialization = testCases.single { it.name == "test_Milk[$platformName]" } - assertThat(sampleSpecialization.skipped).isFalse() - } - - // The default test class is completely skipped. - with(readTestSuite(testResults.resolve("$testTaskName/TEST-CoffeeTest_Decaf.xml"))) { - assertThat(testCases.map { it.name }).containsExactlyInAnyOrder( - "test[$platformName]", - "test_Milk[$platformName]", - "test_None[$platformName]", - "test_Oat[$platformName]", - ) + // There's no default specialization. + assertThat(testResults.resolve("$testTaskName/TEST-CoffeeTest.xml").exists()).isFalse() - val defaultFunction = testCases.single { it.name == "test[$platformName]" } - assertThat(defaultFunction.skipped).isTrue() - - val defaultSpecialization = testCases.single { it.name == "test_None[$platformName]" } - assertThat(defaultSpecialization.skipped).isTrue() - - val sampleSpecialization = testCases.single { it.name == "test_Milk[$platformName]" } - assertThat(sampleSpecialization.skipped).isTrue() - } - - // Another test class is executed normally with nothing skipped. + // Each test class is executed normally with nothing skipped. with(readTestSuite(testResults.resolve("$testTaskName/TEST-CoffeeTest_Regular.xml"))) { assertThat(testCases.map { it.name }).containsExactlyInAnyOrder( - "test[$platformName]", "test_Milk[$platformName]", "test_None[$platformName]", "test_Oat[$platformName]", ) - val defaultFunction = testCases.single { it.name == "test[$platformName]" } - assertThat(defaultFunction.skipped).isFalse() - - val defaultSpecialization = testCases.single { it.name == "test_None[$platformName]" } - assertThat(defaultSpecialization.skipped).isTrue() - val sampleSpecialization = testCases.single { it.name == "test_Milk[$platformName]" } assertThat(sampleSpecialization.skipped).isFalse() } @@ -142,7 +97,6 @@ class BurstGradlePluginTest { val testSuite = readTestSuite(testXmlFile) assertThat(testSuite.testCases.map { it.name }).containsExactlyInAnyOrder( - "test", "test_Decaf_Milk", "test_Decaf_None", "test_Decaf_Oat", @@ -154,12 +108,6 @@ class BurstGradlePluginTest { "test_Regular_Oat", ) - val originalTest = testSuite.testCases.single { it.name == "test" } - assertThat(originalTest.skipped).isFalse() - - val defaultSpecialization = testSuite.testCases.single { it.name == "test_Decaf_None" } - assertThat(defaultSpecialization.skipped).isTrue() - val sampleSpecialization = testSuite.testCases.single { it.name == "test_Regular_Milk" } assertThat(sampleSpecialization.skipped).isFalse() } @@ -174,22 +122,8 @@ class BurstGradlePluginTest { val testResults = projectDir.resolve("lib/build/test-results") - val coffeeTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest.xml")) - val coffeeTestTest = coffeeTest.testCases.single() - assertThat(coffeeTestTest.name).isEqualTo("test") - assertThat(coffeeTest.systemOut).isEqualTo( - """ - |set up Decaf None - |running Decaf None - | - """.trimMargin(), - ) - - val defaultTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest_Decaf_None.xml")) - val defaultTestTest = defaultTest.testCases.single() - assertThat(defaultTestTest.name).isEqualTo("test") - assertThat(defaultTestTest.skipped).isTrue() - assertThat(defaultTest.systemOut).isEmpty() + // There's no default specialization. + assertThat(testResults.resolve("test/TEST-CoffeeTest.xml").exists()).isFalse() val sampleTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest_Regular_Milk.xml")) val sampleTestTest = sampleTest.testCases.single() @@ -215,6 +149,49 @@ class BurstGradlePluginTest { assertThat(result.task(androidTestTaskName)!!.outcome).isIn(*SUCCESS_OUTCOMES) } + @Test + fun defaultArguments() { + val projectDir = File("src/test/projects/defaultArguments") + + val result = createRunner(projectDir, "clean", ":lib:test").build() + assertThat(result.task(":lib:test")!!.outcome).isIn(*SUCCESS_OUTCOMES) + + val testResults = projectDir.resolve("lib/build/test-results") + + // The original test class runs the default specialization. + with(readTestSuite(testResults.resolve("test/TEST-CoffeeTest.xml"))) { + assertThat(testCases.map { it.name }).containsExactlyInAnyOrder( + "test", + "test_None", + "test_Oat", + ) + + val defaultFunction = testCases.single { it.name == "test" } + assertThat(defaultFunction.skipped).isFalse() + + val sampleSpecialization = testCases.single { it.name == "test_Oat" } + assertThat(sampleSpecialization.skipped).isFalse() + } + + // No subclass is generated for the default specialization. + assertThat(testResults.resolve("test/TEST-CoffeeTest_Regular.xml").exists()).isFalse() + + // Another test class is executed normally with nothing skipped. + with(readTestSuite(testResults.resolve("test/TEST-CoffeeTest_Double.xml"))) { + assertThat(testCases.map { it.name }).containsExactlyInAnyOrder( + "test", + "test_None", + "test_Oat", + ) + + val defaultFunction = testCases.single { it.name == "test" } + assertThat(defaultFunction.skipped).isFalse() + + val sampleSpecialization = testCases.single { it.name == "test_Oat" } + assertThat(sampleSpecialization.skipped).isFalse() + } + } + private fun createRunner( projectDir: File, vararg taskNames: String, diff --git a/burst-gradle-plugin/src/test/projects/defaultArguments/build.gradle.kts b/burst-gradle-plugin/src/test/projects/defaultArguments/build.gradle.kts new file mode 100644 index 0000000..66fc2bb --- /dev/null +++ b/burst-gradle-plugin/src/test/projects/defaultArguments/build.gradle.kts @@ -0,0 +1,37 @@ +import org.jetbrains.kotlin.gradle.dsl.JvmTarget +import org.jetbrains.kotlin.gradle.tasks.KotlinJvmCompile + +buildscript { + repositories { + maven { + url = file("$rootDir/../../../../../build/testMaven").toURI() + } + mavenCentral() + google() + } + dependencies { + classpath("app.cash.burst:burst-gradle-plugin:${project.property("burstVersion")}") + classpath(libs.kotlin.gradlePlugin) + } +} + +allprojects { + repositories { + maven { + url = file("$rootDir/../../../../../build/testMaven").toURI() + } + mavenCentral() + google() + } + + tasks.withType(JavaCompile::class.java).configureEach { + sourceCompatibility = "1.8" + targetCompatibility = "1.8" + } + + tasks.withType(KotlinJvmCompile::class.java).configureEach { + compilerOptions { + jvmTarget.set(JvmTarget.JVM_1_8) + } + } +} diff --git a/burst-gradle-plugin/src/test/projects/defaultArguments/lib/build.gradle.kts b/burst-gradle-plugin/src/test/projects/defaultArguments/lib/build.gradle.kts new file mode 100644 index 0000000..36cb092 --- /dev/null +++ b/burst-gradle-plugin/src/test/projects/defaultArguments/lib/build.gradle.kts @@ -0,0 +1,8 @@ +plugins { + kotlin("jvm") + id("app.cash.burst") +} + +dependencies { + testImplementation(kotlin("test")) +} diff --git a/burst-gradle-plugin/src/test/projects/defaultArguments/lib/src/test/kotlin/CoffeeTest.kt b/burst-gradle-plugin/src/test/projects/defaultArguments/lib/src/test/kotlin/CoffeeTest.kt new file mode 100644 index 0000000..55dc023 --- /dev/null +++ b/burst-gradle-plugin/src/test/projects/defaultArguments/lib/src/test/kotlin/CoffeeTest.kt @@ -0,0 +1,22 @@ +import app.cash.burst.Burst +import kotlin.test.BeforeTest +import kotlin.test.Test + +@Burst +class CoffeeTest( + private val espresso: Espresso = Espresso.Regular, +) { + @BeforeTest + fun setUp() { + println("set up $espresso") + } + + @Test + fun test(dairy: Dairy = Dairy.Milk) { + println("running $espresso $dairy") + } +} + +enum class Espresso { Decaf, Regular, Double } + +enum class Dairy { None, Milk, Oat } diff --git a/burst-gradle-plugin/src/test/projects/defaultArguments/settings.gradle.kts b/burst-gradle-plugin/src/test/projects/defaultArguments/settings.gradle.kts new file mode 100644 index 0000000..78eeed4 --- /dev/null +++ b/burst-gradle-plugin/src/test/projects/defaultArguments/settings.gradle.kts @@ -0,0 +1,9 @@ +dependencyResolutionManagement { + versionCatalogs { + create("libs") { + from(files("../../../../../gradle/libs.versions.toml")) + } + } +} + +include(":lib") diff --git a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt index fc687f2..02c4e07 100644 --- a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt +++ b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt @@ -24,9 +24,9 @@ import com.tschuchort.compiletesting.JvmCompilationResult import com.tschuchort.compiletesting.KotlinCompilation import com.tschuchort.compiletesting.SourceFile import java.lang.reflect.Modifier -import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import org.jetbrains.kotlin.compiler.plugin.CompilerPluginRegistrar import org.jetbrains.kotlin.compiler.plugin.ExperimentalCompilerApi @@ -66,32 +66,21 @@ class BurstKotlinPluginTest { val adapterInstance = testClass.constructors.single().newInstance() val log = testClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*> - // Burst drops @Test from the original test. + // Burst drops @Test from the original test function. val originalTest = testClass.methods.single { it.name == "test" && it.parameterCount == 2 } assertThat(originalTest.isAnnotationPresent(Test::class.java)).isFalse() // Burst adds a specialization for each combination of parameters. val sampleSpecialization = testClass.getMethod("test_Regular_Milk") assertThat(sampleSpecialization.isAnnotationPresent(Test::class.java)).isTrue() - assertThat(sampleSpecialization.isAnnotationPresent(Ignore::class.java)).isFalse() sampleSpecialization.invoke(adapterInstance) assertThat(log).containsExactly("running Regular Milk") log.clear() - // The first specialization is also annotated `@Ignore`. - val firstSpecialization = testClass.getMethod("test_Decaf_None") - assertThat(firstSpecialization.isAnnotationPresent(Test::class.java)).isTrue() - assertThat(firstSpecialization.isAnnotationPresent(Ignore::class.java)).isTrue() - firstSpecialization.invoke(adapterInstance) - assertThat(log).containsExactly("running Decaf None") - log.clear() - - // Burst adds a no-parameter function that calls the first specialization. - val noArgsTest = testClass.getMethod("test") - assertThat(noArgsTest.isAnnotationPresent(Test::class.java)).isTrue() - assertThat(noArgsTest.isAnnotationPresent(Ignore::class.java)).isFalse() - noArgsTest.invoke(adapterInstance) - assertThat(log).containsExactly("running Decaf None") + // Burst doesn't add a no-parameter function because there's no default specialization. + assertFailsWith { + testClass.getMethod("test") + } } @Test @@ -156,26 +145,14 @@ class BurstKotlinPluginTest { // Burst opens the class because it needs to subclass it. assertThat(Modifier.isFinal(baseClass.modifiers)).isFalse() - assertThat(baseClass.isAnnotationPresent(Ignore::class.java)).isFalse() - - // Burst adds a no-args constructor that binds the first enum value. - val baseConstructor = baseClass.constructors.single { it.parameterCount == 0 } - val baseInstance = baseConstructor.newInstance() - val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> - - // The setUp function gets the first value of each parameter. - baseClass.getMethod("setUp").invoke(baseInstance) - assertThat(baseLog).containsExactly("set up Decaf None") - baseLog.clear() - // The test function gets the same. - baseClass.getMethod("test").invoke(baseInstance) - assertThat(baseLog).containsExactly("running Decaf None") - baseLog.clear() + // Burst doesn't add a no-arg constructor because there's no default specialization. + assertFailsWith { + baseClass.getConstructor() + } // It generates a subclass for each specialization. val sampleClass = result.classLoader.loadClass("CoffeeTest_Regular_Milk") - assertThat(sampleClass.isAnnotationPresent(Ignore::class.java)).isFalse() val sampleConstructor = sampleClass.getConstructor() val sampleInstance = sampleConstructor.newInstance() val sampleLog = sampleClass.getMethod("getLog") @@ -187,10 +164,72 @@ class BurstKotlinPluginTest { "running Regular Milk", ) sampleLog.clear() + } + + @Test + fun defaultArgumentsHonored() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import kotlin.test.BeforeTest + import kotlin.test.Test + + @Burst + class CoffeeTest( + private val espresso: Espresso = Espresso.Regular, + ) { + val log = mutableListOf() + + @BeforeTest + fun setUp() { + log += "set up ${'$'}espresso" + } + + @Test + fun test(dairy: Dairy = Dairy.Milk) { + log += "running ${'$'}espresso ${'$'}dairy" + } + } + + enum class Espresso { Decaf, Regular, Double } + enum class Dairy { None, Milk, Oat } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + val baseConstructor = baseClass.constructors.single { it.parameterCount == 0 } + val baseInstance = baseConstructor.newInstance() + val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*> + + // The setUp function gets the default parameter value. + baseClass.getMethod("setUp").invoke(baseInstance) + assertThat(baseLog).containsExactly("set up Regular") + baseLog.clear() + + // The test function gets its default parameter value. + baseClass.getMethod("test").invoke(baseInstance) + assertThat(baseLog).containsExactly("running Regular Milk") + baseLog.clear() + + // The default specialization's subclass is not generated. + assertFailsWith { + result.classLoader.loadClass("CoffeeTest_Regular") + } + + // Other subclasses are available. + result.classLoader.loadClass("CoffeeTest_Double") + + // The default test function is also not generated. + assertFailsWith { + baseClass.getMethod("test_Milk") + } - // The default specialization is annotated `@Ignore`. - val defaultClass = result.classLoader.loadClass("CoffeeTest_Decaf_None") - assertThat(defaultClass.isAnnotationPresent(Ignore::class.java)).isTrue() + // Other test functions are available. + baseClass.getMethod("test_Oat") } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt index 07a87f5..1913043 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt @@ -20,6 +20,7 @@ import org.jetbrains.kotlin.ir.IrElement import org.jetbrains.kotlin.ir.declarations.IrEnumEntry import org.jetbrains.kotlin.ir.declarations.IrValueParameter import org.jetbrains.kotlin.ir.expressions.IrExpression +import org.jetbrains.kotlin.ir.expressions.IrGetEnumValue import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl import org.jetbrains.kotlin.ir.types.IrType import org.jetbrains.kotlin.ir.types.getClass @@ -29,6 +30,8 @@ import org.jetbrains.kotlin.ir.util.isEnumClass internal class Argument( private val original: IrElement, private val type: IrType, + /** True if this argument matches the default parameter value. */ + internal val isDefault: Boolean, internal val value: IrEnumEntry, ) { /** Returns an expression that looks up this argument. */ @@ -58,5 +61,13 @@ internal fun IrPluginContext.allPossibleArguments( val referenceClass = referenceClass(classId)?.owner ?: return null if (!referenceClass.isEnumClass) return null val enumEntries = referenceClass.declarations.filterIsInstance() - return enumEntries.map { Argument(parameter, parameter.type, it) } + val defaultValueSymbol = (parameter.defaultValue?.expression as? IrGetEnumValue)?.symbol + return enumEntries.map { + Argument( + original = parameter, + type = parameter.type, + isDefault = it.symbol == defaultValueSymbol, + value = it, + ) + } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt index 6cf53e5..efd7b3b 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt @@ -47,9 +47,6 @@ internal class BurstApis private constructor( val testClassSymbol: IrClassSymbol get() = pluginContext.referenceClass(testPackage.classId("Test"))!! - - val ignoreClassSymbol: IrClassSymbol - get() = pluginContext.referenceClass(testPackage.classId("Ignore"))!! } private val burstFqPackage = FqPackageName("app.cash.burst") diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt index 737742c..0c6afc2 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt @@ -46,7 +46,6 @@ class BurstIrGenerationExtension( if (classHasAtBurst) { ClassSpecializer( pluginContext = pluginContext, - burstApis = burstApis, originalParent = currentFile, original = classDeclaration, ).generateSpecializations() diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt index fe55b8f..84ee93b 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt @@ -38,15 +38,16 @@ import org.jetbrains.kotlin.name.Name * ``` * @Burst * class CoffeeTest( - * private val espresso: Espresso, - * private val dairy: Dairy, + * private val espresso: Espresso = Espresso.Regular, + * private val dairy: Dairy = Dairy.Milk, * ) { * ... * } * ``` * - * This opens the class, makes that constructor protected, and adds a default constructor that calls - * the first specialization: + * This opens the class, makes that constructor protected, and removes the default arguments. + * + * If there's a default specialization, it adds a no-args constructor that calls it: * * ``` * @Burst @@ -54,26 +55,26 @@ import org.jetbrains.kotlin.name.Name * private val espresso: Espresso, * private val dairy: Dairy, * ) { - * constructor() : this(Espresso.Decaf, Dairy.None) + * constructor() : this(Espresso.Regular, Dairy.Milk) * ... * } * ``` * - * And it generates a new test class for each specialization. The default specialization is also - * annotated `@Ignore`. + * If there is no default specialization this makes the test class abstract. + * + * And it generates a new test class for each non-default specialization. * * ``` - * @Ignore class CoffeeTest_Decaf_None : CoffeeTest(Espresso.Decaf, Dairy.None) + * class CoffeeTest_Decaf_None : CoffeeTest(Espresso.Decaf, Dairy.None) * class CoffeeTest_Decaf_Milk : CoffeeTest(Espresso.Decaf, Dairy.Milk) * class CoffeeTest_Decaf_Oat : CoffeeTest(Espresso.Decaf, Dairy.Oat) * class CoffeeTest_Regular_None : CoffeeTest(Espresso.Regular, Dairy.None) - * ... + * class CoffeeTest_Regular_Oat : CoffeeTest(Espresso.Regular, Dairy.Oat) * ``` */ @OptIn(UnsafeDuringIrConstructionAPI::class) internal class ClassSpecializer( private val pluginContext: IrPluginContext, - private val burstApis: BurstApis, private val originalParent: IrFile, private val original: IrClass, ) { @@ -92,25 +93,39 @@ internal class ClassSpecializer( } val cartesianProduct = parameterArguments.cartesianProduct() - val defaultSpecialization = cartesianProduct.first() - // Add @Ignore and open the class - // TODO: don't double-add @Ignore - original.modality = Modality.OPEN + val indexOfDefaultSpecialization = cartesianProduct.indexOfFirst { arguments -> + arguments.all { it.isDefault } + } + + // Make sure the constructor we're using is accessible. Drop the default arguments to prevent + // JUnit from using it. onlyConstructor.visibility = PROTECTED + for (valueParameter in onlyConstructor.valueParameters) { + valueParameter.defaultValue = null + } - // Add a no-args constructor that calls the only constructor as the default specialization. - createNoArgsConstructor( - superConstructor = onlyConstructor, - arguments = defaultSpecialization, - ) + if (indexOfDefaultSpecialization != -1) { + original.modality = Modality.OPEN + + // Add a no-args constructor that calls the only constructor as the default specialization. + createNoArgsConstructor( + superConstructor = onlyConstructor, + arguments = cartesianProduct[indexOfDefaultSpecialization], + ) + } else { + // There's no default specialization. Make the class abstract so JUnit skips it. + original.modality = Modality.ABSTRACT + } // Add a subclass for each specialization. - cartesianProduct.map { arguments -> + cartesianProduct.mapIndexed { index, arguments -> + // Don't generate code for the default specialization; we only want to run it once. + if (index == indexOfDefaultSpecialization) return@mapIndexed + createSpecialization( superConstructor = onlyConstructor, arguments = arguments, - isDefaultSpecialization = arguments == defaultSpecialization, ) } } @@ -118,7 +133,6 @@ internal class ClassSpecializer( private fun createSpecialization( superConstructor: IrConstructor, arguments: List, - isDefaultSpecialization: Boolean, ) { val specialization = original.factory.buildClass { initDefaults(original) @@ -129,10 +143,6 @@ internal class ClassSpecializer( createImplicitParameterDeclarationWithWrappedDescriptor() } - if (isDefaultSpecialization) { - specialization.annotations += burstApis.ignoreClassSymbol.asAnnotation() - } - specialization.addConstructor { initDefaults(original) }.apply { diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt index 1a83432..860fb89 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt @@ -35,34 +35,25 @@ import org.jetbrains.kotlin.name.Name * * ``` * @Test - * fun test(espresso: Espresso, dairy: Dairy) { + * fun test(espresso: Espresso = Espresso.Regular, dairy: Dairy = Dairy.Milk) { * ... * } * ``` * * This drops `@Test` from that test. * - * It generates a new function for each specialization. The default specialization is also annotated - * `@Ignore`. + * It generates a new function for each specialization. The default specialization gets the same + * name as the original test. * * ``` - * @Test @Ignore fun test_Decaf_None() { test(Espresso.Decaf, Dairy.None) } + * @Test fun test_Decaf_None() { test(Espresso.Decaf, Dairy.None) } * @Test fun test_Decaf_Milk() { test(Espresso.Decaf, Dairy.Milk) } * @Test fun test_Decaf_Oat() { test(Espresso.Decaf, Dairy.Oat) } * @Test fun test_Regular_Oat() { test(Espresso.Regular, Dairy.Oat) } - * @Test fun test_Regular_Milk() { test(Espresso.Regular, Dairy.Milk) } + * @Test fun test() { test(Espresso.Regular, Dairy.Milk) } * @Test fun test_Regular_None() { test(Espresso.Regular, Dairy.None) } * ``` * - * And it adds a new function that calls that default specialization. - * - * ``` - * @Test - * fun test() { - * test_Decaf_None() - * } - * ``` - * * This way, the default specialization is executed when you run the test in the IDE. */ @OptIn(UnsafeDuringIrConstructionAPI::class) @@ -86,11 +77,15 @@ internal class FunctionSpecializer( val cartesianProduct = parameterArguments.cartesianProduct() - val specializations = cartesianProduct.map { arguments -> + val indexOfDefaultSpecialization = cartesianProduct.indexOfFirst { arguments -> + arguments.all { it.isDefault } + } + + val specializations = cartesianProduct.mapIndexed { index, arguments -> createSpecialization( originalDispatchReceiver = originalDispatchReceiver, arguments = arguments, - isDefaultSpecialization = arguments == cartesianProduct.first(), + isDefaultSpecialization = index == indexOfDefaultSpecialization, ) } @@ -103,12 +98,6 @@ internal class FunctionSpecializer( for (specialization in specializations) { originalParent.addDeclaration(specialization) } - originalParent.addDeclaration( - createFunctionThatCallsDefaultSpecialization( - originalDispatchReceiver = originalDispatchReceiver, - defaultSpecialization = specializations.first(), - ), - ) } private fun createSpecialization( @@ -118,7 +107,10 @@ internal class FunctionSpecializer( ): IrSimpleFunction { val result = original.factory.buildFun { initDefaults(original) - name = Name.identifier(name("${original.name.identifier}_", arguments)) + name = when { + isDefaultSpecialization -> original.name + else -> Name.identifier(name("${original.name.identifier}_", arguments)) + } returnType = original.returnType }.apply { addDispatchReceiver { @@ -128,9 +120,6 @@ internal class FunctionSpecializer( } result.annotations += burstApis.testClassSymbol.asAnnotation() - if (isDefaultSpecialization) { - result.annotations += burstApis.ignoreClassSymbol.asAnnotation() - } result.irFunctionBody( context = pluginContext, @@ -156,44 +145,4 @@ internal class FunctionSpecializer( return result } - - /** Creates an @Test @Ignore no-args function that calls the default specialization. */ - private fun createFunctionThatCallsDefaultSpecialization( - originalDispatchReceiver: IrValueParameter, - defaultSpecialization: IrSimpleFunction, - ): IrSimpleFunction { - val result = original.factory.buildFun { - initDefaults(original) - name = original.name - returnType = original.returnType - }.apply { - addDispatchReceiver { - initDefaults(originalDispatchReceiver) - type = originalDispatchReceiver.type - } - } - - result.annotations += burstApis.testClassSymbol.asAnnotation() - - result.irFunctionBody( - context = pluginContext, - scopeOwnerSymbol = original.symbol, - ) { - val receiverLocal = irTemporary( - value = irGet(result.dispatchReceiverParameter!!), - nameHint = "receiver", - isMutable = false, - ).apply { - origin = IrDeclarationOrigin.DEFINED - } - - +irCall( - callee = defaultSpecialization.symbol, - ).apply { - this.dispatchReceiver = irGet(receiverLocal) - } - } - - return result - } }