Skip to content

Commit

Permalink
feat: android async backups
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasonvdb committed Nov 29, 2023
1 parent 1c54fb3 commit 92fef96
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 21 deletions.
5 changes: 3 additions & 2 deletions lib/android/src/main/java/com/reactnativeldk/LdkModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class LdkModule(reactContext: ReactApplicationContext) : ReactContextBaseJavaMod
private val channelManagerPersister: LdkChannelManagerPersister by lazy { LdkChannelManagerPersister() }

//Config required to setup below objects
private var chainMonitor: ChainMonitor? = null
private var keysManager: KeysManager? = null
private var channelManager: ChannelManager? = null
private var userConfig: UserConfig? = null
Expand All @@ -173,6 +172,8 @@ class LdkModule(reactContext: ReactApplicationContext) : ReactContextBaseJavaMod
companion object {
lateinit var accountStoragePath: String
lateinit var channelStoragePath: String

var chainMonitor: ChainMonitor? = null
}

init {
Expand Down Expand Up @@ -1023,7 +1024,7 @@ class LdkModule(reactContext: ReactApplicationContext) : ReactContextBaseJavaMod
}

if (remotePersist) {
BackupClient.persist(BackupClient.Label.MISC(fileName), file.readBytes())
BackupClient.addToPersistQueue(BackupClient.Label.MISC(fileName), file.readBytes())
}

handleResolve(promise, LdkCallbackResponses.file_write_success)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package com.reactnativeldk.classes
import com.reactnativeldk.EventTypes
import com.reactnativeldk.LdkErrors
import com.reactnativeldk.LdkEventEmitter
import com.reactnativeldk.handleReject
import com.reactnativeldk.hexEncodedString
import com.reactnativeldk.hexa
import org.json.JSONObject
Expand All @@ -13,24 +11,22 @@ import java.net.URL
import java.security.MessageDigest
import java.security.SecureRandom
import java.util.Random
import java.util.UUID
import java.util.concurrent.locks.ReentrantLock
import javax.crypto.Cipher
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.IvParameterSpec
import javax.crypto.spec.SecretKeySpec

class BackupError : Exception() {
companion object {
val requiresSetup = RequiresSetup()
val missingBackup = MissingBackup()
val invalidServerResponse = InvalidServerResponse(0)
val decryptFailed = DecryptFailed("")
val signingError = SigningError()
val serverChallengeResponseFailed = ServerChallengeResponseFailed()
val checkError = BackupCheckError()
}
}

class InvalidNetwork() : Exception("Invalid network passed to BackupClient setup")
class RequiresSetup() : Exception("BackupClient requires setup")
class MissingBackup() : Exception("Retrieve failed. Missing backup.")
class InvalidServerResponse(code: Int) : Exception("Invalid backup server response ($code)")
Expand All @@ -44,6 +40,15 @@ class CompleteBackup(
val channelFiles: Map<String, ByteArray>
)

typealias BackupCompleteCallback = () -> Unit

class BackupQueueEntry(
val uuid: UUID,
val label: BackupClient.Label,
val bytes: ByteArray,
val callback: BackupCompleteCallback? = null
)

class BackupClient {
sealed class Label(val string: String, channelId: String = "") {
data class PING(val customName: String = "") : Label("ping")
Expand All @@ -53,6 +58,13 @@ class BackupClient {

data class MISC(val customName: String) :
Label(customName.replace(".json", "").replace(".bin", ""))

val queueId: String
get() = when (this) {
is CHANNEL_MONITOR -> "$string-$channelId"
is MISC -> "$string-$customName"
else -> string
}
}

companion object {
Expand All @@ -72,6 +84,9 @@ class BackupClient {
private var version = "v1"
private var signedMessagePrefix = "react-native-ldk backup server auth:"

private var persistQueues: HashMap<String, ArrayList<BackupQueueEntry>> = HashMap()
private var persistQueuesLock: HashMap<String, Boolean> = HashMap()

var skipRemoteBackup = true //Allow dev to opt out (for development), will not throw error when attempting to persist

private var network: String? = null
Expand Down Expand Up @@ -175,7 +190,7 @@ class BackupClient {
}

@Throws(BackupError::class)
fun persist(label: Label, bytes: ByteArray, retry: Int) {
private fun persist(label: Label, bytes: ByteArray, retry: Int) {
var attempts = 0
var persistError: Exception? = null

Expand Down Expand Up @@ -205,7 +220,7 @@ class BackupClient {
}

@Throws(BackupError::class)
fun persist(label: Label, bytes: ByteArray) {
private fun persist(label: Label, bytes: ByteArray) {
if (skipRemoteBackup) {
return
}
Expand Down Expand Up @@ -491,5 +506,69 @@ class BackupClient {
cachedBearer = CachedBearer(bearer, expires)
return bearer
}

//Backup queue management
fun addToPersistQueue(label: BackupClient.Label, bytes: ByteArray, callback: (() -> Unit)? = null) {
if (BackupClient.skipRemoteBackup) {
callback?.invoke()
LdkEventEmitter.send(
EventTypes.native_log,
"Skipping remote backup queue append for ${label.string}"
)
return
}

persistQueues[label.queueId] = persistQueues[label.queueId] ?: ArrayList()
persistQueues[label.queueId]!!.add(BackupQueueEntry(UUID.randomUUID(), label, bytes, callback))

processPersistNextInQueue(label)
}

private val backupQueueLock = ReentrantLock()
private fun processPersistNextInQueue(label: Label) {
//Check if queue is locked, if not lock it and process next in queue
var backupEntry: BackupQueueEntry? = null
backupQueueLock.lock()
try {
if (persistQueuesLock[label.queueId] == true) {
return
}

persistQueuesLock[label.queueId] = true

backupEntry = persistQueues[label.queueId]?.firstOrNull()
if (backupEntry == null) {
persistQueuesLock[label.queueId] = false
return
}
} finally {
backupQueueLock.unlock()
}

Thread {
try {
persist(backupEntry!!.label, backupEntry.bytes, 10)
backupEntry.callback?.invoke()
} catch (e: Exception) {
LdkEventEmitter.send(
EventTypes.native_log,
"Remote persist failed for ${label.string} with error ${e.message}"
)
} finally {
backupQueueLock.lock()
try {
persistQueues[label.queueId]?.remove(backupEntry)
persistQueuesLock[label.queueId] = false
} finally {
backupQueueLock.unlock()
}

processPersistNextInQueue(label)
}
}.start()
}
}
}



Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,13 @@ class LdkChannelManagerPersister: ChannelManagerConstructor.EventHandler {

override fun persist_manager(channel_manager_bytes: ByteArray?) {
if (channel_manager_bytes != null && LdkModule.accountStoragePath != "") {
BackupClient.persist(BackupClient.Label.CHANNEL_MANAGER(), channel_manager_bytes, retry = 100)
BackupClient.addToPersistQueue(BackupClient.Label.CHANNEL_MANAGER(), channel_manager_bytes) {
LdkEventEmitter.send(EventTypes.native_log, "Remote persisted channel manager to disk")
}

File(LdkModule.accountStoragePath + "/" + LdkFileNames.channel_manager.fileName).writeBytes(channel_manager_bytes)

LdkEventEmitter.send(EventTypes.native_log, "Persisted channel manager to disk")
LdkEventEmitter.send(EventTypes.native_log, "Locally persisted channel manager to disk")
LdkEventEmitter.send(EventTypes.backup, "")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import org.ldk.structs.Persist.PersistInterface
import java.io.File

class LdkPersister {
fun handleChannel(id: OutPoint, data: ChannelMonitor): ChannelMonitorUpdateStatus {
fun handleChannel(id: OutPoint, data: ChannelMonitor, update_id: MonitorUpdateId): ChannelMonitorUpdateStatus {
val body = Arguments.createMap()
body.putHexString("channel_id", id.to_channel_id())
body.putHexString("counterparty_node_id", data._counterparty_node_id)
Expand All @@ -24,29 +24,36 @@ class LdkPersister {

val isNew = !file.exists()

BackupClient.persist(BackupClient.Label.CHANNEL_MONITOR(channelId=channelId), data.write(), retry = 100)
file.writeBytes(data.write())

LdkEventEmitter.send(EventTypes.native_log, "Persisted channel (${id.to_channel_id().hexEncodedString()}) to disk")
LdkEventEmitter.send(EventTypes.backup, "")
BackupClient.addToPersistQueue(BackupClient.Label.CHANNEL_MONITOR(channelId=channelId), data.write()) {
file.writeBytes(data.write())

//Update chainmonitor with successful persist
val res = LdkModule.chainMonitor?.channel_monitor_updated(id, update_id)
if (res == null || !res.is_ok) {
LdkEventEmitter.send(EventTypes.native_log, "Failed to update chain monitor with persisted channel (${id.to_channel_id().hexEncodedString()})")
} else {
LdkEventEmitter.send(EventTypes.native_log, "Persisted channel (${id.to_channel_id().hexEncodedString()}) to disk")
LdkEventEmitter.send(EventTypes.backup, "")
}
}

if (isNew) {
LdkEventEmitter.send(EventTypes.new_channel, body)
}

return ChannelMonitorUpdateStatus.LDKChannelMonitorUpdateStatus_Completed
return ChannelMonitorUpdateStatus.LDKChannelMonitorUpdateStatus_InProgress
} catch (e: Exception) {
return ChannelMonitorUpdateStatus.LDKChannelMonitorUpdateStatus_UnrecoverableError
}
}

var persister = Persist.new_impl(object : PersistInterface {
override fun persist_new_channel(id: OutPoint, data: ChannelMonitor, update_id: MonitorUpdateId): ChannelMonitorUpdateStatus {
return handleChannel(id, data)
return handleChannel(id, data, update_id)
}

override fun update_persisted_channel(id: OutPoint, update: ChannelMonitorUpdate?, data: ChannelMonitor, update_id: MonitorUpdateId): ChannelMonitorUpdateStatus {
return handleChannel(id, data)
return handleChannel(id, data, update_id)
}
})
}

0 comments on commit 92fef96

Please sign in to comment.