Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial, limited implementation of burstValues() #45

Merged
merged 5 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ class BurstKotlinPluginTest {
),
)
assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages)
assertThat(result.messages)
.contains("CoffeeTest.kt:7:12 Expected an enum for @Burst test parameter")
assertThat(result.messages).contains(
"CoffeeTest.kt:7:12 " +
"@Burst parameter must be an enum or have a burstValues() default value",
)
}

@Test
Expand All @@ -129,8 +131,10 @@ class BurstKotlinPluginTest {
),
)
assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages)
assertThat(result.messages)
.contains("CoffeeTest.kt:9:12 @Burst default parameter must be an enum constant (or absent)")
assertThat(result.messages).contains(
"CoffeeTest.kt:9:12 " +
"@Burst parameter default value must be burstValues(), an enum constant, or absent",
)
}

@Test
Expand Down Expand Up @@ -258,6 +262,50 @@ class BurstKotlinPluginTest {
// Other test functions are available.
baseClass.getMethod("test_Oat")
}

@Test
fun burstValues() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import app.cash.burst.burstValues
import kotlin.test.Test

@Burst
class CoffeeTest {
val log = mutableListOf<String>()

@Test
fun test(volume: Int = burstValues(12, 16, 20)) {
log += "running ${'$'}volume"
}
}
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")
val baseInstance = baseClass.constructors.single().newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

// The test function gets its default parameter value.
baseClass.getMethod("test").invoke(baseInstance)
assertThat(baseLog).containsExactly("running 12")
baseLog.clear()

// The default test function is not generated.
assertFailsWith<NoSuchMethodException> {
baseClass.getMethod("test_12")
}

// Other test functions are available, named by the literal values.
baseClass.getMethod("test_16").invoke(baseInstance)
assertThat(baseLog).containsExactly("running 16")
baseLog.clear()
}
}

@ExperimentalCompilerApi
Expand Down
144 changes: 114 additions & 30 deletions burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,68 +17,152 @@ package app.cash.burst.kotlin

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.backend.js.utils.valueArguments
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrConst
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrGetEnumValue
import org.jetbrains.kotlin.ir.expressions.IrVararg
import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.deepCopyWithSymbols
import org.jetbrains.kotlin.ir.util.isEnumClass

internal class Argument(
private val original: IrElement,
private val type: IrType,
internal sealed interface Argument {
/** True if this argument matches the default parameter value. */
internal val isDefault: Boolean,
internal val value: IrEnumEntry,
) {
val isDefault: Boolean

/** A string that's safe to use in a declaration name. */
val name: String

/** Returns an expression that looks up this argument. */
fun get(): IrExpression {
return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
}
fun expression(): IrExpression
}

private class EnumValueArgument(
private val original: IrElement,
private val type: IrType,
override val isDefault: Boolean,
private val value: IrEnumEntry,
) : Argument {
override val name = value.name.identifier

override fun expression() =
IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
}

private class BurstValuesArgument(
override val isDefault: Boolean,
private val value: IrExpression,
private val index: Int,
) : Argument {
override val name: String
get() {
return when {
value is IrConst<*> -> value.value.toString()
else -> index.toString()
}
}

override fun expression() = value.deepCopyWithSymbols()
}

/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */
internal fun name(
prefix: String,
arguments: List<Argument>,
): String {
return arguments.joinToString(
prefix = prefix,
separator = "_",
) { argument ->
argument.value.name.identifier
): String = arguments.joinToString(prefix = prefix, separator = "_", transform = Argument::name)

/**
* Returns all arguments for [parameter].
*
* If the parameter's default value is an immediate call to `burstValues()`, this returns an
* argument for each value.
*
* If the parameter's type is an enum, this returns each enum constant for that type.
*
* @throws BurstCompilationException if we can't compute all possible arguments for this parameter.
*/
@UnsafeDuringIrConstructionAPI
internal fun IrPluginContext.allPossibleArguments(
parameter: IrValueParameter,
burstApis: BurstApis,
): List<Argument> {
val burstApisCall = parameter.defaultValue?.expression as? IrCall
if (burstApisCall?.symbol == burstApis.burstValues) {
return burstValuesArguments(parameter, burstApisCall)
}

val classId = parameter.type.getClass()?.classId ?: unexpectedParameter(parameter)
val referenceClass = referenceClass(classId)?.owner ?: unexpectedParameter(parameter)
if (referenceClass.isEnumClass) {
return enumValueArguments(referenceClass, parameter)
}

unexpectedParameter(parameter)
}

/** Returns null if we can't compute all possible arguments for this parameter. */
internal fun IrPluginContext.allPossibleArguments(
private fun burstValuesArguments(
parameter: IrValueParameter,
): List<Argument>? {
val classId = parameter.type.getClass()?.classId ?: return null
val referenceClass = referenceClass(classId)?.owner ?: return null
if (!referenceClass.isEnumClass) return null
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
burstApisCall: IrCall,
): List<Argument> {
return buildList {
add(
BurstValuesArgument(
isDefault = true,
value = burstApisCall.valueArguments[0] ?: unexpectedParameter(parameter),
index = 0,
),
)

val defaultValueSymbol = parameter.defaultValue?.let { defaultValue ->
val expression = defaultValue.expression
if (expression !is IrGetEnumValue) {
throw BurstCompilationException(
"@Burst default parameter must be an enum constant (or absent)",
parameter,
for ((index, element) in (burstApisCall.valueArguments[1] as IrVararg).elements.withIndex()) {
add(
BurstValuesArgument(
isDefault = false,
value = element as? IrExpression ?: unexpectedParameter(parameter),
index = index + 1,
),
)
}
expression.symbol
}
}

@UnsafeDuringIrConstructionAPI
private fun enumValueArguments(
referenceClass: IrClass,
parameter: IrValueParameter,
): List<EnumValueArgument> {
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
val defaultValueSymbol = parameter.defaultValue?.let { defaultValue ->
(defaultValue.expression as? IrGetEnumValue)?.symbol ?: unexpectedDefaultValue(parameter)
}

return enumEntries.map {
Argument(
EnumValueArgument(
original = parameter,
type = parameter.type,
isDefault = it.symbol == defaultValueSymbol,
value = it,
)
}
}

private fun unexpectedParameter(parameter: IrValueParameter): Nothing {
throw BurstCompilationException(
"@Burst parameter must be an enum or have a burstValues() default value",
parameter,
)
}

private fun unexpectedDefaultValue(parameter: IrValueParameter): Nothing {
throw BurstCompilationException(
"@Burst parameter default value must be burstValues(), an enum constant, or absent",
parameter,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@ package app.cash.burst.kotlin
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.declarations.IrAnnotationContainer
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
import org.jetbrains.kotlin.ir.util.hasAnnotation

/** Looks up APIs used by the code rewriters. */
internal class BurstApis private constructor(
private val pluginContext: IrPluginContext,
private val testPackage: FqPackageName,
pluginContext: IrPluginContext,
testPackage: FqPackageName,
) {
companion object {
fun maybeCreate(pluginContext: IrPluginContext): BurstApis? {
// If we don't have @Burst, we don't have the runtime. Abort!
if (pluginContext.referenceClass(burstAnnotationClassId) == null) {
if (pluginContext.referenceClass(burstAnnotationId) == null) {
return null
}

Expand All @@ -45,20 +46,21 @@ internal class BurstApis private constructor(
}
}

val testClassSymbol: IrClassSymbol
get() = pluginContext.referenceClass(testPackage.classId("Test"))!!
val testClassSymbol: IrClassSymbol = pluginContext.referenceClass(testPackage.classId("Test"))!!
val burstValues: IrFunctionSymbol = pluginContext.referenceFunctions(burstValuesId).single()
}

private val burstFqPackage = FqPackageName("app.cash.burst")
private val burstAnnotationClassId = burstFqPackage.classId("Burst")
private val burstAnnotationId = burstFqPackage.classId("Burst")
private val burstValuesId = burstFqPackage.callableId("burstValues")

val junitPackage = FqPackageName("org.junit")
val junitTestClassId = junitPackage.classId("Test")
val kotlinTestPackage = FqPackageName("kotlin.test")
val kotlinTestClassId = kotlinTestPackage.classId("Test")
private val junitPackage = FqPackageName("org.junit")
private val junitTestClassId = junitPackage.classId("Test")
private val kotlinTestPackage = FqPackageName("kotlin.test")
private val kotlinTestClassId = kotlinTestPackage.classId("Test")

internal val IrAnnotationContainer.hasAtTest: Boolean
get() = hasAnnotation(junitTestClassId) || hasAnnotation(kotlinTestClassId)

internal val IrAnnotationContainer.hasAtBurst: Boolean
get() = hasAnnotation(burstAnnotationClassId)
get() = hasAnnotation(burstAnnotationId)
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class BurstIrGenerationExtension(
if (classHasAtBurst) {
ClassSpecializer(
pluginContext = pluginContext,
burstApis = burstApis,
originalParent = currentFile,
original = classDeclaration,
).generateSpecializations()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ import org.jetbrains.kotlin.name.Name
@OptIn(UnsafeDuringIrConstructionAPI::class)
internal class ClassSpecializer(
private val pluginContext: IrPluginContext,
private val burstApis: BurstApis,
private val originalParent: IrFile,
private val original: IrClass,
) {
Expand All @@ -88,8 +89,7 @@ internal class ClassSpecializer(
if (valueParameters.isEmpty()) return // Nothing to do.

val parameterArguments = valueParameters.map { parameter ->
pluginContext.allPossibleArguments(parameter)
?: throw BurstCompilationException("Expected an enum for @Burst test parameter", parameter)
pluginContext.allPossibleArguments(parameter, burstApis)
}

val cartesianProduct = parameterArguments.cartesianProduct()
Expand Down Expand Up @@ -153,7 +153,7 @@ internal class ClassSpecializer(
valueArgumentsCount = arguments.size,
) {
for ((index, argument) in arguments.withIndex()) {
putValueArgument(index, argument.get())
putValueArgument(index, argument.expression())
}
}
statements += irInstanceInitializerCall(
Expand Down Expand Up @@ -182,7 +182,7 @@ internal class ClassSpecializer(
valueArgumentsCount = arguments.size,
) {
for ((index, argument) in arguments.withIndex()) {
putValueArgument(index, argument.get())
putValueArgument(index, argument.expression())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ internal class FunctionSpecializer(
?: throw BurstCompilationException("Unexpected dispatch receiver", original)

val parameterArguments = valueParameters.map { parameter ->
pluginContext.allPossibleArguments(parameter)
?: throw BurstCompilationException("Expected an enum for @Burst test parameter", parameter)
pluginContext.allPossibleArguments(parameter, burstApis)
}

val cartesianProduct = parameterArguments.cartesianProduct()
Expand Down Expand Up @@ -138,7 +137,7 @@ internal class FunctionSpecializer(
).apply {
this.dispatchReceiver = irGet(receiverLocal)
for ((index, argument) in arguments.withIndex()) {
putValueArgument(index, argument.get())
putValueArgument(index, argument.expression())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package app.cash.burst.kotlin

import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
Expand All @@ -11,3 +12,5 @@ value class FqPackageName(val fqName: FqName)
fun FqPackageName(name: String) = FqPackageName(FqName(name))

fun FqPackageName.classId(name: String) = ClassId(fqName, Name.identifier(name))

fun FqPackageName.callableId(name: String) = CallableId(fqName, Name.identifier(name))
4 changes: 4 additions & 0 deletions burst/api/burst.api
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
public abstract interface annotation class app/cash/burst/Burst : java/lang/annotation/Annotation {
}

public final class app/cash/burst/BurstKt {
public static final fun burstValues (Ljava/lang/Object;[Ljava/lang/Object;)Ljava/lang/Object;
}

Loading