From dc48003b762bc997b991d0fe6a33bcd08e133005 Mon Sep 17 00:00:00 2001 From: Thibault Martinez Date: Wed, 15 Mar 2023 20:45:39 +0100 Subject: [PATCH] Consistent storage set and get (#1958) * Consistent storage set and get * Changelog * Remove an explicit type parameter * Remove parse_accounts * Changelog * Fix compilation --- wallet/CHANGELOG.md | 1 + wallet/src/storage/manager.rs | 58 +++++++++-------------------- wallet/src/storage/mod.rs | 14 +++---- wallet/src/storage/participation.rs | 34 ++++++++--------- wallet/tests/backup_restore.rs | 2 +- 5 files changed, 41 insertions(+), 68 deletions(-) diff --git a/wallet/CHANGELOG.md b/wallet/CHANGELOG.md index da9bb98f5b..f22b41aee0 100644 --- a/wallet/CHANGELOG.md +++ b/wallet/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Resync outputs if a transaction got confirmed between syncing outputs and pending transactions to prevent not having unspent outputs afterwards; - Cache participations for spent outputs; +- Make `{Storage, StorageManager}::get` generic over a `T: Deserialize` and return a `T`, avoiding always having to deserialize after; ### Fixed diff --git a/wallet/src/storage/manager.rs b/wallet/src/storage/manager.rs index 751462636f..7a8d0fb2f7 100644 --- a/wallet/src/storage/manager.rs +++ b/wallet/src/storage/manager.rs @@ -1,9 +1,8 @@ // Copyright 2021 IOTA Stiftung // SPDX-License-Identifier: Apache-2.0 -use std::{str::FromStr, sync::Arc}; +use std::sync::Arc; -use crypto::ciphers::chacha; use iota_client::secret::{SecretManager, SecretManagerDto}; use serde::{Deserialize, Serialize}; use tokio::sync::{Mutex, RwLock}; @@ -50,10 +49,7 @@ pub(crate) async fn new_storage_manager( encryption_key, }; // Get the db version or set it - let db_schema_version = storage.get(DATABASE_SCHEMA_VERSION_KEY).await?; - if let Some(db_schema_version) = db_schema_version { - let db_schema_version = u8::from_str(&db_schema_version) - .map_err(|_| crate::Error::Storage("invalid db_schema_version".to_string()))?; + if let Some(db_schema_version) = storage.get::(DATABASE_SCHEMA_VERSION_KEY).await? { if db_schema_version != DATABASE_SCHEMA_VERSION { return Err(crate::Error::Storage(format!( "unsupported database schema version {db_schema_version}" @@ -65,10 +61,8 @@ pub(crate) async fn new_storage_manager( .await?; }; - let account_indexes = match storage.get(ACCOUNTS_INDEXATION_KEY).await? { - Some(account_indexes) => serde_json::from_str(&account_indexes)?, - None => Vec::new(), - }; + let account_indexes = storage.get(ACCOUNTS_INDEXATION_KEY).await?.unwrap_or_default(); + let storage_manager = StorageManager { storage, account_indexes, @@ -95,7 +89,7 @@ impl StorageManager { self.storage.encryption_key.is_some() } - pub async fn get(&self, key: &str) -> crate::Result> { + pub async fn get Deserialize<'de>>(&self, key: &str) -> crate::Result> { self.storage.get(key).await } @@ -125,13 +119,16 @@ impl StorageManager { pub async fn get_account_manager_data(&self) -> crate::Result> { log::debug!("get_account_manager_data"); - if let Some(data) = self.storage.get(ACCOUNT_MANAGER_INDEXATION_KEY).await? { - log::debug!("get_account_manager_data {data:?}"); - let mut builder: AccountManagerBuilder = serde_json::from_str(&data)?; + if let Some(mut builder) = self + .storage + .get::(ACCOUNT_MANAGER_INDEXATION_KEY) + .await? + { + log::debug!("get_account_manager_data {builder:?}"); + + if let Some(secret_manager_dto) = self.storage.get::(SECRET_MANAGER_KEY).await? { + log::debug!("get_secret_manager {secret_manager_dto:?}"); - if let Some(data) = self.storage.get(SECRET_MANAGER_KEY).await? { - log::debug!("get_secret_manager {data}"); - let secret_manager_dto: SecretManagerDto = serde_json::from_str(&data)?; // Only secret_managers that aren't SecretManagerDto::Mnemonic can be restored, because there the Seed // can't be serialized, so we can't create the SecretManager again match secret_manager_dto { @@ -149,9 +146,9 @@ impl StorageManager { } pub async fn get_accounts(&mut self) -> crate::Result> { - if let Some(record) = self.storage.get(ACCOUNTS_INDEXATION_KEY).await? { + if let Some(account_indexes) = self.storage.get(ACCOUNTS_INDEXATION_KEY).await? { if self.account_indexes.is_empty() { - self.account_indexes = serde_json::from_str(&record)?; + self.account_indexes = account_indexes; } } else { return Ok(Vec::new()); @@ -168,7 +165,7 @@ impl StorageManager { ); } - parse_accounts(&accounts, &self.storage.encryption_key) + Ok(accounts) } pub async fn save_account(&mut self, account: &Account) -> crate::Result<()> { @@ -195,24 +192,3 @@ impl StorageManager { .await } } - -// Parse accounts from strings and decrypt them first if necessary -fn parse_accounts(accounts: &[String], encryption_key: &Option<[u8; 32]>) -> crate::Result> { - let mut parsed_accounts: Vec = Vec::new(); - for account in accounts { - let account_json = if account.starts_with('{') { - Some(account.to_string()) - } else if let Some(key) = encryption_key { - Some(String::from_utf8_lossy(&chacha::aead_decrypt(key, account.as_bytes())?).into_owned()) - } else { - None - }; - if let Some(json) = account_json { - let acc = serde_json::from_str::(&json)?; - parsed_accounts.push(acc); - } else { - return Err(crate::Error::StorageIsEncrypted); - } - } - Ok(parsed_accounts) -} diff --git a/wallet/src/storage/mod.rs b/wallet/src/storage/mod.rs index 253c92ba6e..685b3a5ed7 100644 --- a/wallet/src/storage/mod.rs +++ b/wallet/src/storage/mod.rs @@ -15,7 +15,7 @@ mod participation; use std::collections::HashMap; use crypto::ciphers::chacha; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use self::adapter::StorageAdapter; @@ -30,19 +30,19 @@ impl Storage { self.inner.id() } - async fn get(&self, key: &str) -> crate::Result> { + async fn get Deserialize<'de>>(&self, key: &str) -> crate::Result> { match self.inner.get(key).await? { Some(record) => { if let Some(key) = &self.encryption_key { if serde_json::from_str::>(&record).is_ok() { - Ok(Some( - String::from_utf8_lossy(&chacha::aead_decrypt(key, record.as_bytes())?).into_owned(), - )) + Ok(Some(serde_json::from_str(&String::from_utf8_lossy( + &chacha::aead_decrypt(key, record.as_bytes())?, + ))?)) } else { - Ok(Some(record)) + Ok(Some(serde_json::from_str(&record)?)) } } else { - Ok(Some(record)) + Ok(Some(serde_json::from_str(&record)?)) } } None => Ok(None), diff --git a/wallet/src/storage/participation.rs b/wallet/src/storage/participation.rs index 3cb5497e38..fdfd49dd41 100644 --- a/wallet/src/storage/participation.rs +++ b/wallet/src/storage/participation.rs @@ -22,14 +22,13 @@ impl StorageManager { ) -> crate::Result<()> { log::debug!("insert_participation_event {}", event_with_nodes.id); - let mut events: HashMap = match self + let mut events = self .storage - .get(&format!("{PARTICIPATION_EVENTS}{account_index}")) + .get::>(&format!( + "{PARTICIPATION_EVENTS}{account_index}" + )) .await? - { - Some(events) => serde_json::from_str(&events)?, - None => HashMap::new(), - }; + .unwrap_or_default(); events.insert(event_with_nodes.id, event_with_nodes); @@ -47,12 +46,14 @@ impl StorageManager { ) -> crate::Result<()> { log::debug!("remove_participation_event {id}"); - let mut events: HashMap = match self + let mut events = match self .storage - .get(&format!("{PARTICIPATION_EVENTS}{account_index}")) + .get::>(&format!( + "{PARTICIPATION_EVENTS}{account_index}" + )) .await? { - Some(events) => serde_json::from_str(&events)?, + Some(events) => events, None => return Ok(()), }; @@ -71,14 +72,11 @@ impl StorageManager { ) -> crate::Result> { log::debug!("get_participation_events"); - match self + Ok(self .storage .get(&format!("{PARTICIPATION_EVENTS}{account_index}")) .await? - { - Some(events) => Ok(serde_json::from_str(&events)?), - None => Ok(HashMap::new()), - } + .unwrap_or_default()) } pub(crate) async fn set_cached_participation_output_status( @@ -94,6 +92,7 @@ impl StorageManager { outputs_participation, ) .await?; + Ok(()) } @@ -103,13 +102,10 @@ impl StorageManager { ) -> crate::Result> { log::debug!("get_cached_participation"); - match self + Ok(self .storage .get(&format!("{PARTICIPATION_CACHED_OUTPUTS}{account_index}")) .await? - { - Some(cached_outputs) => Ok(serde_json::from_str(&cached_outputs)?), - None => Ok(HashMap::new()), - } + .unwrap_or_default()) } } diff --git a/wallet/tests/backup_restore.rs b/wallet/tests/backup_restore.rs index e8ba525163..c4e585f1b0 100644 --- a/wallet/tests/backup_restore.rs +++ b/wallet/tests/backup_restore.rs @@ -289,7 +289,7 @@ async fn backup_and_restore_different_coin_type() -> Result<()> { #[cfg(all(feature = "stronghold", feature = "storage"))] // Backup and restore with Stronghold async fn backup_and_restore_same_coin_type() -> Result<()> { - let storage_path = "test-storage/backup_and_restore_different_coin_type"; + let storage_path = "test-storage/backup_and_restore_same_coin_type"; common::setup(storage_path)?; let client_options = ClientOptions::new().with_node(common::NODE_LOCAL)?;