Skip to content

Commit

Permalink
Allow arbitrarily serializable content within rememberSaveable
Browse files Browse the repository at this point in the history
  • Loading branch information
veyndan committed Oct 17, 2023
1 parent 9be0922 commit 8ffa644
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private class RedwoodZiplineTreehouseUi(
this.composition = composition

this.saveableStateRegistry = SaveableStateRegistry(
restoredValues = stateSnapshot?.toValuesMap(),
restoredValues = stateSnapshot?.content,
// Note: values will only be restored by SaveableStateRegistry if `canBeSaved` returns true.
// With current serialization mechanism of stateSnapshot, this field is always true, an update
// to lambda of this field might be needed when serialization mechanism of stateSnapshot
Expand All @@ -111,7 +111,7 @@ private class RedwoodZiplineTreehouseUi(

override fun snapshotState(): StateSnapshot {
val savedState = saveableStateRegistry.performSave()
return savedState.toStateSnapshot()
return StateSnapshot(savedState)
}

override fun close() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,75 +18,69 @@ package app.cash.redwood.treehouse
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import kotlin.jvm.JvmInline
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Polymorphic
import kotlinx.serialization.PolymorphicSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.booleanOrNull
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.doubleOrNull
import kotlinx.serialization.json.intOrNull

private const val MutableStateKey = "androidx.compose.runtime.MutableState"
import kotlinx.serialization.SerializationStrategy
import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic
import kotlinx.serialization.modules.subclass

@Serializable
public class StateSnapshot(
public val content: Map<String, List<JsonElement>>,
public val content: Map<String, List<@Polymorphic Any?>>,
) {
public fun toValuesMap(): Map<String, List<Any?>> {
return content.mapValues { entry ->
entry.value.map {
it.fromJsonElement()
}
}
}

@JvmInline
@Serializable
public value class Id(public val value: String?)
}

/**
* Supported types:
* String, Boolean, Int, List (of supported primitive types), Map (of supported primitive types)
*/
public fun Map<String, List<Any?>>.toStateSnapshot(): StateSnapshot = StateSnapshot(
mapValues { entry -> entry.value.map { element -> element.toJsonElement() } },
)

private fun Any?.toJsonElement(): JsonElement {
return when (this) {
is MutableState<*> -> JsonMutableState(value.toJsonElement())
is String -> JsonPrimitive(this)
is Int -> JsonPrimitive(this)
is List<*> -> JsonArray(map { it.toJsonElement() })
is JsonElement -> this
null -> JsonNull
else -> error("unexpected type: $this")
// TODO: add support to Map<*, *>
// TODO Add support for rest of built-ins serializers.
public val SaveableStateSerializersModule: SerializersModule = SerializersModule {
polymorphic(Any::class) {
subclass(Boolean::class)
subclass(Double::class)
subclass(Float::class)
subclass(Int::class)
subclass(String::class)
}
polymorphicDefaultSerializer(Any::class) { value ->
@Suppress("UNCHECKED_CAST")
when (value) {
is List<*> -> ListSerializer(PolymorphicSerializer(Any::class)) as SerializationStrategy<Any>
is MutableState<*> -> MutableStateSerializer as SerializationStrategy<Any>
else -> null
}
}
polymorphicDefaultDeserializer(Any::class) { className ->
when (className) {
"kotlin.collections.ArrayList" -> ListSerializer(PolymorphicSerializer(Any::class))
"MutableState" -> MutableStateSerializer
else -> null
}
}
}

private fun JsonElement.fromJsonElement(): Any? {
return when {
this is JsonNull -> null
this is JsonPrimitive -> {
if (this.isString) return content
return booleanOrNull ?: doubleOrNull ?: intOrNull ?: error("unexpected type: $this")
// TODO add other primitive types (float, long) when needed
}
@Serializable
@SerialName("MutableState")
private class MutableStateSurrogate(val value: @Polymorphic Any?)

this is JsonArray -> listOf({ this.forEach { it.toJsonElement() } })
this is JsonObject && containsKey(MutableStateKey) ->
mutableStateOf(getValue(MutableStateKey).fromJsonElement())
// TODO: map, numbers
// is Map<*, *> -> JsonElement
else -> error("unexpected type: $this")
private object MutableStateSerializer : KSerializer<MutableState<Any?>> {
override val descriptor = MutableStateSurrogate.serializer().descriptor

override fun serialize(encoder: Encoder, value: MutableState<Any?>) {
val surrogate = MutableStateSurrogate(value.value)
encoder.encodeSerializableValue(MutableStateSurrogate.serializer(), surrogate)
}
}

internal fun JsonMutableState(element: JsonElement): JsonObject = buildJsonObject {
put(MutableStateKey, element)
override fun deserialize(decoder: Decoder): MutableState<Any?> {
val surrogate = decoder.decodeSerializableValue(MutableStateSurrogate.serializer())
return mutableStateOf(surrogate.value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,66 +18,108 @@ package app.cash.redwood.treehouse
import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import assertk.assertThat
import assertk.assertions.containsOnly
import assertk.assertions.corresponds
import assertk.assertions.hasSize
import assertk.assertions.isEqualTo
import assertk.assertions.isInstanceOf
import assertk.assertions.isNotNull
import kotlin.test.Test
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json

class StateSnapshotTest {

@Test
fun toValueMapWorksAsExpected() {
val stateSnapshot = stateSnapshot()
val valuesMap = stateSnapshot.toValuesMap()
assertThat(valuesMap).hasSize(5)
assertThat(valuesMap["key1"]!![0])
fun stateSnapshotSerializeThenDeserialize() {
val json = Json {
prettyPrint = true
useArrayPolymorphism = true
serializersModule = SaveableStateSerializersModule
}
val stateSnapshot = StateSnapshot(
mapOf(
"key1" to listOf(mutableStateOf(1)),
"key2" to listOf(1),
"key3" to listOf(mutableStateOf("str")),
"key4" to listOf("str"),
"key5" to listOf(null),
"key6" to listOf(true),
"key7" to listOf(1.0),
"key8" to listOf(1.0f),
"key9" to listOf(listOf(1, "str")),
),
)
val serialized = json.encodeToString(stateSnapshot)
val deserialized = json.decodeFromString<StateSnapshot>(serialized)

assertThat(serialized).isEqualTo("""
{
"content": {
"key1": [
["MutableState", {
"value": ["kotlin.Int", 1
]
}
]
],
"key2": [
["kotlin.Int", 1
]
],
"key3": [
["MutableState", {
"value": ["kotlin.String", "str"
]
}
]
],
"key4": [
["kotlin.String", "str"
]
],
"key5": [
null
],
"key6": [
["kotlin.Boolean", true
]
],
"key7": [
["kotlin.Double", 1.0
]
],
"key8": [
["kotlin.Float", 1.0
]
],
"key9": [
["kotlin.collections.ArrayList", [
["kotlin.Int", 1
],
["kotlin.String", "str"
]
]
]
]
}
}
""".trimIndent())
assertThat(deserialized.content).hasSize(stateSnapshot.content.size)
assertThat(deserialized.content["key1"]!![0])
.isNotNull()
.isInstanceOf<MutableState<*>>()
.corresponds(mutableStateOf(1.0), ::mutableStateCorrespondence)
assertThat(valuesMap["key2"]).isEqualTo(listOf(1.0))
assertThat(valuesMap["key3"]!![0])
.corresponds(mutableStateOf(1), ::mutableStateCorrespondence)
assertThat(deserialized.content["key2"]).isEqualTo(listOf(1))
assertThat(deserialized.content["key3"]!![0])
.isNotNull()
.isInstanceOf<MutableState<*>>()
.corresponds(mutableStateOf("str"), ::mutableStateCorrespondence)
assertThat(valuesMap["key4"]).isEqualTo(listOf("str"))
assertThat(valuesMap["key5"]).isEqualTo(listOf(null))
assertThat(deserialized.content["key4"]).isEqualTo(listOf("str"))
assertThat(deserialized.content["key5"]).isEqualTo(listOf(null))
assertThat(deserialized.content["key6"]).isEqualTo(listOf(true))
assertThat(deserialized.content["key7"]).isEqualTo(listOf(1.0))
assertThat(deserialized.content["key8"]).isEqualTo(listOf(1.0f))
assertThat(deserialized.content["key9"]).isEqualTo(listOf(listOf(1, "str")))
}

@Test
fun toStateSnapshotWorksAsExpected() {
val storedStateSnapshot = storedStateSnapshot()
val stateSnapshot = storedStateSnapshot.toStateSnapshot()
assertThat(stateSnapshot.content).containsOnly(
"key1" to listOf(JsonMutableState(JsonPrimitive(1))),
"key2" to listOf(JsonPrimitive(1)),
"key3" to listOf(JsonMutableState(JsonPrimitive("str"))),
"key4" to listOf(JsonPrimitive("str")),
"key5" to listOf(JsonNull),
)
}

private fun stateSnapshot() = StateSnapshot(
mapOf(
"key1" to listOf(JsonMutableState(JsonPrimitive(1))),
"key2" to listOf(JsonPrimitive(1)),
"key3" to listOf(JsonMutableState(JsonPrimitive("str"))),
"key4" to listOf(JsonPrimitive("str")),
"key5" to listOf(JsonNull),
),
)

private fun storedStateSnapshot() = mapOf(
"key1" to listOf(mutableStateOf(1)),
"key2" to listOf(1),
"key3" to listOf(mutableStateOf("str")),
"key4" to listOf("str"),
"key5" to listOf(null),
)
}

private fun mutableStateCorrespondence(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import app.cash.zipline.loader.asZiplineHttpClient
import app.cash.zipline.loader.withDevelopmentServerPush
import com.example.redwood.emojisearch.launcher.EmojiSearchAppSpec
import com.example.redwood.emojisearch.treehouse.EmojiSearchPresenter
import com.example.redwood.emojisearch.treehouse.emojiSearchSerializersModule
import com.example.redwood.emojisearch.widget.EmojiSearchProtocolNodeFactory
import com.example.redwood.emojisearch.widget.EmojiSearchWidgetFactories
import com.google.android.material.snackbar.Snackbar
Expand Down Expand Up @@ -124,7 +125,10 @@ class EmojiSearchActivity : ComponentActivity() {
embeddedDir = "/".toPath(),
embeddedFileSystem = applicationContext.assets.asFileSystem(),
stateStore = FileStateStore(
json = Json,
json = Json {
useArrayPolymorphism = true
serializersModule = emojiSearchSerializersModule
},
fileSystem = FileSystem.SYSTEM,
directory = applicationContext.getDir("TreehouseState", MODE_PRIVATE).toOkioPath(),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ import app.cash.redwood.treehouse.TreehouseApp
import app.cash.zipline.Zipline
import com.example.redwood.emojisearch.treehouse.EmojiSearchPresenter
import com.example.redwood.emojisearch.treehouse.HostApi
import com.example.redwood.emojisearch.treehouse.emojiSearchSerializersModule
import kotlinx.coroutines.flow.Flow

class EmojiSearchAppSpec(
override val manifestUrl: Flow<String>,
private val hostApi: HostApi,
) : TreehouseApp.Spec<EmojiSearchPresenter>() {
override val name = "emoji-search"
override val serializersModule = emojiSearchSerializersModule

override fun bindServices(zipline: Zipline) {
zipline.bind<HostApi>("HostApi", hostApi)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright (C) 2023 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.example.redwood.emojisearch.treehouse

import app.cash.redwood.treehouse.SaveableStateSerializersModule

val emojiSearchSerializersModule = SaveableStateSerializersModule
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package com.example.redwood.emojisearch.treehouse

import app.cash.zipline.Zipline

private val zipline by lazy { Zipline.get() }
private val zipline by lazy { Zipline.get(emojiSearchSerializersModule) }

@OptIn(ExperimentalJsExport::class)
@JsExport
Expand Down

0 comments on commit 8ffa644

Please sign in to comment.