From b0958ad584c3c440db16400953287b4e98cfee69 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Thu, 10 Oct 2024 06:59:40 -0700 Subject: [PATCH 1/2] Rename some internal APIs in preparation for class specialization Rename 'variant' to 'specialization'. Rename 'BurstRewriter' to 'FunctionSpecializer'. Start adding new declarations in-place, rather than using a mix of side-effects and return values. Use exceptions to report syntax errors. --- .../burst/gradle/BurstGradlePluginTest.kt | 4 +- .../burst/kotlin/BurstKotlinPluginTest.kt | 28 +- .../kotlin/app/cash/burst/kotlin/Argument.kt | 63 ++++ .../kotlin/BurstIrGenerationExtension.kt | 28 +- ...urstRewriter.kt => FunctionSpecializer.kt} | 117 ++------ .../main/kotlin/app/cash/burst/kotlin/ir.kt | 275 +----------------- 6 files changed, 146 insertions(+), 369 deletions(-) create mode 100644 burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt rename burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/{BurstRewriter.kt => FunctionSpecializer.kt} (54%) diff --git a/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt b/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt index 3506724..7d01bb1 100644 --- a/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt +++ b/burst-gradle-plugin/src/test/kotlin/app/cash/burst/gradle/BurstGradlePluginTest.kt @@ -57,8 +57,8 @@ class BurstGradlePluginTest { val originalTest = testSuite.testCases.single { it.name == "test[jvm]" } assertThat(originalTest.skipped).isTrue() - val sampleVariant = testSuite.testCases.single { it.name == "test_Decaf_Oat[jvm]" } - assertThat(sampleVariant.skipped).isFalse() + val sampleSpecialization = testSuite.testCases.single { it.name == "test_Decaf_Oat[jvm]" } + assertThat(sampleSpecialization.skipped).isFalse() } private fun createRunner( diff --git a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt index 30c599f..210ab63 100644 --- a/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt +++ b/burst-kotlin-plugin-tests/src/test/kotlin/app/cash/burst/kotlin/BurstKotlinPluginTest.kt @@ -16,6 +16,7 @@ package app.cash.burst.kotlin import assertk.assertThat +import assertk.assertions.contains import assertk.assertions.containsExactly import assertk.assertions.isFalse import assertk.assertions.isTrue @@ -36,8 +37,6 @@ class BurstKotlinPluginTest { sourceFile = SourceFile.kotlin( "CoffeeTest.kt", """ - package app.cash.burst.testing - import app.cash.burst.Burst import kotlin.test.Test @@ -58,7 +57,7 @@ class BurstKotlinPluginTest { ) assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages) - val adapterClass = result.classLoader.loadClass("app.cash.burst.testing.CoffeeTest") + val adapterClass = result.classLoader.loadClass("CoffeeTest") val adapterInstance = adapterClass.constructors.single().newInstance() val log = adapterClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*> @@ -91,6 +90,29 @@ class BurstKotlinPluginTest { "running Double Oat", ) } + + @Test + fun unexpectedArgumentType() { + val result = compile( + sourceFile = SourceFile.kotlin( + "CoffeeTest.kt", + """ + import app.cash.burst.Burst + import kotlin.test.Test + + @Burst + class CoffeeTest { + @Test + fun test(espresso: String) { + } + } + """, + ), + ) + assertEquals(KotlinCompilation.ExitCode.COMPILATION_ERROR, result.exitCode, result.messages) + assertThat(result.messages) + .contains("CoffeeTest.kt:7:12 Expected an enum for @Burst test parameter") + } } @ExperimentalCompilerApi diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt new file mode 100644 index 0000000..e93efbc --- /dev/null +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2024 Cash App + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package app.cash.burst.kotlin + +import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext +import org.jetbrains.kotlin.ir.IrElement +import org.jetbrains.kotlin.ir.declarations.IrEnumEntry +import org.jetbrains.kotlin.ir.declarations.IrValueParameter +import org.jetbrains.kotlin.ir.expressions.IrExpression +import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl +import org.jetbrains.kotlin.ir.types.IrType +import org.jetbrains.kotlin.ir.types.getClass +import org.jetbrains.kotlin.ir.util.classId +import org.jetbrains.kotlin.ir.util.isEnumClass + +internal class Argument( + private val original: IrElement, + private val type: IrType, + internal val value: IrEnumEntry, +) { + /** Returns an expression that looks up this argument. */ + fun get(): IrExpression { + return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol) + } +} + +/** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */ +internal fun name( + prefix: String, + arguments: List +): String { + return arguments.joinToString( + prefix = prefix, + separator = "_", + ) { argument -> + argument.value.name.identifier + } +} + +/** Returns null if we can't compute all possible arguments for this parameter. */ +internal fun IrPluginContext.allPossibleArguments( + parameter: IrValueParameter +): List? { + val classId = parameter.type.getClass()?.classId ?: return null + val referenceClass = referenceClass(classId)?.owner ?: return null + if (!referenceClass.isEnumClass) return null + val enumEntries = referenceClass.declarations.filterIsInstance() + return enumEntries.map { Argument(parameter, parameter.type, it) } +} + diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt index 185e2b1..5d025b0 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstIrGenerationExtension.kt @@ -21,7 +21,6 @@ import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.ir.IrStatement import org.jetbrains.kotlin.ir.declarations.IrClass -import org.jetbrains.kotlin.ir.declarations.IrDeclaration import org.jetbrains.kotlin.ir.declarations.IrModuleFragment import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI import org.jetbrains.kotlin.ir.util.functions @@ -39,28 +38,31 @@ class BurstIrGenerationExtension( val classDeclaration = super.visitClassNew(declaration) as IrClass val classHasAtBurst = classDeclaration.hasAtBurst - val addedDeclarations = mutableListOf() + // Return early if there's no @Burst anywhere. + if (!classHasAtBurst && classDeclaration.functions.none { it.hasAtBurst }) { + return classDeclaration + } + + // Snapshot the original functions because the loop mutates them. + val originalFunctions = classDeclaration.functions.toList() - for (function in classDeclaration.functions) { + for (function in originalFunctions) { if (!function.hasAtTest) continue + if (!classHasAtBurst && !function.hasAtBurst) continue - if (classHasAtBurst || function.hasAtBurst) { - val rewriter = BurstRewriter( - messageCollector = messageCollector, + try { + val specializer = FunctionSpecializer( pluginContext = pluginContext, burstApis = burstApis, - file = currentFile, + originalParent = classDeclaration, original = function, ) - addedDeclarations += rewriter.rewrite() + specializer.generateSpecializations() + } catch (e: BurstCompilationException) { + messageCollector.report(e.severity, e.message, currentFile.locationOf(e.element)) } } - for (added in addedDeclarations) { - classDeclaration.declarations.add(added) - added.parent = classDeclaration - } - return classDeclaration } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstRewriter.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt similarity index 54% rename from burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstRewriter.kt rename to burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt index 4cf0541..cd88c41 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/BurstRewriter.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt @@ -17,95 +17,73 @@ package app.cash.burst.kotlin import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext import org.jetbrains.kotlin.backend.common.ir.addDispatchReceiver -import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity -import org.jetbrains.kotlin.cli.common.messages.MessageCollector import org.jetbrains.kotlin.ir.builders.declarations.buildFun import org.jetbrains.kotlin.ir.builders.irCall import org.jetbrains.kotlin.ir.builders.irGet import org.jetbrains.kotlin.ir.builders.irTemporary +import org.jetbrains.kotlin.ir.declarations.IrClass import org.jetbrains.kotlin.ir.declarations.IrDeclaration import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin -import org.jetbrains.kotlin.ir.declarations.IrEnumEntry -import org.jetbrains.kotlin.ir.declarations.IrFile import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction import org.jetbrains.kotlin.ir.declarations.IrValueParameter -import org.jetbrains.kotlin.ir.expressions.IrConstructorCall -import org.jetbrains.kotlin.ir.expressions.IrExpression -import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl -import org.jetbrains.kotlin.ir.expressions.impl.IrGetEnumValueImpl -import org.jetbrains.kotlin.ir.symbols.IrClassSymbol import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI -import org.jetbrains.kotlin.ir.types.IrType import org.jetbrains.kotlin.ir.types.classFqName -import org.jetbrains.kotlin.ir.types.getClass import org.jetbrains.kotlin.ir.types.starProjectedType -import org.jetbrains.kotlin.ir.util.classId -import org.jetbrains.kotlin.ir.util.constructors import org.jetbrains.kotlin.name.Name @OptIn(UnsafeDuringIrConstructionAPI::class) -internal class BurstRewriter( - private val messageCollector: MessageCollector, +internal class FunctionSpecializer( private val pluginContext: IrPluginContext, private val burstApis: BurstApis, - private val file: IrFile, + private val originalParent: IrClass, private val original: IrSimpleFunction, ) { - /** Returns a list of additional declarations. */ - fun rewrite(): List { + fun generateSpecializations() { val originalValueParameters = original.valueParameters - if (originalValueParameters.isEmpty()) { - return listOf() - } + if (originalValueParameters.isEmpty()) return // Nothing to do. val originalDispatchReceiver = original.dispatchReceiverParameter - if (originalDispatchReceiver == null) { - messageCollector.report( - CompilerMessageSeverity.ERROR, - "Unexpected dispatch receiver", - file.locationOf(original), - ) - return listOf() - } + ?: throw BurstCompilationException("Unexpected dispatch receiver", original) val parameterArguments = mutableListOf>() for (parameter in originalValueParameters) { - val expanded = parameter.allPossibleArguments() - if (expanded == null) { - messageCollector.report( - CompilerMessageSeverity.ERROR, - "Expected an enum for @Burst test parameter", - file.locationOf(parameter), - ) - return listOf() - } + val expanded = pluginContext.allPossibleArguments(parameter) + ?: throw BurstCompilationException("Expected an enum for @Burst test parameter", parameter) parameterArguments += expanded } val cartesianProduct = parameterArguments.cartesianProduct() - val variants = cartesianProduct.map { variantArguments -> - createVariant(originalDispatchReceiver, variantArguments) + val specializations = cartesianProduct.map { arguments -> + createSpecialization(originalDispatchReceiver, arguments) } - // Side-effect: drop `@Test` from the original's annotations. + // Drop `@Test` from the original's annotations. original.annotations = original.annotations.filter { it.type.classFqName != burstApis.testClassSymbol.starProjectedType.classFqName } - val result = mutableListOf() - result += createFunctionThatCallsAllVariants(originalDispatchReceiver, variants) - result += variants - return result + // Add new declarations. + for (specialization in specializations) { + originalParent.addDeclaration(specialization) + } + originalParent.addDeclaration( + createFunctionThatCallsAllSpecializations(originalDispatchReceiver, specializations) + ) + } + + private fun IrClass.addDeclaration(declaration: IrDeclaration) { + declarations.add(declaration) + declaration.parent = this } - private fun createVariant( + private fun createSpecialization( originalDispatchReceiver: IrValueParameter, arguments: List, ): IrSimpleFunction { val result = original.factory.buildFun { initDefaults(original) - name = Name.identifier(name(arguments)) + name = Name.identifier(name("${original.name.identifier}_", arguments)) returnType = original.returnType }.apply { addDispatchReceiver { @@ -141,10 +119,10 @@ internal class BurstRewriter( return result } - /** Creates a function with no arguments that calls each variant. */ - private fun createFunctionThatCallsAllVariants( + /** Creates an @Test @Ignore no-args function that calls each specialization. */ + private fun createFunctionThatCallsAllSpecializations( originalDispatchReceiver: IrValueParameter, - variants: List, + specializations: List, ): IrSimpleFunction { val result = original.factory.buildFun { initDefaults(original) @@ -172,9 +150,9 @@ internal class BurstRewriter( origin = IrDeclarationOrigin.DEFINED } - for (variant in variants) { + for (specialization in specializations) { +irCall( - callee = variant.symbol, + callee = specialization.symbol, ).apply { this.dispatchReceiver = irGet(receiverLocal) } @@ -183,39 +161,4 @@ internal class BurstRewriter( return result } - - private inner class Argument( - val type: IrType, - val value: IrEnumEntry, - ) - - /** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */ - private fun name(arguments: List): String { - return arguments.joinToString( - prefix = "${original.name.identifier}_", - separator = "_", - ) { argument -> - argument.value.name.identifier - } - } - - /** Returns an expression that looks up this argument. */ - private fun Argument.get(): IrExpression { - return IrGetEnumValueImpl(original.startOffset, original.endOffset, type, value.symbol) - } - - /** Returns null if we can't compute all possible arguments for this parameter. */ - private fun IrValueParameter.allPossibleArguments(): List? { - val classId = type.getClass()?.classId ?: return null - val referenceClass = pluginContext.referenceClass(classId)?.owner ?: return null - val enumEntries = referenceClass.declarations.filterIsInstance() - return enumEntries.map { Argument(type, it) } - } - - private fun IrClassSymbol.asAnnotation(): IrConstructorCall { - return IrConstructorCallImpl.fromSymbolOwner( - type = starProjectedType, - constructorSymbol = constructors.single(), - ) - } } diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ir.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ir.kt index 66861c5..87b37dc 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ir.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/ir.kt @@ -15,73 +15,43 @@ */ package app.cash.burst.kotlin -import org.jetbrains.kotlin.backend.common.ScopeWithIr -import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder import org.jetbrains.kotlin.cli.common.messages.CompilerMessageLocationWithRange import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSeverity import org.jetbrains.kotlin.cli.common.messages.CompilerMessageSourceLocation -import org.jetbrains.kotlin.descriptors.ClassKind import org.jetbrains.kotlin.descriptors.DescriptorVisibilities import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.ir.IrElement import org.jetbrains.kotlin.ir.IrStatement import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET import org.jetbrains.kotlin.ir.builders.IrBlockBodyBuilder -import org.jetbrains.kotlin.ir.builders.IrBlockBuilder -import org.jetbrains.kotlin.ir.builders.IrBuilderWithScope import org.jetbrains.kotlin.ir.builders.IrGeneratorContext import org.jetbrains.kotlin.ir.builders.Scope import org.jetbrains.kotlin.ir.builders.declarations.IrClassBuilder import org.jetbrains.kotlin.ir.builders.declarations.IrFunctionBuilder import org.jetbrains.kotlin.ir.builders.declarations.IrValueParameterBuilder -import org.jetbrains.kotlin.ir.builders.declarations.addConstructor -import org.jetbrains.kotlin.ir.builders.declarations.buildClass -import org.jetbrains.kotlin.ir.builders.irCall -import org.jetbrains.kotlin.ir.builders.irGet -import org.jetbrains.kotlin.ir.builders.irGetField -import org.jetbrains.kotlin.ir.builders.irReturn -import org.jetbrains.kotlin.ir.declarations.IrClass import org.jetbrains.kotlin.ir.declarations.IrConstructor import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin import org.jetbrains.kotlin.ir.declarations.IrFile -import org.jetbrains.kotlin.ir.declarations.IrProperty import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction import org.jetbrains.kotlin.ir.declarations.createBlockBody +import org.jetbrains.kotlin.ir.expressions.IrConstructorCall import org.jetbrains.kotlin.ir.expressions.IrDelegatingConstructorCall -import org.jetbrains.kotlin.ir.expressions.IrExpression -import org.jetbrains.kotlin.ir.expressions.IrExpressionBody -import org.jetbrains.kotlin.ir.expressions.IrInstanceInitializerCall -import org.jetbrains.kotlin.ir.expressions.IrMemberAccessExpression -import org.jetbrains.kotlin.ir.expressions.IrReturn -import org.jetbrains.kotlin.ir.expressions.impl.IrClassReferenceImpl +import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl import org.jetbrains.kotlin.ir.expressions.impl.IrDelegatingConstructorCallImpl -import org.jetbrains.kotlin.ir.expressions.impl.IrInstanceInitializerCallImpl -import org.jetbrains.kotlin.ir.expressions.impl.IrReturnImpl import org.jetbrains.kotlin.ir.symbols.IrClassSymbol import org.jetbrains.kotlin.ir.symbols.IrConstructorSymbol -import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol -import org.jetbrains.kotlin.ir.symbols.IrPropertySymbol -import org.jetbrains.kotlin.ir.symbols.IrReturnTargetSymbol import org.jetbrains.kotlin.ir.symbols.IrSymbol -import org.jetbrains.kotlin.ir.symbols.impl.IrFieldSymbolImpl -import org.jetbrains.kotlin.ir.symbols.impl.IrPropertySymbolImpl import org.jetbrains.kotlin.ir.symbols.impl.IrSimpleFunctionSymbolImpl -import org.jetbrains.kotlin.ir.types.IrType -import org.jetbrains.kotlin.ir.types.typeWith +import org.jetbrains.kotlin.ir.types.starProjectedType import org.jetbrains.kotlin.ir.util.SYNTHETIC_OFFSET -import org.jetbrains.kotlin.ir.util.companionObject import org.jetbrains.kotlin.ir.util.constructors -import org.jetbrains.kotlin.ir.util.createDispatchReceiverParameter -import org.jetbrains.kotlin.ir.util.createImplicitParameterDeclarationWithWrappedDescriptor -import org.jetbrains.kotlin.ir.util.defaultType import org.jetbrains.kotlin.name.Name -import org.jetbrains.kotlin.name.StandardClassIds /** Thrown on invalid or unexpected input code. */ class BurstCompilationException( override val message: String, - val element: IrElement? = null, + val element: IrElement, val severity: CompilerMessageSeverity = CompilerMessageSeverity.ERROR, ) : Exception(message) @@ -101,21 +71,6 @@ fun IrFile.locationOf(irElement: IrElement?): CompilerMessageSourceLocation { )!! } -/** `return ...` */ -internal fun IrBuilderWithScope.irReturn( - value: IrExpression, - returnTargetSymbol: IrReturnTargetSymbol, - type: IrType = value.type, -): IrReturn { - return IrReturnImpl( - startOffset = startOffset, - endOffset = endOffset, - type = type, - returnTargetSymbol = returnTargetSymbol, - value = value, - ) -} - /** Set up reasonable defaults for a generated function or constructor. */ fun IrFunctionBuilder.initDefaults(original: IrElement) { this.startOffset = original.startOffset.toSyntheticIfUnknown() @@ -140,6 +95,13 @@ fun IrValueParameterBuilder.initDefaults(original: IrElement) { this.endOffset = original.endOffset.toSyntheticIfUnknown() } +fun IrClassSymbol.asAnnotation(): IrConstructorCall { + return IrConstructorCallImpl.fromSymbolOwner( + type = starProjectedType, + constructorSymbol = constructors.single(), + ) +} + /** * When we generate code based on classes outside of the current module unit we get elements that * use `UNDEFINED_OFFSET`. Make sure we don't propagate this further into generated code; that @@ -189,18 +151,6 @@ fun DeclarationIrBuilder.irDelegatingConstructorCall( return result } -fun DeclarationIrBuilder.irInstanceInitializerCall( - context: IrGeneratorContext, - classSymbol: IrClassSymbol, -): IrInstanceInitializerCall { - return IrInstanceInitializerCallImpl( - startOffset = startOffset, - endOffset = endOffset, - classSymbol = classSymbol, - type = context.irBuiltIns.unitType, - ) -} - fun IrSimpleFunction.irFunctionBody( context: IrGeneratorContext, scopeOwnerSymbol: IrSymbol, @@ -216,206 +166,3 @@ fun IrSimpleFunction.irFunctionBody( blockBody() } } - -/** Create a private val with a backing field and an accessor function. */ -fun irVal( - pluginContext: IrPluginContext, - propertyType: IrType, - declaringClass: IrClass, - propertyName: Name, - overriddenProperty: IrPropertySymbol? = null, - initializer: IrBlockBuilder.() -> IrExpressionBody, -): IrProperty { - val irFactory = pluginContext.irFactory - val result = irFactory.createProperty( - startOffset = declaringClass.startOffset, - endOffset = declaringClass.endOffset, - origin = IrDeclarationOrigin.DEFINED, - symbol = IrPropertySymbolImpl(), - name = propertyName, - visibility = overriddenProperty?.owner?.visibility ?: DescriptorVisibilities.PRIVATE, - modality = Modality.FINAL, - isVar = false, - isConst = false, - isLateinit = false, - isDelegated = false, - isExternal = false, - isExpect = false, - isFakeOverride = false, - containerSource = null, - ).apply { - overriddenSymbols = listOfNotNull(overriddenProperty) - parent = declaringClass - } - - result.backingField = irFactory.createField( - startOffset = declaringClass.startOffset, - endOffset = declaringClass.endOffset, - origin = IrDeclarationOrigin.PROPERTY_BACKING_FIELD, - symbol = IrFieldSymbolImpl(), - name = result.name, - type = propertyType, - visibility = DescriptorVisibilities.PRIVATE, - isFinal = true, - isExternal = false, - isStatic = false, - ).apply { - parent = declaringClass - correspondingPropertySymbol = result.symbol - val initializerBuilder = IrBlockBuilder( - startOffset = declaringClass.startOffset, - endOffset = declaringClass.endOffset, - context = pluginContext, - scope = Scope(symbol), - ) - this.initializer = initializerBuilder.initializer() - } - - result.getter = irFactory.createSimpleFunction( - startOffset = declaringClass.startOffset, - endOffset = declaringClass.endOffset, - origin = IrDeclarationOrigin.DEFAULT_PROPERTY_ACCESSOR, - name = Name.special(""), - visibility = overriddenProperty?.owner?.getter?.visibility ?: DescriptorVisibilities.PRIVATE, - isInline = false, - isExpect = false, - returnType = propertyType, - modality = Modality.FINAL, - symbol = IrSimpleFunctionSymbolImpl(), - isTailrec = false, - isSuspend = false, - isOperator = false, - isInfix = false, - isExternal = false, - containerSource = null, - isFakeOverride = false, - ).apply { - parent = declaringClass - correspondingPropertySymbol = result.symbol - overriddenSymbols = listOfNotNull(overriddenProperty?.owner?.getter?.symbol) - createDispatchReceiverParameter() - irFunctionBody( - context = pluginContext, - scopeOwnerSymbol = symbol, - ) { - +irReturn( - value = irGetField( - irGet(dispatchReceiverParameter!!), - result.backingField!!, - ), - ) - } - } - - return result -} - -internal fun IrBuilderWithScope.irKClass( - containerClass: IrClass, -): IrClassReferenceImpl { - return IrClassReferenceImpl( - startOffset = startOffset, - endOffset = endOffset, - type = context.irBuiltIns.kClassClass.typeWith(containerClass.defaultType), - symbol = containerClass.symbol, - classType = containerClass.defaultType, - ) -} - -fun irBlockBodyBuilder( - irPluginContext: IrGeneratorContext, - scopeWithIr: ScopeWithIr, - original: IrElement, -): IrBlockBodyBuilder { - return IrBlockBodyBuilder( - context = irPluginContext, - scope = scopeWithIr.scope, - startOffset = original.startOffset.toSyntheticIfUnknown(), - endOffset = original.endOffset.toSyntheticIfUnknown(), - ) -} - -fun irBlockBuilder( - irPluginContext: IrGeneratorContext, - scopeWithIr: ScopeWithIr, - original: IrElement, -): IrBlockBuilder { - return IrBlockBuilder( - context = irPluginContext, - scope = scopeWithIr.scope, - startOffset = original.startOffset.toSyntheticIfUnknown(), - endOffset = original.endOffset.toSyntheticIfUnknown(), - ) -} - -/** This creates `companion object` if it doesn't exist already. */ -fun getOrCreateCompanion( - enclosing: IrClass, - irPluginContext: IrPluginContext, -): IrClass { - val existing = enclosing.companionObject() - if (existing != null) return existing - - val irFactory = irPluginContext.irFactory - val anyType = irPluginContext.referenceClass(StandardClassIds.Any)!! - val companionClass = irFactory.buildClass { - initDefaults(enclosing) - name = Name.identifier("Companion") - visibility = DescriptorVisibilities.PUBLIC - kind = ClassKind.OBJECT - isCompanion = true - }.apply { - parent = enclosing - superTypes = listOf(irPluginContext.irBuiltIns.anyType) - createImplicitParameterDeclarationWithWrappedDescriptor() - } - - companionClass.addConstructor { - initDefaults(enclosing) - visibility = DescriptorVisibilities.PRIVATE - }.apply { - irConstructorBody(irPluginContext) { statements -> - statements += irDelegatingConstructorCall( - context = irPluginContext, - symbol = anyType.constructors.single(), - ) - statements += irInstanceInitializerCall( - context = irPluginContext, - classSymbol = companionClass.symbol, - ) - } - } - - enclosing.declarations.add(companionClass) - return companionClass -} - -// https://github.com/JetBrains/kotlin/blob/d625d9a988f3a7a344ce1687b085ff7c811e916c/plugins/kotlinx-serialization/kotlinx-serialization.backend/src/org/jetbrains/kotlinx/serialization/compiler/backend/ir/IrBuilderWithPluginContext.kt#L199-L211 -fun IrBuilderWithScope.irInvoke( - dispatchReceiver: IrExpression? = null, - callee: IrFunctionSymbol, - vararg args: IrExpression, - typeHint: IrType? = null, -): IrMemberAccessExpression<*> { - assert(callee.isBound) { "Symbol $callee expected to be bound" } - val returnType = typeHint ?: callee.owner.returnType - val call = irCall(callee, type = returnType) - call.dispatchReceiver = dispatchReceiver - args.forEachIndexed(call::putValueArgument) - return call -} - -// https://github.com/JetBrains/kotlin/blob/d625d9a988f3a7a344ce1687b085ff7c811e916c/plugins/kotlinx-serialization/kotlinx-serialization.backend/src/org/jetbrains/kotlinx/serialization/compiler/backend/ir/IrBuilderWithPluginContext.kt#L213-L225 -fun IrBuilderWithScope.irInvoke( - dispatchReceiver: IrExpression? = null, - callee: IrFunctionSymbol, - typeArguments: List, - valueArguments: List, - returnTypeHint: IrType? = null, -): IrMemberAccessExpression<*> = - irInvoke( - dispatchReceiver, - callee, - *valueArguments.toTypedArray(), - typeHint = returnTypeHint, - ).also { call -> typeArguments.forEachIndexed(call::putTypeArgument) } From 5dd10447be12bd81ac9e196ff7908dd7ec701f53 Mon Sep 17 00:00:00 2001 From: Jesse Wilson Date: Thu, 10 Oct 2024 07:05:19 -0700 Subject: [PATCH 2/2] Spotless --- .../src/main/kotlin/app/cash/burst/kotlin/Argument.kt | 5 ++--- .../main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt index e93efbc..07a87f5 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/Argument.kt @@ -40,7 +40,7 @@ internal class Argument( /** Returns a name like `orderCoffee_Decaf_Oat` with each argument value inline. */ internal fun name( prefix: String, - arguments: List + arguments: List, ): String { return arguments.joinToString( prefix = prefix, @@ -52,7 +52,7 @@ internal fun name( /** Returns null if we can't compute all possible arguments for this parameter. */ internal fun IrPluginContext.allPossibleArguments( - parameter: IrValueParameter + parameter: IrValueParameter, ): List? { val classId = parameter.type.getClass()?.classId ?: return null val referenceClass = referenceClass(classId)?.owner ?: return null @@ -60,4 +60,3 @@ internal fun IrPluginContext.allPossibleArguments( val enumEntries = referenceClass.declarations.filterIsInstance() return enumEntries.map { Argument(parameter, parameter.type, it) } } - diff --git a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt index cd88c41..441a5ee 100644 --- a/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt +++ b/burst-kotlin-plugin/src/main/kotlin/app/cash/burst/kotlin/FunctionSpecializer.kt @@ -68,7 +68,7 @@ internal class FunctionSpecializer( originalParent.addDeclaration(specialization) } originalParent.addDeclaration( - createFunctionThatCallsAllSpecializations(originalDispatchReceiver, specializations) + createFunctionThatCallsAllSpecializations(originalDispatchReceiver, specializations), ) }