diff --git a/redwood-treehouse/src/commonMain/kotlin/app/cash/redwood/treehouse/StateSnapshot.kt b/redwood-treehouse/src/commonMain/kotlin/app/cash/redwood/treehouse/StateSnapshot.kt index 6faaefed49..7cb3f46520 100644 --- a/redwood-treehouse/src/commonMain/kotlin/app/cash/redwood/treehouse/StateSnapshot.kt +++ b/redwood-treehouse/src/commonMain/kotlin/app/cash/redwood/treehouse/StateSnapshot.kt @@ -21,23 +21,23 @@ import kotlin.jvm.JvmInline import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.booleanOrNull +import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.doubleOrNull import kotlinx.serialization.json.intOrNull +private const val MutableStateKey = "androidx.compose.runtime.MutableState" + @Serializable public class StateSnapshot( - public val content: Map>, + public val content: Map>, ) { public fun toValuesMap(): Map> { return content.mapValues { entry -> entry.value.map { - if (it.isMutableState) { - mutableStateOf(it.value.fromJsonElement()) - } else { - it.value.fromJsonElement() - } + it.fromJsonElement() } } } @@ -52,18 +52,12 @@ public class StateSnapshot( * String, Boolean, Int, List (of supported primitive types), Map (of supported primitive types) */ public fun Map>.toStateSnapshot(): StateSnapshot = StateSnapshot( - mapValues { entry -> - entry.value.map { element -> - when (element) { - is MutableState<*> -> Saveable(true, element.value.toJsonElement()) - else -> Saveable(false, element.toJsonElement()) - } - } - }, + mapValues { entry -> entry.value.map { element -> element.toJsonElement() } }, ) private fun Any?.toJsonElement(): JsonElement { return when (this) { + is MutableState<*> -> JsonMutableState(value.toJsonElement()) is String -> JsonPrimitive(this) is Int -> JsonPrimitive(this) is List<*> -> JsonArray(map { it.toJsonElement() }) @@ -74,22 +68,22 @@ private fun Any?.toJsonElement(): JsonElement { } private fun JsonElement?.fromJsonElement(): Any { - return when (this) { - is JsonPrimitive -> { + return when { + this is JsonPrimitive -> { if (this.isString) return content return booleanOrNull ?: doubleOrNull ?: intOrNull ?: error("unexpected type: $this") // TODO add other primitive types (float, long) when needed } - is JsonArray -> listOf({ this.forEach { it.toJsonElement() } }) + this is JsonArray -> listOf({ this.forEach { it.toJsonElement() } }) + this is JsonObject && containsKey(MutableStateKey) -> + mutableStateOf(getValue(MutableStateKey).fromJsonElement()) // TODO: map, numbers // is Map<*, *> -> JsonElement else -> error("unexpected type: $this") } } -@Serializable -public data class Saveable( - val isMutableState: Boolean, - val value: JsonElement, -) +internal fun JsonMutableState(element: JsonElement): JsonObject = buildJsonObject { + put(MutableStateKey, element) +} diff --git a/redwood-treehouse/src/commonTest/kotlin/app/cash/redwood/treehouse/StateSnapshotTest.kt b/redwood-treehouse/src/commonTest/kotlin/app/cash/redwood/treehouse/StateSnapshotTest.kt index 8d1bfc3987..9cdd45517f 100644 --- a/redwood-treehouse/src/commonTest/kotlin/app/cash/redwood/treehouse/StateSnapshotTest.kt +++ b/redwood-treehouse/src/commonTest/kotlin/app/cash/redwood/treehouse/StateSnapshotTest.kt @@ -47,19 +47,19 @@ class StateSnapshotTest { val storedStateSnapshot = storedStateSnapshot() val stateSnapshot = storedStateSnapshot.toStateSnapshot() assertThat(stateSnapshot.content).containsOnly( - "key1" to listOf(Saveable(true, JsonPrimitive(1))), - "key2" to listOf(Saveable(false, JsonPrimitive(1))), - "key3" to listOf(Saveable(true, JsonPrimitive("str"))), - "key4" to listOf(Saveable(false, JsonPrimitive("str"))), + "key1" to listOf(JsonMutableState(JsonPrimitive(1))), + "key2" to listOf(JsonPrimitive(1)), + "key3" to listOf(JsonMutableState(JsonPrimitive("str"))), + "key4" to listOf(JsonPrimitive("str")), ) } private fun stateSnapshot() = StateSnapshot( mapOf( - "key1" to listOf(Saveable(true, JsonPrimitive(1))), - "key2" to listOf(Saveable(false, JsonPrimitive(1))), - "key3" to listOf(Saveable(true, JsonPrimitive("str"))), - "key4" to listOf(Saveable(false, JsonPrimitive("str"))), + "key1" to listOf(JsonMutableState(JsonPrimitive(1))), + "key2" to listOf(JsonPrimitive(1)), + "key3" to listOf(JsonMutableState(JsonPrimitive("str"))), + "key4" to listOf(JsonPrimitive("str")), ), )