Skip to content

Commit

Permalink
Consistent storage set and get (#1958)
Browse files Browse the repository at this point in the history
* Consistent storage set and get

* Changelog

* Remove an explicit type parameter

* Remove parse_accounts

* Changelog

* Fix compilation
  • Loading branch information
thibault-martinez authored Mar 15, 2023
1 parent 5cf2702 commit dc48003
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 68 deletions.
1 change: 1 addition & 0 deletions wallet/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Resync outputs if a transaction got confirmed between syncing outputs and pending transactions to prevent not having unspent outputs afterwards;
- Cache participations for spent outputs;
- Make `{Storage, StorageManager}::get` generic over a `T: Deserialize` and return a `T`, avoiding always having to deserialize after;

### Fixed

Expand Down
58 changes: 17 additions & 41 deletions wallet/src/storage/manager.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
// Copyright 2021 IOTA Stiftung
// SPDX-License-Identifier: Apache-2.0

use std::{str::FromStr, sync::Arc};
use std::sync::Arc;

use crypto::ciphers::chacha;
use iota_client::secret::{SecretManager, SecretManagerDto};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, RwLock};
Expand Down Expand Up @@ -50,10 +49,7 @@ pub(crate) async fn new_storage_manager(
encryption_key,
};
// Get the db version or set it
let db_schema_version = storage.get(DATABASE_SCHEMA_VERSION_KEY).await?;
if let Some(db_schema_version) = db_schema_version {
let db_schema_version = u8::from_str(&db_schema_version)
.map_err(|_| crate::Error::Storage("invalid db_schema_version".to_string()))?;
if let Some(db_schema_version) = storage.get::<u8>(DATABASE_SCHEMA_VERSION_KEY).await? {
if db_schema_version != DATABASE_SCHEMA_VERSION {
return Err(crate::Error::Storage(format!(
"unsupported database schema version {db_schema_version}"
Expand All @@ -65,10 +61,8 @@ pub(crate) async fn new_storage_manager(
.await?;
};

let account_indexes = match storage.get(ACCOUNTS_INDEXATION_KEY).await? {
Some(account_indexes) => serde_json::from_str(&account_indexes)?,
None => Vec::new(),
};
let account_indexes = storage.get(ACCOUNTS_INDEXATION_KEY).await?.unwrap_or_default();

let storage_manager = StorageManager {
storage,
account_indexes,
Expand All @@ -95,7 +89,7 @@ impl StorageManager {
self.storage.encryption_key.is_some()
}

pub async fn get(&self, key: &str) -> crate::Result<Option<String>> {
pub async fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> crate::Result<Option<T>> {
self.storage.get(key).await
}

Expand Down Expand Up @@ -125,13 +119,16 @@ impl StorageManager {

pub async fn get_account_manager_data(&self) -> crate::Result<Option<AccountManagerBuilder>> {
log::debug!("get_account_manager_data");
if let Some(data) = self.storage.get(ACCOUNT_MANAGER_INDEXATION_KEY).await? {
log::debug!("get_account_manager_data {data:?}");
let mut builder: AccountManagerBuilder = serde_json::from_str(&data)?;
if let Some(mut builder) = self
.storage
.get::<AccountManagerBuilder>(ACCOUNT_MANAGER_INDEXATION_KEY)
.await?
{
log::debug!("get_account_manager_data {builder:?}");

if let Some(secret_manager_dto) = self.storage.get::<SecretManagerDto>(SECRET_MANAGER_KEY).await? {
log::debug!("get_secret_manager {secret_manager_dto:?}");

if let Some(data) = self.storage.get(SECRET_MANAGER_KEY).await? {
log::debug!("get_secret_manager {data}");
let secret_manager_dto: SecretManagerDto = serde_json::from_str(&data)?;
// Only secret_managers that aren't SecretManagerDto::Mnemonic can be restored, because there the Seed
// can't be serialized, so we can't create the SecretManager again
match secret_manager_dto {
Expand All @@ -149,9 +146,9 @@ impl StorageManager {
}

pub async fn get_accounts(&mut self) -> crate::Result<Vec<Account>> {
if let Some(record) = self.storage.get(ACCOUNTS_INDEXATION_KEY).await? {
if let Some(account_indexes) = self.storage.get(ACCOUNTS_INDEXATION_KEY).await? {
if self.account_indexes.is_empty() {
self.account_indexes = serde_json::from_str(&record)?;
self.account_indexes = account_indexes;
}
} else {
return Ok(Vec::new());
Expand All @@ -168,7 +165,7 @@ impl StorageManager {
);
}

parse_accounts(&accounts, &self.storage.encryption_key)
Ok(accounts)
}

pub async fn save_account(&mut self, account: &Account) -> crate::Result<()> {
Expand All @@ -195,24 +192,3 @@ impl StorageManager {
.await
}
}

// Parse accounts from strings and decrypt them first if necessary
fn parse_accounts(accounts: &[String], encryption_key: &Option<[u8; 32]>) -> crate::Result<Vec<Account>> {
let mut parsed_accounts: Vec<Account> = Vec::new();
for account in accounts {
let account_json = if account.starts_with('{') {
Some(account.to_string())
} else if let Some(key) = encryption_key {
Some(String::from_utf8_lossy(&chacha::aead_decrypt(key, account.as_bytes())?).into_owned())
} else {
None
};
if let Some(json) = account_json {
let acc = serde_json::from_str::<Account>(&json)?;
parsed_accounts.push(acc);
} else {
return Err(crate::Error::StorageIsEncrypted);
}
}
Ok(parsed_accounts)
}
14 changes: 7 additions & 7 deletions wallet/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod participation;
use std::collections::HashMap;

use crypto::ciphers::chacha;
use serde::Serialize;
use serde::{Deserialize, Serialize};

use self::adapter::StorageAdapter;

Expand All @@ -30,19 +30,19 @@ impl Storage {
self.inner.id()
}

async fn get(&self, key: &str) -> crate::Result<Option<String>> {
async fn get<T: for<'de> Deserialize<'de>>(&self, key: &str) -> crate::Result<Option<T>> {
match self.inner.get(key).await? {
Some(record) => {
if let Some(key) = &self.encryption_key {
if serde_json::from_str::<Vec<u8>>(&record).is_ok() {
Ok(Some(
String::from_utf8_lossy(&chacha::aead_decrypt(key, record.as_bytes())?).into_owned(),
))
Ok(Some(serde_json::from_str(&String::from_utf8_lossy(
&chacha::aead_decrypt(key, record.as_bytes())?,
))?))
} else {
Ok(Some(record))
Ok(Some(serde_json::from_str(&record)?))
}
} else {
Ok(Some(record))
Ok(Some(serde_json::from_str(&record)?))
}
}
None => Ok(None),
Expand Down
34 changes: 15 additions & 19 deletions wallet/src/storage/participation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ impl StorageManager {
) -> crate::Result<()> {
log::debug!("insert_participation_event {}", event_with_nodes.id);

let mut events: HashMap<ParticipationEventId, ParticipationEventWithNodes> = match self
let mut events = self
.storage
.get(&format!("{PARTICIPATION_EVENTS}{account_index}"))
.get::<HashMap<ParticipationEventId, ParticipationEventWithNodes>>(&format!(
"{PARTICIPATION_EVENTS}{account_index}"
))
.await?
{
Some(events) => serde_json::from_str(&events)?,
None => HashMap::new(),
};
.unwrap_or_default();

events.insert(event_with_nodes.id, event_with_nodes);

Expand All @@ -47,12 +46,14 @@ impl StorageManager {
) -> crate::Result<()> {
log::debug!("remove_participation_event {id}");

let mut events: HashMap<ParticipationEventId, ParticipationEventWithNodes> = match self
let mut events = match self
.storage
.get(&format!("{PARTICIPATION_EVENTS}{account_index}"))
.get::<HashMap<ParticipationEventId, ParticipationEventWithNodes>>(&format!(
"{PARTICIPATION_EVENTS}{account_index}"
))
.await?
{
Some(events) => serde_json::from_str(&events)?,
Some(events) => events,
None => return Ok(()),
};

Expand All @@ -71,14 +72,11 @@ impl StorageManager {
) -> crate::Result<HashMap<ParticipationEventId, ParticipationEventWithNodes>> {
log::debug!("get_participation_events");

match self
Ok(self
.storage
.get(&format!("{PARTICIPATION_EVENTS}{account_index}"))
.await?
{
Some(events) => Ok(serde_json::from_str(&events)?),
None => Ok(HashMap::new()),
}
.unwrap_or_default())
}

pub(crate) async fn set_cached_participation_output_status(
Expand All @@ -94,6 +92,7 @@ impl StorageManager {
outputs_participation,
)
.await?;

Ok(())
}

Expand All @@ -103,13 +102,10 @@ impl StorageManager {
) -> crate::Result<HashMap<OutputId, OutputStatusResponse>> {
log::debug!("get_cached_participation");

match self
Ok(self
.storage
.get(&format!("{PARTICIPATION_CACHED_OUTPUTS}{account_index}"))
.await?
{
Some(cached_outputs) => Ok(serde_json::from_str(&cached_outputs)?),
None => Ok(HashMap::new()),
}
.unwrap_or_default())
}
}
2 changes: 1 addition & 1 deletion wallet/tests/backup_restore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ async fn backup_and_restore_different_coin_type() -> Result<()> {
#[cfg(all(feature = "stronghold", feature = "storage"))]
// Backup and restore with Stronghold
async fn backup_and_restore_same_coin_type() -> Result<()> {
let storage_path = "test-storage/backup_and_restore_different_coin_type";
let storage_path = "test-storage/backup_and_restore_same_coin_type";
common::setup(storage_path)?;

let client_options = ClientOptions::new().with_node(common::NODE_LOCAL)?;
Expand Down

0 comments on commit dc48003

Please sign in to comment.