Skip to content

Commit

Permalink
Merge pull request #17 from cashapp/jwilson.1010.specialize_classes
Browse files Browse the repository at this point in the history
Implement class-level specialization
  • Loading branch information
squarejesse authored Oct 10, 2024
2 parents ff2d4fc + a69e3f1 commit d37647f
Show file tree
Hide file tree
Showing 13 changed files with 430 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package app.cash.burst.gradle
import assertk.assertThat
import assertk.assertions.contains
import assertk.assertions.containsExactlyInAnyOrder
import assertk.assertions.isEqualTo
import assertk.assertions.isFalse
import assertk.assertions.isTrue
import java.io.File
Expand Down Expand Up @@ -95,6 +96,35 @@ class BurstGradlePluginTest {
assertThat(sampleVariant.skipped).isFalse()
}

@Test
fun classParameters() {
val projectDir = File("src/test/projects/classParameters")

val taskName = ":lib:test"
val result = createRunner(projectDir, "clean", taskName).build()
assertThat(SUCCESS_OUTCOMES)
.contains(result.task(taskName)!!.outcome)

val testResults = projectDir.resolve("lib/build/test-results")

val coffeeTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest.xml"))
val coffeeTestTest = coffeeTest.testCases.single()
assertThat(coffeeTestTest.name).isEqualTo("test")
assertThat(coffeeTestTest.skipped).isTrue()

val sampleTest = readTestSuite(testResults.resolve("test/TEST-CoffeeTest_Decaf_None.xml"))
val sampleTestTest = sampleTest.testCases.single()
assertThat(sampleTestTest.name).isEqualTo("test")
assertThat(sampleTestTest.skipped).isFalse()
assertThat(sampleTest.systemOut).isEqualTo(
"""
|set up Decaf None
|running Decaf None
|
""".trimMargin(),
)
}

private fun createRunner(
projectDir: File,
vararg taskNames: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@ fun readTestSuite(xmlFile: File): TestSuite {

internal fun Element.toTestSuite(): TestSuite {
val testCases = mutableListOf<TestCase>()
val systemOut = StringBuilder()
for (i in 0 until childNodes.length) {
val item = childNodes.item(i)
if (item !is Element || item.tagName != "testcase") continue
testCases += item.toTestCase()
if (item is Element && item.tagName == "testcase") {
testCases += item.toTestCase()
}
if (item is Element && item.tagName == "system-out") {
systemOut.append(item.textContent)
}
}

return TestSuite(
name = getAttribute("name"),
testCases = testCases,
systemOut = systemOut.toString(),
)
}

Expand All @@ -57,6 +63,7 @@ internal fun Element.toTestCase(): TestCase {
class TestSuite(
val name: String,
val testCases: List<TestCase>,
val systemOut: String,
)

class TestCase(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
buildscript {
repositories {
maven {
url = file("$rootDir/../../../../../build/testMaven").toURI()
}
mavenCentral()
google()
}
dependencies {
classpath("app.cash.burst:burst-gradle-plugin:${project.property("burstVersion")}")
classpath(libs.kotlin.gradle.plugin)
}
}

allprojects {
repositories {
maven {
url = file("$rootDir/../../../../../build/testMaven").toURI()
}
mavenCentral()
google()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
plugins {
kotlin("jvm")
id("app.cash.burst")
}

dependencies {
testImplementation(kotlin("test"))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import app.cash.burst.Burst
import kotlin.test.BeforeTest
import kotlin.test.Test

@Burst
class CoffeeTest(
private val espresso: Espresso,
private val dairy: Dairy,
) {
@BeforeTest
fun setUp() {
println("set up $espresso $dairy")
}

@Test
fun test() {
println("running $espresso $dairy")
}
}

enum class Espresso { Decaf, Regular, Double }

enum class Dairy { None, Milk, Oat }
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
dependencyResolutionManagement {
versionCatalogs {
create("libs") {
from(files("../../../../../gradle/libs.versions.toml"))
}
}
}

include(":lib")
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ import kotlin.test.Test

@Burst
class CoffeeTest {
val log = mutableListOf<String>()

@Test
fun test(espresso: Espresso, dairy: Dairy) {
log += "running $espresso $dairy"
println("running $espresso $dairy")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ import kotlin.test.Test

@Burst
class CoffeeTest {
val log = mutableListOf<String>()

@Test
fun test(espresso: Espresso, dairy: Dairy) {
log += "running $espresso $dairy"
println("running $espresso $dairy")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import assertk.assertions.isTrue
import com.tschuchort.compiletesting.JvmCompilationResult
import com.tschuchort.compiletesting.KotlinCompilation
import com.tschuchort.compiletesting.SourceFile
import java.lang.reflect.Modifier
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertEquals
Expand Down Expand Up @@ -57,24 +58,28 @@ class BurstKotlinPluginTest {
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val adapterClass = result.classLoader.loadClass("CoffeeTest")
val adapterInstance = adapterClass.constructors.single().newInstance()
val log = adapterClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*>
val testClass = result.classLoader.loadClass("CoffeeTest")

// Burst doesn't make the class non-final as it has no reason to.
assertThat(Modifier.isFinal(testClass.modifiers)).isTrue()

val adapterInstance = testClass.constructors.single().newInstance()
val log = testClass.getMethod("getLog").invoke(adapterInstance) as MutableList<*>

// Burst drops @Test from the original test.
val originalTest = adapterClass.methods.single { it.name == "test" && it.parameterCount == 2 }
val originalTest = testClass.methods.single { it.name == "test" && it.parameterCount == 2 }
assertThat(originalTest.isAnnotationPresent(Test::class.java)).isFalse()

// Burst adds a variant for each combination of parameters.
val sampleVariant = adapterClass.getMethod("test_Decaf_None")
assertThat(sampleVariant.isAnnotationPresent(Test::class.java)).isTrue()
assertThat(sampleVariant.isAnnotationPresent(Ignore::class.java)).isFalse()
sampleVariant.invoke(adapterInstance)
// Burst adds a specialization for each combination of parameters.
val sampleFunction = testClass.getMethod("test_Decaf_None")
assertThat(sampleFunction.isAnnotationPresent(Test::class.java)).isTrue()
assertThat(sampleFunction.isAnnotationPresent(Ignore::class.java)).isFalse()
sampleFunction.invoke(adapterInstance)
assertThat(log).containsExactly("running Decaf None")
log.clear()

// Burst adds a no-parameter function that calls each variant in sequence.
val noArgsTest = adapterClass.getMethod("test")
// Burst adds a no-parameter function that calls each specialization in sequence.
val noArgsTest = testClass.getMethod("test")
assertThat(noArgsTest.isAnnotationPresent(Test::class.java)).isTrue()
assertThat(noArgsTest.isAnnotationPresent(Ignore::class.java)).isTrue()
noArgsTest.invoke(adapterInstance)
Expand Down Expand Up @@ -113,6 +118,77 @@ class BurstKotlinPluginTest {
assertThat(result.messages)
.contains("CoffeeTest.kt:7:12 Expected an enum for @Burst test parameter")
}

@Test
fun constructorParameters() {
val result = compile(
sourceFile = SourceFile.kotlin(
"CoffeeTest.kt",
"""
import app.cash.burst.Burst
import kotlin.test.BeforeTest
import kotlin.test.Test
@Burst
class CoffeeTest(
private val espresso: Espresso,
private val dairy: Dairy,
) {
val log = mutableListOf<String>()
@BeforeTest
fun setUp() {
log += "set up ${'$'}espresso ${'$'}dairy"
}
@Test
fun test() {
log += "running ${'$'}espresso ${'$'}dairy"
}
}
enum class Espresso { Decaf, Regular, Double }
enum class Dairy { None, Milk, Oat }
""",
),
)
assertEquals(KotlinCompilation.ExitCode.OK, result.exitCode, result.messages)

val baseClass = result.classLoader.loadClass("CoffeeTest")

// Burst opens the class because it needs to subclass it. And it marks the entire class @Ignore.
assertThat(Modifier.isFinal(baseClass.modifiers)).isFalse()
assertThat(baseClass.isAnnotationPresent(Ignore::class.java)).isTrue()

// Burst adds a no-args constructor that binds the first enum value.
val baseConstructor = baseClass.constructors.single { it.parameterCount == 0 }
val baseInstance = baseConstructor.newInstance()
val baseLog = baseClass.getMethod("getLog").invoke(baseInstance) as MutableList<*>

// The setUp function gets the first value of each parameter.
baseClass.getMethod("setUp").invoke(baseInstance)
assertThat(baseLog).containsExactly("set up Decaf None")
baseLog.clear()

// The test function gets the same.
baseClass.getMethod("test").invoke(baseInstance)
assertThat(baseLog).containsExactly("running Decaf None")
baseLog.clear()

// It generates a subclass for each specialization.
val sampleClass = result.classLoader.loadClass("CoffeeTest_Regular_Oat")
val sampleConstructor = sampleClass.getConstructor()
val sampleInstance = sampleConstructor.newInstance()
val sampleLog = sampleClass.getMethod("getLog")
.invoke(sampleInstance) as MutableList<*>
sampleClass.getMethod("setUp").invoke(sampleInstance)
sampleClass.getMethod("test").invoke(sampleInstance)
assertThat(sampleLog).containsExactly(
"set up Regular Oat",
"running Regular Oat",
)
sampleLog.clear()
}
}

@ExperimentalCompilerApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ class BurstIrGenerationExtension(
return classDeclaration
}

if (classHasAtBurst) {
ClassSpecializer(
pluginContext = pluginContext,
burstApis = burstApis,
originalParent = currentFile,
original = classDeclaration,
).generateSpecializations()
}

// Snapshot the original functions because the loop mutates them.
val originalFunctions = classDeclaration.functions.toList()

Expand Down
Loading

0 comments on commit d37647f

Please sign in to comment.