From 8d953c848c2ee42bbbabe852008ee3e74ce70d50 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Wed, 9 Oct 2024 11:19:42 -0700 Subject: [PATCH] It works --- .../burst/kotlin/BurstKotlinPluginTest.kt | 67 +++++++- .../kotlin/app/cash/burst/kotlin/BurstApis.kt | 33 ++-- .../app/cash/burst/kotlin/BurstRewriter.kt | 162 ++++++++++++++++-- .../app/cash/burst/kotlin/cartesianProduct.kt | 27 +++ 4 files changed, 253 insertions(+), 36 deletions(-) create mode 100644 burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/cartesianProduct.kt diff --git a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt index 6e310f0..5f2ce68 100644 --- a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt +++ b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt @@ -16,10 +16,13 @@ package app.cash.burst.kotlin import assertk.assertThat -import assertk.assertions.isNotNull +import assertk.assertions.containsExactly +import assertk.assertions.isFalse +import assertk.assertions.isTrue import com.tschuchort.compiletesting.JvmCompilationResult import com.tschuchort.compiletesting.KotlinCompilation import com.tschuchort.compiletesting.SourceFile +import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals import org.jetbrains.kotlin.compiler.plugin.CompilerPluginRegistrar @@ -31,27 +34,77 @@ class BurstKotlinPluginTest { fun happyPath() { val result = compile( sourceFile = SourceFile.kotlin( - "LunchTest.kt", + "CoffeeTest.kt", """ package app.cash.burst.testing import app.cash.burst.Burst + import kotlin.test.Ignore import kotlin.test.Test @Burst - class LunchTest { + class CoffeeTest { + val log = mutableListOf() + + @Test + fun test(espresso: Espresso, dairy: Dairy) { + log += "running ${'$'}espresso ${'$'}dairy" + } + + // Generate this + @Test + @Ignore + fun x_test() { + x_test_Decaf_None() + } + + // Generate this @Test - fun test() { + fun x_test_Decaf_None() { + test(Espresso.Decaf, Dairy.None) } } + + enum class Espresso { Decaf, Regular, Double } + enum class Dairy { None, Milk, Oat } """, ), ) assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) - val adapterClass = result.classLoader.loadClass("app.cash.burst.testing.LunchTest") - assertThat(adapterClass).isNotNull() - assertThat(adapterClass.getMethod("test_2")).isNotNull() + val adapterClass = result.classLoader.loadClass("app.cash.burst.testing.CoffeeTest") + val adapterInstance = adapterClass.constructors.single().newInstance() + val log = adapterClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*> + + // Burst adds @Ignore to the original test. + val originalTest = adapterClass.methods.single { it.name == "test" && it.parameterCount == 2 } + assertThat(originalTest.isAnnotationPresent(Test::class.java)).isTrue() + assertThat(originalTest.isAnnotationPresent(Ignore::class.java)).isTrue() + + // Burst adds a variant for each combination of parameters. + val sampleVariant = adapterClass.getMethod("test_Decaf_None") + assertThat(sampleVariant.isAnnotationPresent(Test::class.java)).isTrue() + assertThat(sampleVariant.isAnnotationPresent(Ignore::class.java)).isFalse() + sampleVariant.invoke(adapterInstance) + assertThat(log).containsExactly("running Decaf None") + log.clear() + + // Burst adds a no-parameter function that calls each variant in sequence. + val noArgsTest = adapterClass.getMethod("test") + assertThat(noArgsTest.isAnnotationPresent(Test::class.java)).isTrue() + assertThat(noArgsTest.isAnnotationPresent(Ignore::class.java)).isTrue() + noArgsTest.invoke(adapterInstance) + assertThat(log).containsExactly( + "running Decaf None", + "running Decaf Milk", + "running Decaf Oat", + "running Regular None", + "running Regular Milk", + "running Regular Oat", + "running Double None", + "running Double Milk", + "running Double Oat", + ) } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt index 30d019e..e629f10 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstApis.kt @@ -23,32 +23,45 @@ import org.jetbrains.kotlin.ir.util.hasAnnotation /** Looks up APIs used by the code rewriters. */ internal class BurstApis private constructor( private val pluginContext: IrPluginContext, + private val testPackage: FqPackageName, ) { companion object { fun maybeCreate(pluginContext: IrPluginContext): BurstApis? { + // If we don't have @Burst, we don't have the runtime. Abort! if (pluginContext.referenceClass(burstAnnotationClassId) == null) { - // If we don't have @Burst, we don't have the runtime. Abort! return null } - return BurstApis(pluginContext) + + if (pluginContext.referenceClass(junitTestClassId) != null) { + return BurstApis(pluginContext, junitPackage) + } + + if (pluginContext.referenceClass(kotlinTestClassId) != null) { + return BurstApis(pluginContext, kotlinTestPackage) + } + + // No kotlin.test and no org.junit. No Burst for you. + return null } } - val burstAnnotationClassSymbol: IrClassSymbol - get() = pluginContext.referenceClass(burstAnnotationClassId)!! + val testClassSymbol: IrClassSymbol + get() = pluginContext.referenceClass(testPackage.classId("Test"))!! + + val ignoreClassSymbol: IrClassSymbol + get() = pluginContext.referenceClass(testPackage.classId("Ignore"))!! } private val burstFqPackage = FqPackageName("app.cash.burst") private val burstAnnotationClassId = burstFqPackage.classId("Burst") -private val kotlinTestFqPackage = FqPackageName("kotlin.test") -private val kotlinTestTestClassId = kotlinTestFqPackage.classId("Test") - -private val orgJunitFqPackage = FqPackageName("org.junit") -private val orgJunitTestClassId = orgJunitFqPackage.classId("Test") +val junitPackage = FqPackageName("org.junit") +val junitTestClassId = junitPackage.classId("Test") +val kotlinTestPackage = FqPackageName("kotlin.text") +val kotlinTestClassId = kotlinTestPackage.classId("Test") internal val IrAnnotationContainer.hasAtTest: Boolean - get() = hasAnnotation(orgJunitTestClassId) || hasAnnotation(kotlinTestTestClassId) + get() = hasAnnotation(junitTestClassId) || hasAnnotation(kotlinTestClassId) internal val IrAnnotationContainer.hasAtBurst: Boolean get() = hasAnnotation(burstAnnotationClassId) diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstRewriter.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstRewriter.kt index c9a66de..2f29a9b 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstRewriter.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstRewriter.kt @@ -22,14 +22,27 @@ import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.ir.builders.declarations.buildFun import org.jetbrains.kotlin.ir.builders.irCall import org.jetbrains.kotlin.ir.builders.irGet -import org.jetbrains.kotlin.ir.builders.irReturn import org.jetbrains.kotlin.ir.builders.irTemporary import org.jetbrains.kotlin.ir.declarations.IrDeclaration import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin +import org.jetbrains.kotlin.ir.declarations.IrEnumEntry import org.jetbrains.kotlin.ir.declarations.IrFile import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction +import org.jetbrains.kotlin.ir.declarations.IrValueParameter +import org.jetbrains.kotlin.ir.expressions.IrConstructorCall +import org.jetbrains.kotlin.ir.expressions.IrExpression +import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl +import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl +import org.jetbrains.kotlin.ir.symbols.IrClassSymbol +import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI +import org.jetbrains.kotlin.ir.types.IrType +import org.jetbrains.kotlin.ir.types.getClass +import org.jetbrains.kotlin.ir.types.starProjectedType +import org.jetbrains.kotlin.ir.util.classId +import org.jetbrains.kotlin.ir.util.constructors import org.jetbrains.kotlin.name.Name +@OptIn(UnsafeDuringIrConstructionAPI::class) internal class BurstRewriter( private val messageCollector: MessageCollector, private val pluginContext: IrPluginContext, @@ -39,6 +52,11 @@ internal class BurstRewriter( ) { /** Returns a list of additional declarations. */ fun rewrite(): List { + val originalValueParameters = original.valueParameters + if (originalValueParameters.isEmpty()) { + return listOf() + } + val originalDispatchReceiver = original.dispatchReceiverParameter if (originalDispatchReceiver == null) { messageCollector.report( @@ -49,47 +67,153 @@ internal class BurstRewriter( return listOf() } - val expansion = original.factory.buildFun { + val parameterArguments = mutableListOf>() + for (parameter in originalValueParameters) { + val expanded = parameter.allPossibleArguments() + if (expanded == null) { + messageCollector.report( + CompilerMessageSeverity.ERROR, + "Expected an enum for @Burst test parameter", + file.locationOf(parameter), + ) + return listOf() + } + parameterArguments += expanded + } + + val cartesianProduct = parameterArguments.cartesianProduct() + + val variants = cartesianProduct.map { variantArguments -> + createVariant(originalDispatchReceiver, variantArguments) + } + + // Side-effect: add `@Ignore` + // TODO: if its' absent! + original.annotations += burstApis.ignoreClassSymbol.asAnnotation() + + val result = mutableListOf() + result += createFunctionThatCallsAllVariants(originalDispatchReceiver, variants) + result += variants + return result + } + + private fun createVariant( + originalDispatchReceiver: IrValueParameter, + arguments: List, + ): IrSimpleFunction { + val result = original.factory.buildFun { initDefaults(original) - name = Name.identifier("${original.name.identifier}_2") + name = Name.identifier(name(arguments)) + returnType = original.returnType }.apply { addDispatchReceiver { initDefaults(originalDispatchReceiver) type = originalDispatchReceiver.type } + } + + result.annotations += burstApis.testClassSymbol.asAnnotation() + + result.irFunctionBody( + context = pluginContext, + scopeOwnerSymbol = original.symbol, + ) { + val receiverLocal = irTemporary( + value = irGet(result.dispatchReceiverParameter!!), + nameHint = "receiver", + isMutable = false, + ).apply { + origin = IrDeclarationOrigin.DEFINED + } + + +irCall( + callee = original.symbol, + ).apply { + this.dispatchReceiver = irGet(receiverLocal) + for ((index, argument) in arguments.withIndex()) { + putValueArgument(index, argument.get()) + } + } + } + + return result + } + + /** Creates a function with no arguments that calls each variant. */ + private fun createFunctionThatCallsAllVariants( + originalDispatchReceiver: IrValueParameter, + variants: List, + ): IrSimpleFunction { + val result = original.factory.buildFun { + initDefaults(original) + name = original.name returnType = original.returnType -// addValueParameter { -// initDefaults(original) -// name = Name.identifier("callHandler") -// type = ziplineApis.outboundCallHandler.defaultType -// } -// overriddenSymbols = listOf(ziplineApis.ziplineServiceAdapterOutboundService) + }.apply { + addDispatchReceiver { + initDefaults(originalDispatchReceiver) + type = originalDispatchReceiver.type + } } - expansion.irFunctionBody( + result.annotations += burstApis.testClassSymbol.asAnnotation() + result.annotations += burstApis.ignoreClassSymbol.asAnnotation() + + result.irFunctionBody( context = pluginContext, scopeOwnerSymbol = original.symbol, ) { val receiverLocal = irTemporary( - value = irGet(expansion.dispatchReceiverParameter!!), + value = irGet(result.dispatchReceiverParameter!!), nameHint = "receiver", isMutable = false, ).apply { origin = IrDeclarationOrigin.DEFINED } - +irReturn( - irCall( - callee = original.symbol, + for (variant in variants) { + +irCall( + callee = variant.symbol, ).apply { this.dispatchReceiver = irGet(receiverLocal) + } + } + } -// putValueArgument(0, irGet(outboundServiceFunction.valueParameters[0])) -// type = bridgedInterfaceT - }, - ) + return result + } + + private inner class Argument( + val type: IrType, + val value: IrEnumEntry, + ) + + /** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */ + private fun name(arguments: List): String { + return arguments.joinToString( + prefix = "${original.name.identifier}_", + separator = "_", + ) { argument -> + argument.value.name.identifier } + } + + /** Returns an expression that looks up this argument. */ + private fun Argument.get(): IrExpression { + return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol) + } + + /** Returns null if we can't compute all possible arguments for this parameter. */ + private fun IrValueParameter.allPossibleArguments(): List? { + val classId = type.getClass()?.classId ?: return null + val referenceClass = pluginContext.referenceClass(classId)?.owner ?: return null + val enumEntries = referenceClass.declarations.filterIsInstance() + return enumEntries.map { Argument(type, it) } + } - return listOf(expansion) + private fun IrClassSymbol.asAnnotation(): IrConstructorCall { + return IrConstructorCallImpl.fromSymbolOwner( + type = starProjectedType, + constructorSymbol = constructors.single(), + ) } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/cartesianProduct.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/cartesianProduct.kt new file mode 100644 index 0000000..043215d --- /dev/null +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/cartesianProduct.kt @@ -0,0 +1,27 @@ +/* + * Copyright 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package app.cash.burst.kotlin + +/** + * Given an iterable lists like `[[A, B, C], [1, 2, 3]]`, return the cartesian product like + * `[[A, 1], [A, 2], [A, 3], [B, 1], [B, 2], [B, 3], [C, 1], [C, 2], [C, 3]]`. + */ +fun Iterable>.cartesianProduct(): List> { + return fold(listOf(listOf())) { partials, list -> + partials.flatMap { partial -> list.map { element -> partial + element } } + } +}