From 3bd30d42c95998898f89e810153e37eb5a1b0681 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Tue, 15 Oct 2024 18:56:04 -0400 Subject: [PATCH] Prioritize the default specialization See https://github.com/cashapp/burst/issues/18 This PR doesn't choose the default specialization by inspecting parameter values. Instead it just takes the first specialization arbitrarily. --- .../burst/gradle/BurstGradlePluginTest.kt | 53 +++++++++++++------ .../burst/kotlin/BurstKotlinPluginTest.kt | 49 +++++++++-------- .../app/cash/burst/kotlin/ClassSpecializer.kt | 34 ++++++++---- .../cash/burst/kotlin/FunctionSpecializer.kt | 46 +++++++++------- 4 files changed, 112 insertions(+), 70 deletions(-) diff --git a/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt b/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt index b56dcff..32d0580 100644 --- a/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt +++ b/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt @@ -19,6 +19,7 @@ package app.cash.burst.gradle import assertk.assertThat import assertk.assertions.contains import assertk.assertions.containsExactlyInAnyOrder +import assertk.assertions.isEmpty import assertk.assertions.isEqualTo import assertk.assertions.isFalse import assertk.assertions.isTrue @@ -44,21 +45,24 @@ class BurstGradlePluginTest { assertThat(testSuite.testCases.map { it.name }).containsExactlyInAnyOrder( "test[jvm]", - "test_Decaf_Oat[jvm]", - "test_Regular_Milk[jvm]", - "test_Regular_None[jvm]", "test_Decaf_Milk[jvm]", "test_Decaf_None[jvm]", + "test_Decaf_Oat[jvm]", "test_Double_Milk[jvm]", "test_Double_None[jvm]", - "test_Regular_Oat[jvm]", "test_Double_Oat[jvm]", + "test_Regular_Milk[jvm]", + "test_Regular_None[jvm]", + "test_Regular_Oat[jvm]", ) val originalTest = testSuite.testCases.single { it.name == "test[jvm]" } - assertThat(originalTest.skipped).isTrue() + assertThat(originalTest.skipped).isFalse() + + val defaultSpecialization = testSuite.testCases.single { it.name == "test_Decaf_None[jvm]" } + assertThat(defaultSpecialization.skipped).isTrue() - val sampleSpecialization = testSuite.testCases.single { it.name == "test_Decaf_Oat[jvm]" } + val sampleSpecialization = testSuite.testCases.single { it.name == "test_Regular_Milk[jvm]" } assertThat(sampleSpecialization.skipped).isFalse() } @@ -78,22 +82,25 @@ class BurstGradlePluginTest { assertThat(testSuite.testCases.map { it.name }).containsExactlyInAnyOrder( "test", - "test_Decaf_Oat", - "test_Regular_Milk", - "test_Regular_None", "test_Decaf_Milk", "test_Decaf_None", + "test_Decaf_Oat", "test_Double_Milk", "test_Double_None", - "test_Regular_Oat", "test_Double_Oat", + "test_Regular_Milk", + "test_Regular_None", + "test_Regular_Oat", ) val originalTest = testSuite.testCases.single { it.name == "test" } - assertThat(originalTest.skipped).isTrue() + assertThat(originalTest.skipped).isFalse() + + val defaultSpecialization = testSuite.testCases.single { it.name == "test_Decaf_None" } + assertThat(defaultSpecialization.skipped).isTrue() - val sampleVariant = testSuite.testCases.single { it.name == "test_Decaf_Oat" } - assertThat(sampleVariant.skipped).isFalse() + val sampleSpecialization = testSuite.testCases.single { it.name == "test_Regular_Milk" } + assertThat(sampleSpecialization.skipped).isFalse() } @Test @@ -110,16 +117,28 @@ class BurstGradlePluginTest { val coffeeTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest.xml")) val coffeeTestTest = coffeeTest.testCases.single() assertThat(coffeeTestTest.name).isEqualTo("test") - assertThat(coffeeTestTest.skipped).isTrue() + assertThat(coffeeTest.systemOut).isEqualTo( + """ + |set up Decaf None + |running Decaf None + | + """.trimMargin(), + ) - val sampleTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest_Decaf_None.xml")) + val defaultTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest_Decaf_None.xml")) + val defaultTestTest = defaultTest.testCases.single() + assertThat(defaultTestTest.name).isEqualTo("test") + assertThat(defaultTestTest.skipped).isTrue() + assertThat(defaultTest.systemOut).isEmpty() + + val sampleTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest_Regular_Milk.xml")) val sampleTestTest = sampleTest.testCases.single() assertThat(sampleTestTest.name).isEqualTo("test") assertThat(sampleTestTest.skipped).isFalse() assertThat(sampleTest.systemOut).isEqualTo( """ - |set up Decaf None - |running Decaf None + |set up Regular Milk + |running Regular Milk | """.trimMargin(), ) diff --git a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt index 6f1f14a..fc687f2 100644 --- a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt +++ b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt @@ -33,7 +33,7 @@ import org.jetbrains.kotlin.compiler.plugin.ExperimentalCompilerApi @OptIn(ExperimentalCompilerApi::class) class BurstKotlinPluginTest { @Test - fun happyPath() { + fun functionParameters() { val result = compile( sourceFile = SourceFile.kotlin( "CoffeeTest.kt", @@ -71,29 +71,27 @@ class BurstKotlinPluginTest { assertThat(originalTest.isAnnotationPresent(Test::class.java)).isFalse() // Burst adds a specialization for each combination of parameters. - val sampleFunction = testClass.getMethod("test_Decaf_None") - assertThat(sampleFunction.isAnnotationPresent(Test::class.java)).isTrue() - assertThat(sampleFunction.isAnnotationPresent(Ignore::class.java)).isFalse() - sampleFunction.invoke(adapterInstance) + val sampleSpecialization = testClass.getMethod("test_Regular_Milk") + assertThat(sampleSpecialization.isAnnotationPresent(Test::class.java)).isTrue() + assertThat(sampleSpecialization.isAnnotationPresent(Ignore::class.java)).isFalse() + sampleSpecialization.invoke(adapterInstance) + assertThat(log).containsExactly("running Regular Milk") + log.clear() + + // The first specialization is also annotated `@Ignore`. + val firstSpecialization = testClass.getMethod("test_Decaf_None") + assertThat(firstSpecialization.isAnnotationPresent(Test::class.java)).isTrue() + assertThat(firstSpecialization.isAnnotationPresent(Ignore::class.java)).isTrue() + firstSpecialization.invoke(adapterInstance) assertThat(log).containsExactly("running Decaf None") log.clear() - // Burst adds a no-parameter function that calls each specialization in sequence. + // Burst adds a no-parameter function that calls the first specialization. val noArgsTest = testClass.getMethod("test") assertThat(noArgsTest.isAnnotationPresent(Test::class.java)).isTrue() - assertThat(noArgsTest.isAnnotationPresent(Ignore::class.java)).isTrue() + assertThat(noArgsTest.isAnnotationPresent(Ignore::class.java)).isFalse() noArgsTest.invoke(adapterInstance) - assertThat(log).containsExactly( - "running Decaf None", - "running Decaf Milk", - "running Decaf Oat", - "running Regular None", - "running Regular Milk", - "running Regular Oat", - "running Double None", - "running Double Milk", - "running Double Oat", - ) + assertThat(log).containsExactly("running Decaf None") } @Test @@ -156,9 +154,9 @@ class BurstKotlinPluginTest { val baseClass = result.classLoader.loadClass("CoffeeTest") - // Burst opens the class because it needs to subclass it. And it marks the entire class @Ignore. + // Burst opens the class because it needs to subclass it. assertThat(Modifier.isFinal(baseClass.modifiers)).isFalse() - assertThat(baseClass.isAnnotationPresent(Ignore::class.java)).isTrue() + assertThat(baseClass.isAnnotationPresent(Ignore::class.java)).isFalse() // Burst adds a no-args constructor that binds the first enum value. val baseConstructor = baseClass.constructors.single { it.parameterCount == 0 } @@ -176,7 +174,8 @@ class BurstKotlinPluginTest { baseLog.clear() // It generates a subclass for each specialization. - val sampleClass = result.classLoader.loadClass("CoffeeTest_Regular_Oat") + val sampleClass = result.classLoader.loadClass("CoffeeTest_Regular_Milk") + assertThat(sampleClass.isAnnotationPresent(Ignore::class.java)).isFalse() val sampleConstructor = sampleClass.getConstructor() val sampleInstance = sampleConstructor.newInstance() val sampleLog = sampleClass.getMethod("getLog") @@ -184,10 +183,14 @@ class BurstKotlinPluginTest { sampleClass.getMethod("setUp").invoke(sampleInstance) sampleClass.getMethod("test").invoke(sampleInstance) assertThat(sampleLog).containsExactly( - "set up Regular Oat", - "running Regular Oat", + "set up Regular Milk", + "running Regular Milk", ) sampleLog.clear() + + // The default specialization is annotated `@Ignore`. + val defaultClass = result.classLoader.loadClass("CoffeeTest_Decaf_None") + assertThat(defaultClass.isAnnotationPresent(Ignore::class.java)).isTrue() } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt index 3296025..23d1c31 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt @@ -16,6 +16,7 @@ package app.cash.burst.kotlin import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext +import org.jetbrains.kotlin.descriptors.DescriptorVisibilities import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.ir.builders.declarations.addConstructor import org.jetbrains.kotlin.ir.builders.declarations.buildClass @@ -43,13 +44,12 @@ import org.jetbrains.kotlin.name.Name * } * ``` * - * This opens the class, adds `@Ignore`, and adds a default constructor that calls the first - * specialization: + * This opens the class, makes that constructor protected, and adds a default constructor that calls + * the first specialization: * * ``` * @Burst - * @Ignore - * open class CoffeeTest( + * open class CoffeeTest protected constructor( * private val espresso: Espresso, * private val dairy: Dairy, * ) { @@ -58,10 +58,11 @@ import org.jetbrains.kotlin.name.Name * } * ``` * - * And it generates a new test class for each specialization: + * And it generates a new test class for each specialization. The default specialization is also + * annotated `@Ignore`. * * ``` - * class CoffeeTest_Decaf_None : CoffeeTest(Espresso.Decaf, Dairy.None) + * @Ignore class CoffeeTest_Decaf_None : CoffeeTest(Espresso.Decaf, Dairy.None) * class CoffeeTest_Decaf_Milk : CoffeeTest(Espresso.Decaf, Dairy.Milk) * class CoffeeTest_Decaf_Oat : CoffeeTest(Espresso.Decaf, Dairy.Oat) * class CoffeeTest_Regular_None : CoffeeTest(Espresso.Regular, Dairy.None) @@ -90,24 +91,33 @@ internal class ClassSpecializer( } val cartesianProduct = parameterArguments.cartesianProduct() + val defaultSpecialization = cartesianProduct.first() // Add @Ignore and open the class // TODO: don't double-add @Ignore - original.annotations += burstApis.ignoreClassSymbol.asAnnotation() original.modality = Modality.OPEN + onlyConstructor.visibility = DescriptorVisibilities.PROTECTED - // Add a no-args constructor that calls the only constructor. - createNoArgsConstructor(onlyConstructor, cartesianProduct.first()) + // Add a no-args constructor that calls the only constructor as the default specialization. + createNoArgsConstructor( + superConstructor = onlyConstructor, + arguments = defaultSpecialization, + ) // Add a subclass for each specialization. cartesianProduct.map { arguments -> - createSpecialization(onlyConstructor, arguments) + createSpecialization( + superConstructor = onlyConstructor, + arguments = arguments, + isDefaultSpecialization = arguments == defaultSpecialization, + ) } } private fun createSpecialization( superConstructor: IrConstructor, arguments: List, + isDefaultSpecialization: Boolean, ) { val specialization = original.factory.buildClass { initDefaults(original) @@ -117,6 +127,10 @@ internal class ClassSpecializer( createImplicitParameterDeclarationWithWrappedDescriptor() } + if (isDefaultSpecialization) { + specialization.annotations += burstApis.ignoreClassSymbol.asAnnotation() + } + specialization.addConstructor { initDefaults(original) }.apply { diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt index a40a027..cc81d31 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt @@ -42,10 +42,11 @@ import org.jetbrains.kotlin.name.Name * * This drops `@Test` from that test. * - * It generates a new function for each specialization: + * It generates a new function for each specialization. The default specialization is also annotated + * `@Ignore`. * * ``` - * @Test fun test_Decaf_None() { test(Espresso.Decaf, Dairy.None) } + * @Test @Ignore fun test_Decaf_None() { test(Espresso.Decaf, Dairy.None) } * @Test fun test_Decaf_Milk() { test(Espresso.Decaf, Dairy.Milk) } * @Test fun test_Decaf_Oat() { test(Espresso.Decaf, Dairy.Oat) } * @Test fun test_Regular_Oat() { test(Espresso.Regular, Dairy.Oat) } @@ -53,19 +54,16 @@ import org.jetbrains.kotlin.name.Name * @Test fun test_Regular_None() { test(Espresso.Regular, Dairy.None) } * ``` * - * And it adds a new function that calls each specialization. + * And it adds a new function that calls that default specialization. * * ``` * @Test - * @Ignore * fun test() { * test_Decaf_None() - * test_Decaf_Milk() - * test_Decaf_Oat() - * test_Regular_Oat() - * ... * } * ``` + * + * This way, the default specialization is executed when you run the test in the IDE. */ @OptIn(UnsafeDuringIrConstructionAPI::class) internal class FunctionSpecializer( @@ -89,7 +87,11 @@ internal class FunctionSpecializer( val cartesianProduct = parameterArguments.cartesianProduct() val specializations = cartesianProduct.map { arguments -> - createSpecialization(originalDispatchReceiver, arguments) + createSpecialization( + originalDispatchReceiver = originalDispatchReceiver, + arguments = arguments, + isDefaultSpecialization = arguments == cartesianProduct.first(), + ) } // Drop `@Test` from the original's annotations. @@ -102,13 +104,17 @@ internal class FunctionSpecializer( originalParent.addDeclaration(specialization) } originalParent.addDeclaration( - createFunctionThatCallsAllSpecializations(originalDispatchReceiver, specializations), + createFunctionThatCallsDefaultSpecialization( + originalDispatchReceiver = originalDispatchReceiver, + defaultSpecialization = specializations.first() + ), ) } private fun createSpecialization( originalDispatchReceiver: IrValueParameter, arguments: List, + isDefaultSpecialization: Boolean, ): IrSimpleFunction { val result = original.factory.buildFun { initDefaults(original) @@ -122,6 +128,9 @@ internal class FunctionSpecializer( } result.annotations += burstApis.testClassSymbol.asAnnotation() + if (isDefaultSpecialization) { + result.annotations += burstApis.ignoreClassSymbol.asAnnotation() + } result.irFunctionBody( context = pluginContext, @@ -148,10 +157,10 @@ internal class FunctionSpecializer( return result } - /** Creates an @Test @Ignore no-args function that calls each specialization. */ - private fun createFunctionThatCallsAllSpecializations( + /** Creates an @Test @Ignore no-args function that calls the default specialization. */ + private fun createFunctionThatCallsDefaultSpecialization( originalDispatchReceiver: IrValueParameter, - specializations: List, + defaultSpecialization: IrSimpleFunction, ): IrSimpleFunction { val result = original.factory.buildFun { initDefaults(original) @@ -165,7 +174,6 @@ internal class FunctionSpecializer( } result.annotations += burstApis.testClassSymbol.asAnnotation() - result.annotations += burstApis.ignoreClassSymbol.asAnnotation() result.irFunctionBody( context = pluginContext, @@ -179,12 +187,10 @@ internal class FunctionSpecializer( origin = IrDeclarationOrigin.DEFINED } - for (specialization in specializations) { - +irCall( - callee = specialization.symbol, - ).apply { - this.dispatchReceiver = irGet(receiverLocal) - } + +irCall( + callee = defaultSpecialization.symbol, + ).apply { + this.dispatchReceiver = irGet(receiverLocal) } }