Skip to content

Commit

Permalink
Merge pull request #1636 from zcash/wallet/imported_spending_key_meta…
Browse files Browse the repository at this point in the history
…data

zcash_client_backend: Add optional derivation metadata when importing UFVKs with spending purpose.
  • Loading branch information
nuttycom authored Dec 6, 2024
2 parents c33ad67 + cc2dfbf commit f6040a1
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 84 deletions.
3 changes: 3 additions & 0 deletions zcash_client_backend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this library adheres to Rust's notion of

## [Unreleased]

### Added
- `zcash_client_backend::data_api::AccountSource::key_derivation`

### Changed
- `zcash_client_backend::data_api::WalletRead`:
- The `create_account`, `import_account_hd`, and `import_account_ufvk`
Expand Down
62 changes: 56 additions & 6 deletions zcash_client_backend/src/data_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,40 @@ impl AccountBalance {
}
}

/// Source metadata for a ZIP 32-derived key.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Zip32Derivation {
seed_fingerprint: SeedFingerprint,
account_index: zip32::AccountId,
}

impl Zip32Derivation {
/// Constructs new derivation metadata from its constituent parts.
pub fn new(seed_fingerprint: SeedFingerprint, account_index: zip32::AccountId) -> Self {
Self {
seed_fingerprint,
account_index,
}
}

/// Returns the seed fingerprint.
pub fn seed_fingerprint(&self) -> &SeedFingerprint {
&self.seed_fingerprint
}

/// Returns the account-level index in the ZIP 32 derivation path.
pub fn account_index(&self) -> zip32::AccountId {
self.account_index
}
}

/// An enumeration used to control what information is tracked by the wallet for
/// notes received by a given account.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum AccountPurpose {
/// For spending accounts, the wallet will track information needed to spend
/// received notes.
Spending,
Spending { derivation: Option<Zip32Derivation> },
/// For view-only accounts, the wallet will not track spend information.
ViewOnly,
}
Expand All @@ -347,8 +374,7 @@ pub enum AccountPurpose {
pub enum AccountSource {
/// An account derived from a known seed.
Derived {
seed_fingerprint: SeedFingerprint,
account_index: zip32::AccountId,
derivation: Zip32Derivation,
key_source: Option<String>,
},

Expand All @@ -359,6 +385,28 @@ pub enum AccountSource {
},
}

impl AccountSource {
/// Returns the key derivation metadata for the account source, if any is available.
pub fn key_derivation(&self) -> Option<&Zip32Derivation> {
match self {
AccountSource::Derived { derivation, .. } => Some(derivation),
AccountSource::Imported {
purpose: AccountPurpose::Spending { derivation },
..
} => derivation.as_ref(),
_ => None,
}
}

/// Returns the application-level key source identifier.
pub fn key_source(&self) -> Option<&str> {
match self {
AccountSource::Derived { key_source, .. } => key_source.as_ref().map(|s| s.as_str()),
AccountSource::Imported { key_source, .. } => key_source.as_ref().map(|s| s.as_str()),
}
}
}

/// A set of capabilities that a client account must provide.
pub trait Account {
type AccountId: Copy;
Expand All @@ -376,8 +424,10 @@ pub trait Account {
/// Returns whether the account is a spending account or a view-only account.
fn purpose(&self) -> AccountPurpose {
match self.source() {
AccountSource::Derived { .. } => AccountPurpose::Spending,
AccountSource::Imported { purpose, .. } => *purpose,
AccountSource::Derived { derivation, .. } => AccountPurpose::Spending {
derivation: Some(derivation.clone()),
},
AccountSource::Imported { purpose, .. } => purpose.clone(),
}
}

Expand Down
52 changes: 23 additions & 29 deletions zcash_client_sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use zcash_client_backend::{
Account, AccountBirthday, AccountMeta, AccountPurpose, AccountSource, BlockMetadata,
DecryptedTransaction, InputSource, NoteFilter, NullifierQuery, ScannedBlock, SeedRelevance,
SentTransaction, SpendableNotes, TransactionDataRequest, WalletCommitmentTrees, WalletRead,
WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT,
WalletSummary, WalletWrite, Zip32Derivation, SAPLING_SHARD_HEIGHT,
},
keys::{
AddressGenerationError, UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey,
Expand Down Expand Up @@ -442,17 +442,12 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
seed: &SecretVec<u8>,
) -> Result<bool, Self::Error> {
if let Some(account) = self.get_account(account_id)? {
if let AccountSource::Derived {
seed_fingerprint,
account_index,
..
} = account.source()
{
if let AccountSource::Derived { derivation, .. } = account.source() {
wallet::seed_matches_derived_account(
&self.params,
seed,
seed_fingerprint,
*account_index,
derivation.seed_fingerprint(),
derivation.account_index(),
&account.uivk(),
)
} else {
Expand Down Expand Up @@ -480,19 +475,14 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
// way we could determine that is by brute-forcing the ZIP 32 account
// index space, which we're not going to do. The method name indicates to
// the caller that we only check derived accounts.
if let AccountSource::Derived {
seed_fingerprint,
account_index,
..
} = account.source()
{
if let AccountSource::Derived { derivation, .. } = account.source() {
has_derived = true;

if wallet::seed_matches_derived_account(
&self.params,
seed,
seed_fingerprint,
*account_index,
derivation.seed_fingerprint(),
derivation.account_index(),
&account.uivk(),
)? {
// The seed is relevant to this account.
Expand Down Expand Up @@ -873,8 +863,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
&wdb.params,
account_name,
&AccountSource::Derived {
seed_fingerprint,
account_index: zip32_account_index,
derivation: Zip32Derivation::new(seed_fingerprint, zip32_account_index),
key_source: key_source.map(|s| s.to_owned()),
},
wallet::ViewingKey::Full(Box::new(ufvk)),
Expand Down Expand Up @@ -911,8 +900,7 @@ impl<P: consensus::Parameters> WalletWrite for WalletDb<rusqlite::Connection, P>
&wdb.params,
account_name,
&AccountSource::Derived {
seed_fingerprint,
account_index,
derivation: Zip32Derivation::new(seed_fingerprint, account_index),
key_source: key_source.map(|s| s.to_owned()),
},
wallet::ViewingKey::Full(Box::new(ufvk)),
Expand Down Expand Up @@ -2023,7 +2011,7 @@ mod tests {
.build();
assert_matches!(
st.test_account().unwrap().account().source(),
AccountSource::Derived { account_index, .. } if *account_index == zip32::AccountId::ZERO);
AccountSource::Derived { derivation, .. } if derivation.account_index() == zip32::AccountId::ZERO);
}

#[test]
Expand All @@ -2046,7 +2034,7 @@ mod tests {
.unwrap();
assert_matches!(
first.0.source(),
AccountSource::Derived { account_index, .. } if *account_index == zip32_index_1);
AccountSource::Derived { derivation, .. } if derivation.account_index() == zip32_index_1);

let zip32_index_2 = zip32_index_1.next().unwrap();
let second = st
Expand All @@ -2055,7 +2043,7 @@ mod tests {
.unwrap();
assert_matches!(
second.0.source(),
AccountSource::Derived { account_index, .. } if *account_index == zip32_index_2);
AccountSource::Derived { derivation, .. } if derivation.account_index() == zip32_index_2);
}

fn check_collisions<C, DbT: WalletTest + WalletWrite, P: consensus::Parameters>(
Expand All @@ -2068,7 +2056,7 @@ mod tests {
{
assert_matches!(
st.wallet_mut()
.import_account_ufvk("", ufvk, birthday, AccountPurpose::Spending, None),
.import_account_ufvk("", ufvk, birthday, AccountPurpose::Spending { derivation: None }, None),
Err(e) if is_account_collision(&e)
);

Expand All @@ -2089,7 +2077,7 @@ mod tests {
"",
&subset_ufvk,
birthday,
AccountPurpose::Spending,
AccountPurpose::Spending { derivation: None },
None,
),
Err(e) if is_account_collision(&e)
Expand All @@ -2113,7 +2101,7 @@ mod tests {
"",
&subset_ufvk,
birthday,
AccountPurpose::Spending,
AccountPurpose::Spending { derivation: None },
None,
),
Err(e) if is_account_collision(&e)
Expand Down Expand Up @@ -2172,7 +2160,13 @@ mod tests {

let account = st
.wallet_mut()
.import_account_ufvk("", &ufvk, &birthday, AccountPurpose::Spending, None)
.import_account_ufvk(
"",
&ufvk,
&birthday,
AccountPurpose::Spending { derivation: None },
None,
)
.unwrap();
assert_eq!(
ufvk.encode(st.network()),
Expand All @@ -2182,7 +2176,7 @@ mod tests {
assert_matches!(
account.source(),
AccountSource::Imported {
purpose: AccountPurpose::Spending,
purpose: AccountPurpose::Spending { .. },
..
}
);
Expand Down
61 changes: 34 additions & 27 deletions zcash_client_sqlite/src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ use shardtree::{error::ShardTreeError, store::ShardStore, ShardTree};
use uuid::Uuid;
use zcash_client_backend::data_api::{
AccountPurpose, DecryptedTransaction, Progress, TransactionDataRequest, TransactionStatus,
Zip32Derivation,
};

use zip32::fingerprint::SeedFingerprint;

use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -150,28 +152,36 @@ fn parse_account_source(
spending_key_available: bool,
key_source: Option<String>,
) -> Result<AccountSource, SqliteClientError> {
match (account_kind, hd_seed_fingerprint, hd_account_index) {
(0, Some(seed_fp), Some(account_index)) => Ok(AccountSource::Derived {
seed_fingerprint: SeedFingerprint::from_bytes(seed_fp),
account_index: zip32::AccountId::try_from(account_index).map_err(|_| {
SqliteClientError::CorruptedData(
"ZIP-32 account ID from wallet DB is out of range.".to_string(),
)
})?,
let derivation = hd_seed_fingerprint
.zip(hd_account_index)
.map(|(seed_fp, idx)| {
zip32::AccountId::try_from(idx)
.map_err(|_| {
SqliteClientError::CorruptedData(
"ZIP-32 account ID from wallet DB is out of range.".to_string(),
)
})
.map(|idx| Zip32Derivation::new(SeedFingerprint::from_bytes(seed_fp), idx))
})
.transpose()?;

match (account_kind, derivation) {
(0, Some(derivation)) => Ok(AccountSource::Derived {
derivation,
key_source,
}),
(1, None, None) => Ok(AccountSource::Imported {
(1, derivation) => Ok(AccountSource::Imported {
purpose: if spending_key_available {
AccountPurpose::Spending
AccountPurpose::Spending { derivation }
} else {
AccountPurpose::ViewOnly
},
key_source,
}),
(0, None, None) | (1, Some(_), Some(_)) => Err(SqliteClientError::CorruptedData(
(0, None) => Err(SqliteClientError::CorruptedData(
"Wallet DB account_kind constraint violated".to_string(),
)),
(_, _, _) => Err(SqliteClientError::CorruptedData(
(_, _) => Err(SqliteClientError::CorruptedData(
"Unrecognized account_kind".to_string(),
)),
}
Expand Down Expand Up @@ -378,21 +388,19 @@ pub(crate) fn add_account<P: consensus::Parameters>(

let account_uuid = AccountUuid(Uuid::new_v4());

let (hd_seed_fingerprint, hd_account_index, spending_key_available, key_source) = match kind {
let (derivation, spending_key_available, key_source) = match kind {
AccountSource::Derived {
seed_fingerprint,
account_index,
derivation,
key_source,
} => (
Some(seed_fingerprint),
Some(account_index),
true,
} => (Some(derivation), true, key_source),
AccountSource::Imported {
purpose: AccountPurpose::Spending { derivation },
key_source,
),
} => (derivation.as_ref(), true, key_source),
AccountSource::Imported {
purpose,
purpose: AccountPurpose::ViewOnly,
key_source,
} => (None, None, *purpose == AccountPurpose::Spending, key_source),
} => (None, false, key_source),
};

#[cfg(feature = "orchard")]
Expand Down Expand Up @@ -449,8 +457,8 @@ pub(crate) fn add_account<P: consensus::Parameters>(
":account_name": account_name,
":uuid": account_uuid.0,
":account_kind": account_kind_code(kind),
":hd_seed_fingerprint": hd_seed_fingerprint.as_ref().map(|fp| fp.to_bytes()),
":hd_account_index": hd_account_index.map(|i| u32::from(*i)),
":hd_seed_fingerprint": derivation.map(|d| d.seed_fingerprint().to_bytes()),
":hd_account_index": derivation.map(|d| u32::from(d.account_index())),
":key_source": key_source,
":ufvk": ufvk_encoded,
":uivk": viewing_key.uivk().encode(params),
Expand Down Expand Up @@ -832,8 +840,7 @@ pub(crate) fn get_derived_account<P: consensus::Parameters>(
name: account_name,
uuid: account_uuid,
kind: AccountSource::Derived {
seed_fingerprint: *seed_fp,
account_index,
derivation: Zip32Derivation::new(*seed_fp, account_index),
key_source,
},
viewing_key: ViewingKey::Full(Box::new(ufvk)),
Expand Down Expand Up @@ -3761,7 +3768,7 @@ mod tests {
let expected_account_index = zip32::AccountId::try_from(0).unwrap();
assert_matches!(
account_parameters.kind,
AccountSource::Derived{account_index, ..} if account_index == expected_account_index
AccountSource::Derived{derivation, ..} if derivation.account_index() == expected_account_index
);
}

Expand Down
Loading

0 comments on commit f6040a1

Please sign in to comment.