Skip to content

Commit

Permalink
It works
Browse files Browse the repository at this point in the history
  • Loading branch information
squarejesse committed Oct 9, 2024
1 parent 2f6634a commit 8d953c8
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package app.cash.burst.kotlin

import assertk.assertThat
import assertk.assertions.isNotNull
import assertk.assertions.containsExactly
import assertk.assertions.isFalse
import assertk.assertions.isTrue
import com.tschuchort.compiletesting.JvmCompilationResult
import com.tschuchort.compiletesting.KotlinCompilation
import com.tschuchort.compiletesting.SourceFile
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertEquals
import org.jetbrains.kotlin.compiler.plugin.CompilerPluginRegistrar
Expand All @@ -31,27 +34,77 @@ class BurstKotlinPluginTest {
fun happyPath() {
val result = compile(
sourceFile = SourceFile.kotlin(
"LunchTest.kt",
"CoffeeTest.kt",
"""
package app.cash.burst.testing
import app.cash.burst.Burst
import kotlin.test.Ignore
import kotlin.test.Test
@Burst
class LunchTest {
class CoffeeTest {
val log = mutableListOf<String>()
@Test
fun test(espresso: Espresso, dairy: Dairy) {
log += "running ${'$'}espresso ${'$'}dairy"
}
// Generate this
@Test
@Ignore
fun x_test() {
x_test_Decaf_None()
}
// Generate this
@Test
fun test() {
fun x_test_Decaf_None() {
test(Espresso.Decaf, Dairy.None)
}
}
enum class Espresso { Decaf, Regular, Double }
enum class Dairy { None, Milk, Oat }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val adapterClass = result.classLoader.loadClass("app.cash.burst.testing.LunchTest")
assertThat(adapterClass).isNotNull()
assertThat(adapterClass.getMethod("test_2")).isNotNull()
val adapterClass = result.classLoader.loadClass("app.cash.burst.testing.CoffeeTest")
val adapterInstance = adapterClass.constructors.single().newInstance()
val log = adapterClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*>

// Burst adds @Ignore to the original test.
val originalTest = adapterClass.methods.single { it.name == "test" && it.parameterCount == 2 }
assertThat(originalTest.isAnnotationPresent(Test::class.java)).isTrue()
assertThat(originalTest.isAnnotationPresent(Ignore::class.java)).isTrue()

// Burst adds a variant for each combination of parameters.
val sampleVariant = adapterClass.getMethod("test_Decaf_None")
assertThat(sampleVariant.isAnnotationPresent(Test::class.java)).isTrue()
assertThat(sampleVariant.isAnnotationPresent(Ignore::class.java)).isFalse()
sampleVariant.invoke(adapterInstance)
assertThat(log).containsExactly("running Decaf None")
log.clear()

// Burst adds a no-parameter function that calls each variant in sequence.
val noArgsTest = adapterClass.getMethod("test")
assertThat(noArgsTest.isAnnotationPresent(Test::class.java)).isTrue()
assertThat(noArgsTest.isAnnotationPresent(Ignore::class.java)).isTrue()
noArgsTest.invoke(adapterInstance)
assertThat(log).containsExactly(
"running Decaf None",
"running Decaf Milk",
"running Decaf Oat",
"running Regular None",
"running Regular Milk",
"running Regular Oat",
"running Double None",
"running Double Milk",
"running Double Oat",
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,45 @@ import org.jetbrains.kotlin.ir.util.hasAnnotation
/** Looks up APIs used by the code rewriters. */
internal class BurstApis private constructor(
private val pluginContext: IrPluginContext,
private val testPackage: FqPackageName,
) {
companion object {
fun maybeCreate(pluginContext: IrPluginContext): BurstApis? {
// If we don't have @Burst, we don't have the runtime. Abort!
if (pluginContext.referenceClass(burstAnnotationClassId) == null) {
// If we don't have @Burst, we don't have the runtime. Abort!
return null
}
return BurstApis(pluginContext)

if (pluginContext.referenceClass(junitTestClassId) != null) {
return BurstApis(pluginContext, junitPackage)
}

if (pluginContext.referenceClass(kotlinTestClassId) != null) {
return BurstApis(pluginContext, kotlinTestPackage)
}

// No kotlin.test and no org.junit. No Burst for you.
return null
}
}

val burstAnnotationClassSymbol: IrClassSymbol
get() = pluginContext.referenceClass(burstAnnotationClassId)!!
val testClassSymbol: IrClassSymbol
get() = pluginContext.referenceClass(testPackage.classId("Test"))!!

val ignoreClassSymbol: IrClassSymbol
get() = pluginContext.referenceClass(testPackage.classId("Ignore"))!!
}

private val burstFqPackage = FqPackageName("app.cash.burst")
private val burstAnnotationClassId = burstFqPackage.classId("Burst")

private val kotlinTestFqPackage = FqPackageName("kotlin.test")
private val kotlinTestTestClassId = kotlinTestFqPackage.classId("Test")

private val orgJunitFqPackage = FqPackageName("org.junit")
private val orgJunitTestClassId = orgJunitFqPackage.classId("Test")
val junitPackage = FqPackageName("org.junit")
val junitTestClassId = junitPackage.classId("Test")
val kotlinTestPackage = FqPackageName("kotlin.text")
val kotlinTestClassId = kotlinTestPackage.classId("Test")

internal val IrAnnotationContainer.hasAtTest: Boolean
get() = hasAnnotation(orgJunitTestClassId) || hasAnnotation(kotlinTestTestClassId)
get() = hasAnnotation(junitTestClassId) || hasAnnotation(kotlinTestClassId)

internal val IrAnnotationContainer.hasAtBurst: Boolean
get() = hasAnnotation(burstAnnotationClassId)
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,27 @@ import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irReturn
import org.jetbrains.kotlin.ir.builders.irTemporary
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.expressions.IrConstructorCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.types.starProjectedType
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.constructors
import org.jetbrains.kotlin.name.Name

@OptIn(UnsafeDuringIrConstructionAPI::class)
internal class BurstRewriter(
private val messageCollector: MessageCollector,
private val pluginContext: IrPluginContext,
Expand All @@ -39,6 +52,11 @@ internal class BurstRewriter(
) {
/** Returns a list of additional declarations. */
fun rewrite(): List<IrDeclaration> {
val originalValueParameters = original.valueParameters
if (originalValueParameters.isEmpty()) {
return listOf()
}

val originalDispatchReceiver = original.dispatchReceiverParameter
if (originalDispatchReceiver == null) {
messageCollector.report(
Expand All @@ -49,47 +67,153 @@ internal class BurstRewriter(
return listOf()
}

val expansion = original.factory.buildFun {
val parameterArguments = mutableListOf<List<Argument>>()
for (parameter in originalValueParameters) {
val expanded = parameter.allPossibleArguments()
if (expanded == null) {
messageCollector.report(
CompilerMessageSeverity.ERROR,
"Expected an enum for @Burst test parameter",
file.locationOf(parameter),
)
return listOf()
}
parameterArguments += expanded
}

val cartesianProduct = parameterArguments.cartesianProduct()

val variants = cartesianProduct.map { variantArguments ->
createVariant(originalDispatchReceiver, variantArguments)
}

// Side-effect: add `@Ignore`
// TODO: if its' absent!
original.annotations += burstApis.ignoreClassSymbol.asAnnotation()

val result = mutableListOf<IrDeclaration>()
result += createFunctionThatCallsAllVariants(originalDispatchReceiver, variants)
result += variants
return result
}

private fun createVariant(
originalDispatchReceiver: IrValueParameter,
arguments: List<Argument>,
): IrSimpleFunction {
val result = original.factory.buildFun {
initDefaults(original)
name = Name.identifier("${original.name.identifier}_2")
name = Name.identifier(name(arguments))
returnType = original.returnType
}.apply {
addDispatchReceiver {
initDefaults(originalDispatchReceiver)
type = originalDispatchReceiver.type
}
}

result.annotations += burstApis.testClassSymbol.asAnnotation()

result.irFunctionBody(
context = pluginContext,
scopeOwnerSymbol = original.symbol,
) {
val receiverLocal = irTemporary(
value = irGet(result.dispatchReceiverParameter!!),
nameHint = "receiver",
isMutable = false,
).apply {
origin = IrDeclarationOrigin.DEFINED
}

+irCall(
callee = original.symbol,
).apply {
this.dispatchReceiver = irGet(receiverLocal)
for ((index, argument) in arguments.withIndex()) {
putValueArgument(index, argument.get())
}
}
}

return result
}

/** Creates a function with no arguments that calls each variant. */
private fun createFunctionThatCallsAllVariants(
originalDispatchReceiver: IrValueParameter,
variants: List<IrSimpleFunction>,
): IrSimpleFunction {
val result = original.factory.buildFun {
initDefaults(original)
name = original.name
returnType = original.returnType
// addValueParameter {
// initDefaults(original)
// name = Name.identifier("callHandler")
// type = ziplineApis.outboundCallHandler.defaultType
// }
// overriddenSymbols = listOf(ziplineApis.ziplineServiceAdapterOutboundService)
}.apply {
addDispatchReceiver {
initDefaults(originalDispatchReceiver)
type = originalDispatchReceiver.type
}
}

expansion.irFunctionBody(
result.annotations += burstApis.testClassSymbol.asAnnotation()
result.annotations += burstApis.ignoreClassSymbol.asAnnotation()

result.irFunctionBody(
context = pluginContext,
scopeOwnerSymbol = original.symbol,
) {
val receiverLocal = irTemporary(
value = irGet(expansion.dispatchReceiverParameter!!),
value = irGet(result.dispatchReceiverParameter!!),
nameHint = "receiver",
isMutable = false,
).apply {
origin = IrDeclarationOrigin.DEFINED
}

+irReturn(
irCall(
callee = original.symbol,
for (variant in variants) {
+irCall(
callee = variant.symbol,
).apply {
this.dispatchReceiver = irGet(receiverLocal)
}
}
}

// putValueArgument(0, irGet(outboundServiceFunction.valueParameters[0]))
// type = bridgedInterfaceT
},
)
return result
}

private inner class Argument(
val type: IrType,
val value: IrEnumEntry,
)

/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */
private fun name(arguments: List<Argument>): String {
return arguments.joinToString(
prefix = "${original.name.identifier}_",
separator = "_",
) { argument ->
argument.value.name.identifier
}
}

/** Returns an expression that looks up this argument. */
private fun Argument.get(): IrExpression {
return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
}

/** Returns null if we can't compute all possible arguments for this parameter. */
private fun IrValueParameter.allPossibleArguments(): List<Argument>? {
val classId = type.getClass()?.classId ?: return null
val referenceClass = pluginContext.referenceClass(classId)?.owner ?: return null
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
return enumEntries.map { Argument(type, it) }
}

return listOf(expansion)
private fun IrClassSymbol.asAnnotation(): IrConstructorCall {
return IrConstructorCallImpl.fromSymbolOwner(
type = starProjectedType,
constructorSymbol = constructors.single(),
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package app.cash.burst.kotlin

/**
* Given an iterable lists like `[[A, B, C], [1, 2, 3]]`, return the cartesian product like
* `[[A, 1], [A, 2], [A, 3], [B, 1], [B, 2], [B, 3], [C, 1], [C, 2], [C, 3]]`.
*/
fun <T> Iterable<List<T>>.cartesianProduct(): List<List<T>> {
return fold(listOf(listOf())) { partials, list ->
partials.flatMap { partial -> list.map { element -> partial + element } }
}
}

0 comments on commit 8d953c8

Please sign in to comment.