From e1c29095b580481bd1c362432d2536a0edfb6831 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Mon, 28 Oct 2024 15:38:00 -0400 Subject: [PATCH] Avoid collisions in generated symbols (#48) If no names collide, this does nothing. Otherwise each generated symbol is prefixed with its index. Co-authored-by: Jesse Wilson --- .../burst/kotlin/BurstKotlinPluginTest.kt | 44 ++++++++++++++ .../kotlin/app/cash/burst/kotlin/Argument.kt | 6 -- .../app/cash/burst/kotlin/ClassSpecializer.kt | 47 +++++++-------- .../cash/burst/kotlin/FunctionSpecializer.kt | 29 ++++------ .../app/cash/burst/kotlin/Specialization.kt | 58 +++++++++++++++++++ 5 files changed, 133 insertions(+), 51 deletions(-) create mode 100644 burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Specialization.kt diff --git a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt index 908a2b2..d5a1ade 100644 --- a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt +++ b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt @@ -464,6 +464,50 @@ class BurstKotlinPluginTest { ) } + @Test + fun burstValuesWithNameCollisions() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import app.cash.burst.burstValues + import kotlin.test.Test + + @Burst + class CoffeeTest { + @Test + fun test( + content: Any? = burstValues( + 3, // No name is generated for the first value. + "1", + 1, + 1L, + "CASE_INSENSITIVE_ORDER", + String.CASE_INSENSITIVE_ORDER, + true, + "true" + ) + ) { + } + } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) + + val baseClass = result.classLoader.loadClass("CoffeeTest") + assertThat(baseClass.testSuffixes).containsExactlyInAnyOrder( + "1_1", + "2_1", + "3_1", + "4_CASE_INSENSITIVE_ORDER", + "5_CASE_INSENSITIVE_ORDER", + "6_true", + "7_true", + ) + } + private val Class<*>.testSuffixes: List get() = methods.mapNotNull { when { diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt index b03cd96..57218b0 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt @@ -70,12 +70,6 @@ private class BurstValuesArgument( override fun expression() = value.deepCopyWithSymbols(declarationParent) } -/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */ -internal fun name( - prefix: String, - arguments: List, -): String = arguments.joinToString(prefix = prefix, separator = "_", transform = Argument::name) - /** * Returns all arguments for [parameter]. * diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt index 1393a14..cbd7655 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ClassSpecializer.kt @@ -88,15 +88,8 @@ internal class ClassSpecializer( val valueParameters = onlyConstructor.valueParameters if (valueParameters.isEmpty()) return // Nothing to do. - val parameterArguments = valueParameters.map { parameter -> - pluginContext.allPossibleArguments(parameter, burstApis) - } - - val cartesianProduct = parameterArguments.cartesianProduct() - - val indexOfDefaultSpecialization = cartesianProduct.indexOfFirst { arguments -> - arguments.all { it.isDefault } - } + val specializations = specializations(pluginContext, burstApis, valueParameters) + val indexOfDefaultSpecialization = specializations.indexOfFirst { it.isDefault } // Make sure the constructor we're using is accessible. Drop the default arguments to prevent // JUnit from using it. @@ -111,7 +104,7 @@ internal class ClassSpecializer( // Add a no-args constructor that calls the only constructor as the default specialization. createNoArgsConstructor( superConstructor = onlyConstructor, - arguments = cartesianProduct[indexOfDefaultSpecialization], + specialization = specializations[indexOfDefaultSpecialization], ) } else { // There's no default specialization. Make the class abstract so JUnit skips it. @@ -119,57 +112,57 @@ internal class ClassSpecializer( } // Add a subclass for each specialization. - cartesianProduct.mapIndexed { index, arguments -> + for ((index, specialization) in specializations.withIndex()) { // Don't generate code for the default specialization; we only want to run it once. - if (index == indexOfDefaultSpecialization) return@mapIndexed + if (index == indexOfDefaultSpecialization) continue - createSpecialization( + createSubclass( superConstructor = onlyConstructor, - arguments = arguments, + specialization = specialization, ) } } - private fun createSpecialization( + private fun createSubclass( superConstructor: IrConstructor, - arguments: List, + specialization: Specialization, ) { - val specialization = original.factory.buildClass { + val created = original.factory.buildClass { initDefaults(original) visibility = PUBLIC - name = Name.identifier(name("${original.name.identifier}_", arguments)) + name = Name.identifier("${original.name.identifier}_${specialization.name}") }.apply { superTypes = listOf(original.defaultType) createImplicitParameterDeclarationWithWrappedDescriptor() } - specialization.addConstructor { + created.addConstructor { initDefaults(original) }.apply { irConstructorBody(pluginContext) { statements -> statements += irDelegatingConstructorCall( context = pluginContext, symbol = superConstructor.symbol, - valueArgumentsCount = arguments.size, + valueArgumentsCount = specialization.arguments.size, ) { - for ((index, argument) in arguments.withIndex()) { + for ((index, argument) in specialization.arguments.withIndex()) { putValueArgument(index, argument.expression()) } } statements += irInstanceInitializerCall( context = pluginContext, - classSymbol = specialization.symbol, + classSymbol = created.symbol, ) } } - originalParent.addDeclaration(specialization) - specialization.addFakeOverrides(irTypeSystemContext) + originalParent.addDeclaration(created) + created.addFakeOverrides(irTypeSystemContext) } private fun createNoArgsConstructor( superConstructor: IrConstructor, - arguments: List, + specialization: Specialization, ) { original.addConstructor { initDefaults(original) @@ -179,9 +172,9 @@ internal class ClassSpecializer( statements += irDelegatingConstructorCall( context = pluginContext, symbol = superConstructor.symbol, - valueArgumentsCount = arguments.size, + valueArgumentsCount = specialization.arguments.size, ) { - for ((index, argument) in arguments.withIndex()) { + for ((index, argument) in specialization.arguments.withIndex()) { putValueArgument(index, argument.expression()) } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt index 67547e9..b1a885c 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt @@ -70,20 +70,13 @@ internal class FunctionSpecializer( val originalDispatchReceiver = original.dispatchReceiverParameter ?: throw BurstCompilationException("Unexpected dispatch receiver", original) - val parameterArguments = valueParameters.map { parameter -> - pluginContext.allPossibleArguments(parameter, burstApis) - } - - val cartesianProduct = parameterArguments.cartesianProduct() - - val indexOfDefaultSpecialization = cartesianProduct.indexOfFirst { arguments -> - arguments.all { it.isDefault } - } + val specializations = specializations(pluginContext, burstApis, valueParameters) + val indexOfDefaultSpecialization = specializations.indexOfFirst { it.isDefault } - val specializations = cartesianProduct.mapIndexed { index, arguments -> - createSpecialization( + val functions = specializations.mapIndexed { index, specialization -> + createFunction( originalDispatchReceiver = originalDispatchReceiver, - arguments = arguments, + specialization = specialization, isDefaultSpecialization = index == indexOfDefaultSpecialization, ) } @@ -94,21 +87,21 @@ internal class FunctionSpecializer( } // Add new declarations. - for (specialization in specializations) { - originalParent.addDeclaration(specialization) + for (function in functions) { + originalParent.addDeclaration(function) } } - private fun createSpecialization( + private fun createFunction( originalDispatchReceiver: IrValueParameter, - arguments: List, + specialization: Specialization, isDefaultSpecialization: Boolean, ): IrSimpleFunction { val result = original.factory.buildFun { initDefaults(original) name = when { isDefaultSpecialization -> original.name - else -> Name.identifier(name("${original.name.identifier}_", arguments)) + else -> Name.identifier("${original.name.identifier}_${specialization.name}") } returnType = original.returnType }.apply { @@ -136,7 +129,7 @@ internal class FunctionSpecializer( callee = original.symbol, ).apply { this.dispatchReceiver = irGet(receiverLocal) - for ((index, argument) in arguments.withIndex()) { + for ((index, argument) in specialization.arguments.withIndex()) { putValueArgument(index, argument.expression()) } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Specialization.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Specialization.kt new file mode 100644 index 0000000..f116dc2 --- /dev/null +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Specialization.kt @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2024 Cash App + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package app.cash.burst.kotlin + +import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext +import org.jetbrains.kotlin.ir.declarations.IrValueParameter +import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI + +internal class Specialization( + /** The argument values for this specialization. */ + val arguments: List, + + /** A string like `Decaf_Oat` with each argument value named. */ + val name: String, +) { + val isDefault: Boolean = arguments.all { it.isDefault } +} + +@UnsafeDuringIrConstructionAPI +internal fun specializations( + pluginContext: IrPluginContext, + burstApis: BurstApis, + parameters: List, +): List { + val parameterArguments = parameters.map { parameter -> + pluginContext.allPossibleArguments(parameter, burstApis) + } + + val specializations = parameterArguments.cartesianProduct().map { arguments -> + Specialization( + arguments = arguments, + name = arguments.joinToString(separator = "_", transform = Argument::name), + ) + } + + // If all elements already have distinct names, we're done. + if (specializations.distinctBy { it.name }.size == specializations.size) { + return specializations + } + + // Otherwise, prefix each with its index. + return specializations.mapIndexed { index, specialization -> + Specialization(specialization.arguments, "${index}_${specialization.name}") + } +}