Skip to content

Commit

Permalink
Rename some internal APIs in preparation for class specialization
Browse files Browse the repository at this point in the history
Rename 'variant' to 'specialization'.

Rename 'BurstRewriter' to 'FunctionSpecializer'.

Start adding new declarations in-place, rather than using
a mix of side-effects and return values.

Use exceptions to report syntax errors.
  • Loading branch information
squarejesse committed Oct 10, 2024
1 parent 2719db0 commit b0958ad
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 369 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class BurstGradlePluginTest {
val originalTest = testSuite.testCases.single { it.name == "test[jvm]" }
assertThat(originalTest.skipped).isTrue()

val sampleVariant = testSuite.testCases.single { it.name == "test_Decaf_Oat[jvm]" }
assertThat(sampleVariant.skipped).isFalse()
val sampleSpecialization = testSuite.testCases.single { it.name == "test_Decaf_Oat[jvm]" }
assertThat(sampleSpecialization.skipped).isFalse()
}

private fun createRunner(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package app.cash.burst.kotlin

import assertk.assertThat
import assertk.assertions.contains
import assertk.assertions.containsExactly
import assertk.assertions.isFalse
import assertk.assertions.isTrue
Expand All @@ -36,8 +37,6 @@ class BurstKotlinPluginTest {
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
package app.cash.burst.testing
import app.cash.burst.Burst
import kotlin.test.Test
Expand All @@ -58,7 +57,7 @@ class BurstKotlinPluginTest {
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val adapterClass = result.classLoader.loadClass("app.cash.burst.testing.CoffeeTest")
val adapterClass = result.classLoader.loadClass("CoffeeTest")
val adapterInstance = adapterClass.constructors.single().newInstance()
val log = adapterClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*>

Expand Down Expand Up @@ -91,6 +90,29 @@ class BurstKotlinPluginTest {
"running Double Oat",
)
}

@Test
fun unexpectedArgumentType() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import kotlin.test.Test
@Burst
class CoffeeTest {
@Test
fun test(espresso: String) {
}
}
""",
),
)
assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages)
assertThat(result.messages)
.contains("CoffeeTest.kt:7:12 Expected an enum for @Burst test parameter")
}
}

@ExperimentalCompilerApi
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (C) 2024 Cash App
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package app.cash.burst.kotlin

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.isEnumClass

internal class Argument(
private val original: IrElement,
private val type: IrType,
internal val value: IrEnumEntry,
) {
/** Returns an expression that looks up this argument. */
fun get(): IrExpression {
return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
}
}

/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */
internal fun name(
prefix: String,
arguments: List<Argument>
): String {
return arguments.joinToString(
prefix = prefix,
separator = "_",
) { argument ->
argument.value.name.identifier
}
}

/** Returns null if we can't compute all possible arguments for this parameter. */
internal fun IrPluginContext.allPossibleArguments(
parameter: IrValueParameter
): List<Argument>? {
val classId = parameter.type.getClass()?.classId ?: return null
val referenceClass = referenceClass(classId)?.owner ?: return null
if (!referenceClass.isEnumClass) return null
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
return enumEntries.map { Argument(parameter, parameter.type, it) }
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.util.functions
Expand All @@ -39,28 +38,31 @@ class BurstIrGenerationExtension(
val classDeclaration = super.visitClassNew(declaration) as IrClass
val classHasAtBurst = classDeclaration.hasAtBurst

val addedDeclarations = mutableListOf<IrDeclaration>()
// Return early if there's no @Burst anywhere.
if (!classHasAtBurst && classDeclaration.functions.none { it.hasAtBurst }) {
return classDeclaration
}

// Snapshot the original functions because the loop mutates them.
val originalFunctions = classDeclaration.functions.toList()

for (function in classDeclaration.functions) {
for (function in originalFunctions) {
if (!function.hasAtTest) continue
if (!classHasAtBurst && !function.hasAtBurst) continue

if (classHasAtBurst || function.hasAtBurst) {
val rewriter = BurstRewriter(
messageCollector = messageCollector,
try {
val specializer = FunctionSpecializer(
pluginContext = pluginContext,
burstApis = burstApis,
file = currentFile,
originalParent = classDeclaration,
original = function,
)
addedDeclarations += rewriter.rewrite()
specializer.generateSpecializations()
} catch (e: BurstCompilationException) {
messageCollector.report(e.severity, e.message, currentFile.locationOf(e.element))
}
}

for (added in addedDeclarations) {
classDeclaration.declarations.add(added)
added.parent = classDeclaration
}

return classDeclaration
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,95 +17,73 @@ package app.cash.burst.kotlin

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.ir.addDispatchReceiver
import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irTemporary
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.expressions.IrConstructorCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.classFqName
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.types.starProjectedType
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.constructors
import org.jetbrains.kotlin.name.Name

@OptIn(UnsafeDuringIrConstructionAPI::class)
internal class BurstRewriter(
private val messageCollector: MessageCollector,
internal class FunctionSpecializer(
private val pluginContext: IrPluginContext,
private val burstApis: BurstApis,
private val file: IrFile,
private val originalParent: IrClass,
private val original: IrSimpleFunction,
) {
/** Returns a list of additional declarations. */
fun rewrite(): List<IrDeclaration> {
fun generateSpecializations() {
val originalValueParameters = original.valueParameters
if (originalValueParameters.isEmpty()) {
return listOf()
}
if (originalValueParameters.isEmpty()) return // Nothing to do.

val originalDispatchReceiver = original.dispatchReceiverParameter
if (originalDispatchReceiver == null) {
messageCollector.report(
CompilerMessageSeverity.ERROR,
"Unexpected dispatch receiver",
file.locationOf(original),
)
return listOf()
}
?: throw BurstCompilationException("Unexpected dispatch receiver", original)

val parameterArguments = mutableListOf<List<Argument>>()
for (parameter in originalValueParameters) {
val expanded = parameter.allPossibleArguments()
if (expanded == null) {
messageCollector.report(
CompilerMessageSeverity.ERROR,
"Expected an enum for @Burst test parameter",
file.locationOf(parameter),
)
return listOf()
}
val expanded = pluginContext.allPossibleArguments(parameter)
?: throw BurstCompilationException("Expected an enum for @Burst test parameter", parameter)
parameterArguments += expanded
}

val cartesianProduct = parameterArguments.cartesianProduct()

val variants = cartesianProduct.map { variantArguments ->
createVariant(originalDispatchReceiver, variantArguments)
val specializations = cartesianProduct.map { arguments ->
createSpecialization(originalDispatchReceiver, arguments)
}

// Side-effect: drop `@Test` from the original's annotations.
// Drop `@Test` from the original's annotations.
original.annotations = original.annotations.filter {
it.type.classFqName != burstApis.testClassSymbol.starProjectedType.classFqName
}

val result = mutableListOf<IrDeclaration>()
result += createFunctionThatCallsAllVariants(originalDispatchReceiver, variants)
result += variants
return result
// Add new declarations.
for (specialization in specializations) {
originalParent.addDeclaration(specialization)
}
originalParent.addDeclaration(
createFunctionThatCallsAllSpecializations(originalDispatchReceiver, specializations)
)
}

private fun IrClass.addDeclaration(declaration: IrDeclaration) {
declarations.add(declaration)
declaration.parent = this
}

private fun createVariant(
private fun createSpecialization(
originalDispatchReceiver: IrValueParameter,
arguments: List<Argument>,
): IrSimpleFunction {
val result = original.factory.buildFun {
initDefaults(original)
name = Name.identifier(name(arguments))
name = Name.identifier(name("${original.name.identifier}_", arguments))
returnType = original.returnType
}.apply {
addDispatchReceiver {
Expand Down Expand Up @@ -141,10 +119,10 @@ internal class BurstRewriter(
return result
}

/** Creates a function with no arguments that calls each variant. */
private fun createFunctionThatCallsAllVariants(
/** Creates an @Test @Ignore no-args function that calls each specialization. */
private fun createFunctionThatCallsAllSpecializations(
originalDispatchReceiver: IrValueParameter,
variants: List<IrSimpleFunction>,
specializations: List<IrSimpleFunction>,
): IrSimpleFunction {
val result = original.factory.buildFun {
initDefaults(original)
Expand Down Expand Up @@ -172,9 +150,9 @@ internal class BurstRewriter(
origin = IrDeclarationOrigin.DEFINED
}

for (variant in variants) {
for (specialization in specializations) {
+irCall(
callee = variant.symbol,
callee = specialization.symbol,
).apply {
this.dispatchReceiver = irGet(receiverLocal)
}
Expand All @@ -183,39 +161,4 @@ internal class BurstRewriter(

return result
}

private inner class Argument(
val type: IrType,
val value: IrEnumEntry,
)

/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */
private fun name(arguments: List<Argument>): String {
return arguments.joinToString(
prefix = "${original.name.identifier}_",
separator = "_",
) { argument ->
argument.value.name.identifier
}
}

/** Returns an expression that looks up this argument. */
private fun Argument.get(): IrExpression {
return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
}

/** Returns null if we can't compute all possible arguments for this parameter. */
private fun IrValueParameter.allPossibleArguments(): List<Argument>? {
val classId = type.getClass()?.classId ?: return null
val referenceClass = pluginContext.referenceClass(classId)?.owner ?: return null
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
return enumEntries.map { Argument(type, it) }
}

private fun IrClassSymbol.asAnnotation(): IrConstructorCall {
return IrConstructorCallImpl.fromSymbolOwner(
type = starProjectedType,
constructorSymbol = constructors.single(),
)
}
}
Loading

0 comments on commit b0958ad

Please sign in to comment.