Skip to content

Commit

Permalink
Merge pull request #16 from cashapp/jwilson.1010.rename_stuff
Browse files Browse the repository at this point in the history
Rename some internal APIs in preparation for class specialization
  • Loading branch information
squarejesse authored Oct 10, 2024
2 parents ec8b43c + 5dd1044 commit ff2d4fc
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 369 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class BurstGradlePluginTest {
val originalTest = testSuite.testCases.single { it.name == "test[jvm]" }
assertThat(originalTest.skipped).isTrue()

val sampleVariant = testSuite.testCases.single { it.name == "test_Decaf_Oat[jvm]" }
assertThat(sampleVariant.skipped).isFalse()
val sampleSpecialization = testSuite.testCases.single { it.name == "test_Decaf_Oat[jvm]" }
assertThat(sampleSpecialization.skipped).isFalse()
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package app.cash.burst.kotlin

import assertk.assertThat
import assertk.assertions.contains
import assertk.assertions.containsExactly
import assertk.assertions.isFalse
import assertk.assertions.isTrue
Expand All @@ -36,8 +37,6 @@ class BurstKotlinPluginTest {
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
package app.cash.burst.testing
import app.cash.burst.Burst
import kotlin.test.Test
Expand All @@ -58,7 +57,7 @@ class BurstKotlinPluginTest {
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val adapterClass = result.classLoader.loadClass("app.cash.burst.testing.CoffeeTest")
val adapterClass = result.classLoader.loadClass("CoffeeTest")
val adapterInstance = adapterClass.constructors.single().newInstance()
val log = adapterClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*>

Expand Down Expand Up @@ -91,6 +90,29 @@ class BurstKotlinPluginTest {
"running Double Oat",
)
}

@Test
fun unexpectedArgumentType() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import kotlin.test.Test
@Burst
class CoffeeTest {
@Test
fun test(espresso: String) {
}
}
""",
),
)
assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages)
assertThat(result.messages)
.contains("CoffeeTest.kt:7:12 Expected an enum for @Burst test parameter")
}
}

@ExperimentalCompilerApi
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright (C) 2024 Cash App
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package app.cash.burst.kotlin

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.isEnumClass

internal class Argument(
private val original: IrElement,
private val type: IrType,
internal val value: IrEnumEntry,
) {
/** Returns an expression that looks up this argument. */
fun get(): IrExpression {
return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
}
}

/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */
internal fun name(
prefix: String,
arguments: List<Argument>,
): String {
return arguments.joinToString(
prefix = prefix,
separator = "_",
) { argument ->
argument.value.name.identifier
}
}

/** Returns null if we can't compute all possible arguments for this parameter. */
internal fun IrPluginContext.allPossibleArguments(
parameter: IrValueParameter,
): List<Argument>? {
val classId = parameter.type.getClass()?.classId ?: return null
val referenceClass = referenceClass(classId)?.owner ?: return null
if (!referenceClass.isEnumClass) return null
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
return enumEntries.map { Argument(parameter, parameter.type, it) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.IrStatement
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.util.functions
Expand All @@ -39,28 +38,31 @@ class BurstIrGenerationExtension(
val classDeclaration = super.visitClassNew(declaration) as IrClass
val classHasAtBurst = classDeclaration.hasAtBurst

val addedDeclarations = mutableListOf<IrDeclaration>()
// Return early if there's no @Burst anywhere.
if (!classHasAtBurst && classDeclaration.functions.none { it.hasAtBurst }) {
return classDeclaration
}

// Snapshot the original functions because the loop mutates them.
val originalFunctions = classDeclaration.functions.toList()

for (function in classDeclaration.functions) {
for (function in originalFunctions) {
if (!function.hasAtTest) continue
if (!classHasAtBurst && !function.hasAtBurst) continue

if (classHasAtBurst || function.hasAtBurst) {
val rewriter = BurstRewriter(
messageCollector = messageCollector,
try {
val specializer = FunctionSpecializer(
pluginContext = pluginContext,
burstApis = burstApis,
file = currentFile,
originalParent = classDeclaration,
original = function,
)
addedDeclarations += rewriter.rewrite()
specializer.generateSpecializations()
} catch (e: BurstCompilationException) {
messageCollector.report(e.severity, e.message, currentFile.locationOf(e.element))
}
}

for (added in addedDeclarations) {
classDeclaration.declarations.add(added)
added.parent = classDeclaration
}

return classDeclaration
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,95 +17,73 @@ package app.cash.burst.kotlin

import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.ir.addDispatchReceiver
import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.builders.irCall
import org.jetbrains.kotlin.ir.builders.irGet
import org.jetbrains.kotlin.ir.builders.irTemporary
import org.jetbrains.kotlin.ir.declarations.IrClass
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
import org.jetbrains.kotlin.ir.declarations.IrEnumEntry
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.declarations.IrValueParameter
import org.jetbrains.kotlin.ir.expressions.IrConstructorCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl
import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.classFqName
import org.jetbrains.kotlin.ir.types.getClass
import org.jetbrains.kotlin.ir.types.starProjectedType
import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.constructors
import org.jetbrains.kotlin.name.Name

@OptIn(UnsafeDuringIrConstructionAPI::class)
internal class BurstRewriter(
private val messageCollector: MessageCollector,
internal class FunctionSpecializer(
private val pluginContext: IrPluginContext,
private val burstApis: BurstApis,
private val file: IrFile,
private val originalParent: IrClass,
private val original: IrSimpleFunction,
) {
/** Returns a list of additional declarations. */
fun rewrite(): List<IrDeclaration> {
fun generateSpecializations() {
val originalValueParameters = original.valueParameters
if (originalValueParameters.isEmpty()) {
return listOf()
}
if (originalValueParameters.isEmpty()) return // Nothing to do.

val originalDispatchReceiver = original.dispatchReceiverParameter
if (originalDispatchReceiver == null) {
messageCollector.report(
CompilerMessageSeverity.ERROR,
"Unexpected dispatch receiver",
file.locationOf(original),
)
return listOf()
}
?: throw BurstCompilationException("Unexpected dispatch receiver", original)

val parameterArguments = mutableListOf<List<Argument>>()
for (parameter in originalValueParameters) {
val expanded = parameter.allPossibleArguments()
if (expanded == null) {
messageCollector.report(
CompilerMessageSeverity.ERROR,
"Expected an enum for @Burst test parameter",
file.locationOf(parameter),
)
return listOf()
}
val expanded = pluginContext.allPossibleArguments(parameter)
?: throw BurstCompilationException("Expected an enum for @Burst test parameter", parameter)
parameterArguments += expanded
}

val cartesianProduct = parameterArguments.cartesianProduct()

val variants = cartesianProduct.map { variantArguments ->
createVariant(originalDispatchReceiver, variantArguments)
val specializations = cartesianProduct.map { arguments ->
createSpecialization(originalDispatchReceiver, arguments)
}

// Side-effect: drop `@Test` from the original's annotations.
// Drop `@Test` from the original's annotations.
original.annotations = original.annotations.filter {
it.type.classFqName != burstApis.testClassSymbol.starProjectedType.classFqName
}

val result = mutableListOf<IrDeclaration>()
result += createFunctionThatCallsAllVariants(originalDispatchReceiver, variants)
result += variants
return result
// Add new declarations.
for (specialization in specializations) {
originalParent.addDeclaration(specialization)
}
originalParent.addDeclaration(
createFunctionThatCallsAllSpecializations(originalDispatchReceiver, specializations),
)
}

private fun IrClass.addDeclaration(declaration: IrDeclaration) {
declarations.add(declaration)
declaration.parent = this
}

private fun createVariant(
private fun createSpecialization(
originalDispatchReceiver: IrValueParameter,
arguments: List<Argument>,
): IrSimpleFunction {
val result = original.factory.buildFun {
initDefaults(original)
name = Name.identifier(name(arguments))
name = Name.identifier(name("${original.name.identifier}_", arguments))
returnType = original.returnType
}.apply {
addDispatchReceiver {
Expand Down Expand Up @@ -141,10 +119,10 @@ internal class BurstRewriter(
return result
}

/** Creates a function with no arguments that calls each variant. */
private fun createFunctionThatCallsAllVariants(
/** Creates an @Test @Ignore no-args function that calls each specialization. */
private fun createFunctionThatCallsAllSpecializations(
originalDispatchReceiver: IrValueParameter,
variants: List<IrSimpleFunction>,
specializations: List<IrSimpleFunction>,
): IrSimpleFunction {
val result = original.factory.buildFun {
initDefaults(original)
Expand Down Expand Up @@ -172,9 +150,9 @@ internal class BurstRewriter(
origin = IrDeclarationOrigin.DEFINED
}

for (variant in variants) {
for (specialization in specializations) {
+irCall(
callee = variant.symbol,
callee = specialization.symbol,
).apply {
this.dispatchReceiver = irGet(receiverLocal)
}
Expand All @@ -183,39 +161,4 @@ internal class BurstRewriter(

return result
}

private inner class Argument(
val type: IrType,
val value: IrEnumEntry,
)

/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */
private fun name(arguments: List<Argument>): String {
return arguments.joinToString(
prefix = "${original.name.identifier}_",
separator = "_",
) { argument ->
argument.value.name.identifier
}
}

/** Returns an expression that looks up this argument. */
private fun Argument.get(): IrExpression {
return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol)
}

/** Returns null if we can't compute all possible arguments for this parameter. */
private fun IrValueParameter.allPossibleArguments(): List<Argument>? {
val classId = type.getClass()?.classId ?: return null
val referenceClass = pluginContext.referenceClass(classId)?.owner ?: return null
val enumEntries = referenceClass.declarations.filterIsInstance<IrEnumEntry>()
return enumEntries.map { Argument(type, it) }
}

private fun IrClassSymbol.asAnnotation(): IrConstructorCall {
return IrConstructorCallImpl.fromSymbolOwner(
type = starProjectedType,
constructorSymbol = constructors.single(),
)
}
}
Loading

0 comments on commit ff2d4fc

Please sign in to comment.