diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt index e413667d9..0fb2a01da 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/SnapshotManager.kt @@ -12,12 +12,18 @@ import com.stevesoltys.seedvault.header.VERSION import com.stevesoltys.seedvault.proto.Snapshot import com.stevesoltys.seedvault.transport.restore.Loader import io.github.oshai.kotlinlogging.KotlinLogging -import okio.Buffer -import okio.buffer -import okio.sink import org.calyxos.seedvault.core.backends.AppBackupFileType +import org.calyxos.seedvault.core.toHexString +import java.io.ByteArrayOutputStream +import java.io.File +import java.io.IOException +/** + * Manages interactions with snapshots, such as loading, saving and removing them. + * Also keeps a reference to the [latestSnapshot] that holds important re-usable data. + */ internal class SnapshotManager( + private val snapshotFolder: File, private val crypto: Crypto, private val loader: Loader, private val backendManager: BackendManager, @@ -32,37 +38,73 @@ internal class SnapshotManager( var latestSnapshot: Snapshot? = null private set + /** + * Call this before starting a backup run with the [handles] of snapshots + * currently available on the backend. + */ suspend fun onSnapshotsLoaded(handles: List): List { return handles.map { snapshotHandle -> // TODO set up local snapshot cache, so we don't need to download those all the time // TODO is it a fatal error when one snapshot is corrupted or couldn't get loaded? - val snapshot = loader.loadFile(snapshotHandle).use { inputStream -> - Snapshot.parseFrom(inputStream) - } + val snapshot = loadSnapshot(snapshotHandle) // update latest snapshot if this one is more recent if (snapshot.token > (latestSnapshot?.token ?: 0)) latestSnapshot = snapshot snapshot } } + /** + * Saves the given [snapshot] to the backend and local cache. + * + * @throws IOException or others if saving fails. + */ + @Throws(IOException::class) suspend fun saveSnapshot(snapshot: Snapshot) { - val buffer = Buffer() - val bufferStream = buffer.outputStream() - bufferStream.write(VERSION.toInt()) - crypto.newEncryptingStream(bufferStream, crypto.getAdForVersion()).use { cryptoStream -> + val byteStream = ByteArrayOutputStream() + byteStream.write(VERSION.toInt()) + crypto.newEncryptingStream(byteStream, crypto.getAdForVersion()).use { cryptoStream -> ZstdOutputStream(cryptoStream).use { zstdOutputStream -> snapshot.writeTo(zstdOutputStream) } } - val sha256ByteString = buffer.sha256() - val handle = AppBackupFileType.Snapshot(crypto.repoId, sha256ByteString.hex()) - // TODO exception handling - backendManager.backend.save(handle).use { outputStream -> - outputStream.sink().buffer().apply { - writeAll(buffer) - flush() // needs flushing + val bytes = byteStream.toByteArray() + val sha256 = crypto.sha256(bytes).toHexString() + val snapshotHandle = AppBackupFileType.Snapshot(crypto.repoId, sha256) + backendManager.backend.save(snapshotHandle).use { outputStream -> + outputStream.write(bytes) + } + // save to local cache while at it + try { + File(snapshotFolder, snapshotHandle.name).outputStream().use { outputStream -> + outputStream.write(bytes) } + } catch (e: Exception) { // we'll let this one pass + log.error(e) { "Error saving snapshot ${snapshotHandle.hash} to cache" } + } + } + + /** + * Removes the snapshot referenced by the given [snapshotHandle] from the backend + * and local cache. + */ + suspend fun removeSnapshot(snapshotHandle: AppBackupFileType.Snapshot) { + backendManager.backend.remove(snapshotHandle) + // remove from cache as well + File(snapshotFolder, snapshotHandle.name).delete() + } + + /** + * Loads and parses the snapshot referenced by the given [snapshotHandle]. + * If a locally cached version exists, the backend will not be hit. + */ + private suspend fun loadSnapshot(snapshotHandle: AppBackupFileType.Snapshot): Snapshot { + val file = File(snapshotFolder, snapshotHandle.name) + val inputStream = if (file.isFile) { + loader.loadFile(file, snapshotHandle.hash) + } else { + loader.loadFile(snapshotHandle, file) } + return inputStream.use { Snapshot.parseFrom(it) } } } diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt index 3cba267a3..455a12bd0 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/AppBackupManager.kt @@ -49,26 +49,31 @@ internal class AppBackupManager( blobCache.populateCache(blobInfos, snapshots) } - suspend fun afterBackupFinished(success: Boolean) { + suspend fun afterBackupFinished(success: Boolean): Boolean { log.info { "After backup finished. Success: $success" } MemoryLogger.log() // free up memory by clearing blobs cache blobCache.clear() + var result = false try { if (success) { val snapshot = snapshotCreator?.finalizeSnapshot() ?: error("Had no snapshotCreator") - keepTrying { + keepTrying { // saving this is so important, we even keep trying snapshotManager.saveSnapshot(snapshot) } settingsManager.token = snapshot.token // after snapshot was written, we can clear local cache as its info is in snapshot blobCache.clearLocalCache() } + result = true + } catch (e: Exception) { + log.error(e) { "Error finishing backup" } } finally { snapshotCreator = null } MemoryLogger.log() + return result } private suspend fun keepTrying(n: Int = 3, block: suspend () -> Unit) { diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt index 4bb151236..cbde0a221 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/backup/BackupModule.kt @@ -8,13 +8,17 @@ package com.stevesoltys.seedvault.transport.backup import com.stevesoltys.seedvault.transport.SnapshotManager import org.koin.android.ext.koin.androidContext import org.koin.dsl.module +import java.io.File val backupModule = module { single { BackupInitializer(get()) } single { BackupReceiver(get(), get(), get()) } single { BlobCache(androidContext()) } single { BlobCreator(get(), get()) } - single { SnapshotManager(get(), get(), get()) } + single { + val snapshotFolder = File(androidContext().filesDir, "snapshots") + SnapshotManager(snapshotFolder, get(), get(), get()) + } single { SnapshotCreatorFactory(androidContext(), get(), get(), get()) } single { InputFactory() } single { diff --git a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt index 66a3246ed..9046b5127 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/transport/restore/Loader.kt @@ -5,15 +5,18 @@ package com.stevesoltys.seedvault.transport.restore +import com.android.internal.R.attr.handle import com.github.luben.zstd.ZstdInputStream import com.stevesoltys.seedvault.backend.BackendManager import com.stevesoltys.seedvault.crypto.Crypto import com.stevesoltys.seedvault.header.UnsupportedVersionException import com.stevesoltys.seedvault.header.VERSION +import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.coroutines.runBlocking import org.calyxos.seedvault.core.backends.AppBackupFileType import org.calyxos.seedvault.core.toHexString import java.io.ByteArrayInputStream +import java.io.File import java.io.InputStream import java.io.SequenceInputStream import java.security.GeneralSecurityException @@ -24,21 +27,57 @@ internal class Loader( private val backendManager: BackendManager, ) { + private val log = KotlinLogging.logger {} + + /** + * Downloads the given [fileHandle], decrypts and decompresses its content + * and returns the content as a decrypted and decompressed stream. + * + * Attention: The responsibility with closing the returned stream lies with the caller. + * + * @param cacheFile if non-null, the ciphertext of the loaded file will be cached there + * for later loading with [loadFile]. + */ + suspend fun loadFile(fileHandle: AppBackupFileType, cacheFile: File? = null): InputStream { + val expectedHash = when (fileHandle) { + is AppBackupFileType.Snapshot -> fileHandle.hash + is AppBackupFileType.Blob -> fileHandle.name + } + return loadFromStream(backendManager.backend.load(fileHandle), expectedHash, cacheFile) + } + /** * The responsibility with closing the returned stream lies with the caller. */ - suspend fun loadFile(handle: AppBackupFileType): InputStream { + fun loadFile(file: File, expectedHash: String): InputStream { + return loadFromStream(file.inputStream(), expectedHash) + } + + suspend fun loadFiles(handles: List): InputStream { + val enumeration: Enumeration = object : Enumeration { + val iterator = handles.iterator() + + override fun hasMoreElements(): Boolean { + return iterator.hasNext() + } + + override fun nextElement(): InputStream { + return runBlocking { loadFile(iterator.next()) } + } + } + return SequenceInputStream(enumeration) + } + + private fun loadFromStream( + inputStream: InputStream, + expectedHash: String, + cacheFile: File? = null, + ): InputStream { // We load the entire ciphertext into memory, // so we can check the SHA-256 hash before decrypting and parsing the data. - val cipherText = backendManager.backend.load(handle).use { inputStream -> - inputStream.readAllBytes() - } + val cipherText = inputStream.use { it.readAllBytes() } // check SHA-256 hash first thing val sha256 = crypto.sha256(cipherText).toHexString() - val expectedHash = when (handle) { - is AppBackupFileType.Snapshot -> handle.hash - is AppBackupFileType.Blob -> handle.name - } if (sha256 != expectedHash) { throw GeneralSecurityException("File had wrong SHA-256 hash: $handle") } @@ -46,26 +85,20 @@ internal class Loader( val version = cipherText[0] if (version <= 1) throw GeneralSecurityException("Unexpected version: $version") if (version > VERSION) throw UnsupportedVersionException(version) + // cache ciperText in cacheFile, if existing + try { + cacheFile?.outputStream()?.use { outputStream -> + outputStream.write(cipherText) + } + } catch (e: Exception) { + log.error(e) { "Error writing cache file $cacheFile" } + } // get associated data for version, used for authenticated decryption val ad = crypto.getAdForVersion(version) // skip first version byte when creating cipherText stream - val inputStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1) + val byteStream = ByteArrayInputStream(cipherText, 1, cipherText.size - 1) // decrypt and decompress cipherText stream and parse snapshot - return ZstdInputStream(crypto.newDecryptingStream(inputStream, ad)) + return ZstdInputStream(crypto.newDecryptingStream(byteStream, ad)) } - suspend fun loadFiles(handles: List): InputStream { - val enumeration: Enumeration = object : Enumeration { - val iterator = handles.iterator() - - override fun hasMoreElements(): Boolean { - return iterator.hasNext() - } - - override fun nextElement(): InputStream { - return runBlocking { loadFile(iterator.next()) } - } - } - return SequenceInputStream(enumeration) - } } diff --git a/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt b/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt index 0b8dc5998..f26485319 100644 --- a/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt +++ b/app/src/main/java/com/stevesoltys/seedvault/ui/notification/NotificationBackupObserver.kt @@ -11,6 +11,7 @@ import android.app.backup.IBackupObserver import android.content.Context import android.content.pm.ApplicationInfo.FLAG_SYSTEM import android.content.pm.PackageManager.NameNotFoundException +import android.os.Looper import android.util.Log import android.util.Log.INFO import android.util.Log.isLoggable @@ -136,7 +137,7 @@ internal class NotificationBackupObserver( if (isLoggable(TAG, INFO)) { Log.i(TAG, "Backup finished $numPackages/$requestedPackages. Status: $status") } - val success = status == 0 + var success = status == 0 val size = if (success) metadataManager.getPackagesBackupSize() else 0L val total = try { packageService.allUserPackages.size @@ -144,11 +145,10 @@ internal class NotificationBackupObserver( Log.e(TAG, "Error getting number of all user packages: ", e) requestedPackages } - // TODO handle exceptions runBlocking { - // TODO check if UI thread - Log.d("TAG", "Finalizing backup...") - appBackupManager.afterBackupFinished(success) + check(!Looper.getMainLooper().isCurrentThread) + Log.d(TAG, "Finalizing backup...") + success = appBackupManager.afterBackupFinished(success) } nm.onBackupFinished(success, numPackagesToReport, total, size) } diff --git a/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt b/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt index a4f4d54ed..45f02b49b 100644 --- a/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt +++ b/app/src/test/java/com/stevesoltys/seedvault/transport/SnapshotManagerTest.kt @@ -19,10 +19,13 @@ import org.calyxos.seedvault.core.toHexString import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream +import java.io.File import java.io.InputStream import java.io.OutputStream +import java.nio.file.Path import java.security.MessageDigest import kotlin.random.Random @@ -32,15 +35,28 @@ internal class SnapshotManagerTest : TransportTest() { private val backend: Backend = mockk() private val loader = Loader(crypto, backendManager) // need a real loader - private val snapshotManager = SnapshotManager(crypto, loader, backendManager) private val ad = Random.nextBytes(1) private val passThroughOutputStream = slot() private val passThroughInputStream = slot() private val snapshotHandle = slot() + // @Test + // fun `test onSnapshotsLoaded sets latestSnapshot`(@TempDir tmpDir: Path) = runBlocking { + // val snapshotManager = getSnapshotManager(File(tmpDir.toString())) + // + // val snapshotHandle1 = AppBackupFileType.Snapshot(repoId, chunkId1) + // val snapshotHandle2 = AppBackupFileType.Snapshot(repoId, chunkId2) + // snapshotManager.onSnapshotsLoaded(listOf(snapshotHandle1, snapshotHandle2)) + // Unit + // } + @Test - fun `test saving and loading`() = runBlocking { + fun `test saving and loading`(@TempDir tmpDir: Path) = runBlocking { + val snapshotManager = getSnapshotManager(File(tmpDir.toString())) + + val messageDigest = MessageDigest.getInstance("SHA-256") + val bytes = slot() val outputStream = ByteArrayOutputStream() every { crypto.getAdForVersion() } returns ad @@ -49,12 +65,16 @@ internal class SnapshotManagerTest : TransportTest() { } every { crypto.repoId } returns repoId every { backendManager.backend } returns backend + every { crypto.sha256(capture(bytes)) } answers { + messageDigest.digest(bytes.captured) + } coEvery { backend.save(capture(snapshotHandle)) } returns outputStream snapshotManager.saveSnapshot(snapshot) + println(snapshotHandle.captured) + // check that file content hash matches snapshot hash - val messageDigest = MessageDigest.getInstance("SHA-256") assertEquals( messageDigest.digest(outputStream.toByteArray()).toHexString(), snapshotHandle.captured.hash, @@ -75,4 +95,8 @@ internal class SnapshotManagerTest : TransportTest() { assertEquals(snapshot, snapshots[0]) } } + + private fun getSnapshotManager(tmpFolder: File): SnapshotManager { + return SnapshotManager(tmpFolder, crypto, loader, backendManager) + } }