From e9ad4a61559a276c23910c6740c43f924bfea347 Mon Sep 17 00:00:00 2001 From: Kris Nuttycombe Date: Thu, 17 Oct 2024 11:55:47 -0600 Subject: [PATCH] zcash_client_backend: Generalize & extend wallet metadata query API This generalizes the previous wallet metadata query API to be able to represent more complex queries, and also to return note totals in addition to note counts. --- components/zcash_protocol/src/value.rs | 22 +- zcash_client_backend/CHANGELOG.md | 2 + zcash_client_backend/src/data_api.rs | 121 +++++++++- zcash_client_backend/src/data_api/testing.rs | 4 +- .../src/data_api/testing/pool.rs | 28 +-- zcash_client_backend/src/fees.rs | 100 ++++++-- zcash_client_backend/src/fees/common.rs | 11 +- zcash_client_backend/src/fees/zip317.rs | 24 +- zcash_client_sqlite/CHANGELOG.md | 1 + zcash_client_sqlite/src/error.rs | 5 + zcash_client_sqlite/src/lib.rs | 36 +-- zcash_client_sqlite/src/wallet/common.rs | 218 +++++++++++++++--- zcash_client_sqlite/src/wallet/init.rs | 3 + 13 files changed, 467 insertions(+), 108 deletions(-) diff --git a/components/zcash_protocol/src/value.rs b/components/zcash_protocol/src/value.rs index 94f2b7ba64..aa4d1709bc 100644 --- a/components/zcash_protocol/src/value.rs +++ b/components/zcash_protocol/src/value.rs @@ -2,7 +2,7 @@ use std::convert::{Infallible, TryFrom}; use std::error; use std::iter::Sum; use std::num::NonZeroU64; -use std::ops::{Add, Mul, Neg, Sub}; +use std::ops::{Add, Div, Mul, Neg, Sub}; use memuse::DynamicUsage; @@ -321,6 +321,7 @@ impl Zatoshis { /// Divides this `Zatoshis` value by the given divisor and returns the quotient and remainder. pub fn div_with_remainder(&self, divisor: NonZeroU64) -> QuotRem { let divisor = u64::from(divisor); + // `self` is already bounds-checked, so we don't need to re-check it in division QuotRem { quotient: Zatoshis(self.0 / divisor), remainder: Zatoshis(self.0 % divisor), @@ -394,11 +395,19 @@ impl Sub for Option { } } +impl Mul for Zatoshis { + type Output = Option; + + fn mul(self, rhs: u64) -> Option { + Zatoshis::from_u64(self.0.checked_mul(rhs)?).ok() + } +} + impl Mul for Zatoshis { type Output = Option; fn mul(self, rhs: usize) -> Option { - Zatoshis::from_u64(self.0.checked_mul(u64::try_from(rhs).ok()?)?).ok() + self * u64::try_from(rhs).ok()? } } @@ -414,6 +423,15 @@ impl<'a> Sum<&'a Zatoshis> for Option { } } +impl Div for Zatoshis { + type Output = Zatoshis; + + fn div(self, rhs: NonZeroU64) -> Zatoshis { + // `self` is already bounds-checked, so we don't need to re-check it + Zatoshis(self.0 / u64::from(rhs)) + } +} + /// A type for balance violations in amount addition and subtraction /// (overflow and underflow of allowed ranges) #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/zcash_client_backend/CHANGELOG.md b/zcash_client_backend/CHANGELOG.md index aebf460767..377fd0e3ba 100644 --- a/zcash_client_backend/CHANGELOG.md +++ b/zcash_client_backend/CHANGELOG.md @@ -13,6 +13,8 @@ and this library adheres to Rust's notion of - `WalletSummary::progress` - `WalletMeta` - `impl Default for wallet::input_selection::GreedyInputSelector` + - `BoundedU8` + - `NoteSelector` - `zcash_client_backend::fees` - `SplitPolicy` - `StandardFeeRule` has been moved here from `zcash_primitives::fees`. Relative diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index ab827f70a6..d30ec29413 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -804,20 +804,28 @@ impl SpendableNotes { /// the wallet. pub struct WalletMeta { sapling_note_count: usize, + sapling_total_value: NonNegativeAmount, #[cfg(feature = "orchard")] orchard_note_count: usize, + #[cfg(feature = "orchard")] + orchard_total_value: NonNegativeAmount, } impl WalletMeta { /// Constructs a new [`WalletMeta`] value from its constituent parts. pub fn new( sapling_note_count: usize, + sapling_total_value: NonNegativeAmount, #[cfg(feature = "orchard")] orchard_note_count: usize, + #[cfg(feature = "orchard")] orchard_total_value: NonNegativeAmount, ) -> Self { Self { sapling_note_count, + sapling_total_value, #[cfg(feature = "orchard")] orchard_note_count, + #[cfg(feature = "orchard")] + orchard_total_value, } } @@ -838,6 +846,11 @@ impl WalletMeta { self.sapling_note_count } + /// Returns the total value of Sapling notes represented by [`Self::sapling_note_count`]. + pub fn sapling_total_value(&self) -> NonNegativeAmount { + self.sapling_total_value + } + /// Returns the number of unspent Orchard notes belonging to the account for which this was /// generated. #[cfg(feature = "orchard")] @@ -845,11 +858,112 @@ impl WalletMeta { self.orchard_note_count } + /// Returns the total value of Orchard notes represented by [`Self::orchard_note_count`]. + #[cfg(feature = "orchard")] + pub fn orchard_total_value(&self) -> NonNegativeAmount { + self.orchard_total_value + } + /// Returns the total number of unspent shielded notes belonging to the account for which this /// was generated. pub fn total_note_count(&self) -> usize { self.sapling_note_count + self.note_count(ShieldedProtocol::Orchard) } + + /// Returns the total value of shielded notes represented by [`Self::total_note_count`] + pub fn total_value(&self) -> NonNegativeAmount { + #[cfg(feature = "orchard")] + let orchard_value = self.orchard_total_value; + #[cfg(not(feature = "orchard"))] + let orchard_value = NonNegativeAmount::ZERO; + + (self.sapling_total_value + orchard_value).expect("Does not overflow Zcash maximum value.") + } +} + +/// A `u8` value in the range 0..=MAX +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct BoundedU8(u8); + +impl BoundedU8 { + /// Creates a constant `BoundedU8` from a [`u8`] value. + /// + /// Panics: if the value is outside the range `0..=MAX`. + pub const fn new_const(value: u8) -> Self { + assert!(value <= MAX); + Self(value) + } + + /// Creates a `BoundedU8` from a [`u8`] value. + /// + /// Returns `None` if the provided value is outside the range `0..=MAX`. + pub fn new(value: u8) -> Option { + if value <= MAX { + Some(Self(value)) + } else { + None + } + } + + /// Returns the wrapped [`u8`] value. + pub fn value(&self) -> u8 { + self.0 + } +} + +impl From> for u8 { + fn from(value: BoundedU8) -> Self { + value.0 + } +} + +impl From> for usize { + fn from(value: BoundedU8) -> Self { + usize::from(value.0) + } +} + +/// A small query language for filtering notes belonging to an account. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum NoteSelector { + /// Selects notes having value greater than or equal to the provided value. + ExceedsMinValue(NonNegativeAmount), + /// Selects notes having value greater than or equal to the n'th percentile of previously sent + /// notes in the wallet. The wrapped value must be in the range `1..=99`. `n` may be rounded + /// to a multiple of 10 as part of this computation. + ExceedsPriorSendPercentile(BoundedU8<99>), + /// Selects notes having value greater than or equal to the specified percentage of the wallet + /// balance. The wrapped value must be in the range `1..=99` + ExceedsBalancePercentage(BoundedU8<99>), + /// A note will be selected if it satisfies both of the specified conditions. + /// + /// If it is not possible to evaluate one of the conditions (for example, + /// [`NoteSelector::ExceedsPriorSendPercentile`] cannot be evaluated if no sends have been + /// performed) then that condition will be ignored. + And(Box, Box), + /// A note will be selected if it satisfies the first condition; if it is not possible to + /// evaluate that condition (for example, [`NoteSelector::ExceedsPriorSendPercentile`] cannot + /// be evaluated if no sends have been performed) then the second condition will be used for + /// evaluation. + Attempt { + condition: Box, + fallback: Box, + }, +} + +impl NoteSelector { + /// Constructs a [`NoteSelector::And`] query node. + pub fn and(l: NoteSelector, r: NoteSelector) -> Self { + Self::And(Box::new(l), Box::new(r)) + } + + /// Constructs a [`NoteSelector::Attempt`] query node. + pub fn attempt(condition: NoteSelector, fallback: NoteSelector) -> Self { + Self::Attempt { + condition: Box::new(condition), + fallback: Box::new(fallback), + } + } } /// A trait representing the capability to query a data store for unspent transaction outputs @@ -900,12 +1014,15 @@ pub trait InputSource { /// /// The returned metadata value must exclude: /// - spent notes; - /// - unspent notes having value less than the specified minimum value; + /// - unspent notes excluded by the provided selector; /// - unspent notes identified in the given `exclude` list. + /// + /// Implementations of this method may limit the complexity of supported queries. Such + /// limitations should be clearly documented for the implementing type. fn get_wallet_metadata( &self, account: Self::AccountId, - min_value: NonNegativeAmount, + selector: &NoteSelector, exclude: &[Self::NoteRef], ) -> Result; diff --git a/zcash_client_backend/src/data_api/testing.rs b/zcash_client_backend/src/data_api/testing.rs index 0c117c7c8e..8fc2416d86 100644 --- a/zcash_client_backend/src/data_api/testing.rs +++ b/zcash_client_backend/src/data_api/testing.rs @@ -59,7 +59,6 @@ use crate::{ ShieldedProtocol, }; -use super::error::Error; use super::{ chain::{scan_cached_blocks, BlockSource, ChainState, CommitmentTreeRoot, ScanSummary}, scanning::ScanRange, @@ -74,6 +73,7 @@ use super::{ WalletCommitmentTrees, WalletMeta, WalletRead, WalletSummary, WalletTest, WalletWrite, SAPLING_SHARD_HEIGHT, }; +use super::{error::Error, NoteSelector}; #[cfg(feature = "transparent-inputs")] use { @@ -2354,7 +2354,7 @@ impl InputSource for MockWalletDb { fn get_wallet_metadata( &self, _account: Self::AccountId, - _min_value: NonNegativeAmount, + _selector: &NoteSelector, _exclude: &[Self::NoteRef], ) -> Result { Err(()) diff --git a/zcash_client_backend/src/data_api/testing/pool.rs b/zcash_client_backend/src/data_api/testing/pool.rs index 846a681136..c332fe7c8d 100644 --- a/zcash_client_backend/src/data_api/testing/pool.rs +++ b/zcash_client_backend/src/data_api/testing/pool.rs @@ -362,7 +362,7 @@ pub fn send_with_multiple_change_outputs( Some(change_memo.clone().into()), T::SHIELDED_PROTOCOL, DustOutputPolicy::default(), - SplitPolicy::new( + SplitPolicy::with_min_output_value( NonZeroUsize::new(2).unwrap(), NonNegativeAmount::const_from_u64(100_0000), ), @@ -465,7 +465,7 @@ pub fn send_with_multiple_change_outputs( Some(change_memo.into()), T::SHIELDED_PROTOCOL, DustOutputPolicy::default(), - SplitPolicy::new( + SplitPolicy::with_min_output_value( NonZeroUsize::new(8).unwrap(), NonNegativeAmount::const_from_u64(10_0000), ), @@ -530,7 +530,7 @@ pub fn send_multi_step_proposed_transfer( // Add funds to the wallet. add_funds(st, value); - let expected_step0_fee = (zip317::MARGINAL_FEE * 3).unwrap(); + let expected_step0_fee = (zip317::MARGINAL_FEE * 3u64).unwrap(); let expected_step1_fee = zip317::MINIMUM_FEE; let expected_ephemeral = (transfer_amount + expected_step1_fee).unwrap(); let expected_step0_change = @@ -1123,7 +1123,7 @@ pub fn spend_fails_on_unverified_notes( st.scan_cached_blocks(h2 + 1, 8); // Total balance is value * number of blocks scanned (10). - assert_eq!(st.get_total_balance(account_id), (value * 10).unwrap()); + assert_eq!(st.get_total_balance(account_id), (value * 10u64).unwrap()); // Spend still fails assert_matches!( @@ -1150,15 +1150,15 @@ pub fn spend_fails_on_unverified_notes( st.scan_cached_blocks(h11, 1); // Total balance is value * number of blocks scanned (11). - assert_eq!(st.get_total_balance(account_id), (value * 11).unwrap()); + assert_eq!(st.get_total_balance(account_id), (value * 11u64).unwrap()); // Spendable balance at 10 confirmations is value * 2. assert_eq!( st.get_spendable_balance(account_id, 10), - (value * 2).unwrap() + (value * 2u64).unwrap() ); assert_eq!( st.get_pending_shielded_balance(account_id, 10), - (value * 9).unwrap() + (value * 9u64).unwrap() ); // Should now be able to generate a proposal @@ -1192,7 +1192,7 @@ pub fn spend_fails_on_unverified_notes( // TODO: send to an account so that we can check its balance. assert_eq!( st.get_total_balance(account_id), - ((value * 11).unwrap() + ((value * 11u64).unwrap() - (amount_sent + NonNegativeAmount::from_u64(10000).unwrap()).unwrap()) .unwrap() ); @@ -2124,7 +2124,7 @@ pub fn fully_funded_fully_private( st.generate_next_block(&p1_fvk, AddressType::DefaultExternal, note_value); st.scan_cached_blocks(account.birthday().height(), 2); - let initial_balance = (note_value * 2).unwrap(); + let initial_balance = (note_value * 2u64).unwrap(); assert_eq!(st.get_total_balance(account.id()), initial_balance); assert_eq!(st.get_spendable_balance(account.id(), 1), initial_balance); @@ -2307,7 +2307,7 @@ pub fn multi_pool_checkpoint( let next_to_scan = scanned.scanned_range().end; - let initial_balance = (note_value * 3).unwrap(); + let initial_balance = (note_value * 3u64).unwrap(); assert_eq!(st.get_total_balance(acct_id), initial_balance); assert_eq!(st.get_spendable_balance(acct_id, 1), initial_balance); @@ -2352,7 +2352,7 @@ pub fn multi_pool_checkpoint( let expected_change = (note_value - transfer_amount - expected_fee).unwrap(); assert_eq!( st.get_total_balance(acct_id), - ((note_value * 2).unwrap() + expected_change).unwrap() + ((note_value * 2u64).unwrap() + expected_change).unwrap() ); assert_eq!(st.get_pending_change(acct_id, 1), expected_change); @@ -2396,8 +2396,8 @@ pub fn multi_pool_checkpoint( ); let expected_final = (initial_balance + note_value - - (transfer_amount * 3).unwrap() - - (expected_fee * 3).unwrap()) + - (transfer_amount * 3u64).unwrap() + - (expected_fee * 3u64).unwrap()) .unwrap(); assert_eq!(st.get_total_balance(acct_id), expected_final); diff --git a/zcash_client_backend/src/fees.rs b/zcash_client_backend/src/fees.rs index a10136db51..343d8b72d2 100644 --- a/zcash_client_backend/src/fees.rs +++ b/zcash_client_backend/src/fees.rs @@ -10,13 +10,14 @@ use zcash_primitives::{ components::{amount::NonNegativeAmount, OutPoint}, fees::{ transparent::{self, InputSize}, - zip317 as prim_zip317, FeeRule, + zip317::{self as prim_zip317}, + FeeRule, }, }, }; use zcash_protocol::{PoolType, ShieldedProtocol}; -use crate::data_api::InputSource; +use crate::data_api::{BoundedU8, InputSource}; pub mod common; #[cfg(feature = "non-standard-fees")] @@ -355,18 +356,27 @@ impl Default for DustOutputPolicy { #[derive(Clone, Copy, Debug)] pub struct SplitPolicy { target_output_count: NonZeroUsize, - min_split_output_size: NonNegativeAmount, + min_split_output_value: Option, + notes_must_exceed_prior_send_percentile: Option>, + notes_must_exceed_balance_percentage: Option>, } impl SplitPolicy { - /// Constructs a new [`SplitPolicy`] from its constituent parts. - pub fn new( + /// In the case that no other conditions provided by the user are available to fall back on, + /// a default value of [`MARGINAL_FEE`] * 100 will be used as the "minimum usable note value" + /// when retrieving wallet metadata. + pub(crate) const MIN_NOTE_VALUE: NonNegativeAmount = NonNegativeAmount::const_from_u64(500000); + + /// Constructs a new [`SplitPolicy`] that splits using a fixed minimum note value. + pub fn with_min_output_value( target_output_count: NonZeroUsize, - min_split_output_size: NonNegativeAmount, + min_split_output_value: NonNegativeAmount, ) -> Self { Self { target_output_count, - min_split_output_size, + min_split_output_value: Some(min_split_output_value), + notes_must_exceed_prior_send_percentile: None, + notes_must_exceed_balance_percentage: None, } } @@ -374,7 +384,9 @@ impl SplitPolicy { pub fn single_output() -> Self { Self { target_output_count: NonZeroUsize::MIN, - min_split_output_size: NonNegativeAmount::ZERO, + min_split_output_value: None, + notes_must_exceed_prior_send_percentile: None, + notes_must_exceed_balance_percentage: None, } } @@ -382,36 +394,74 @@ impl SplitPolicy { /// /// If splitting change would result in notes of value less than the minimum split output size, /// a smaller number of splits should be chosen. - pub fn min_split_output_size(&self) -> NonNegativeAmount { - self.min_split_output_size + pub fn min_split_output_value(&self) -> Option { + self.min_split_output_value + } + + /// Returns the bound on output size that is used to evaluate against prior send behavior. + /// + /// If splitting change would result in notes of value less than the `n`'th percentile of prior + /// send values, a smaller number of splits should be chosen. + pub fn notes_must_exceed_prior_send_percentile(&self) -> Option> { + self.notes_must_exceed_prior_send_percentile + } + + /// Returns the bound on output size that is used to evaluate against wallet balance. + /// + /// If splitting change would result in notes of value less than `n` percent of the wallet + /// balance, a smaller number of splits should be chosen. + pub fn notes_must_exceed_balance_percentage(&self) -> Option> { + self.notes_must_exceed_balance_percentage } /// Returns the number of output notes to produce from the given total change value, given the - /// number of existing unspent notes in the account and this policy. + /// total value and number of existing unspent notes in the account and this policy. pub fn split_count( &self, existing_notes: usize, + existing_notes_total: NonNegativeAmount, total_change: NonNegativeAmount, ) -> NonZeroUsize { + fn to_nonzero_u64(value: usize) -> NonZeroU64 { + NonZeroU64::new(u64::try_from(value).expect("usize fits into u64")) + .expect("NonZeroU64 input derived from NonZeroUsize") + } + let mut split_count = NonZeroUsize::new(usize::from(self.target_output_count).saturating_sub(existing_notes)) .unwrap_or(NonZeroUsize::MIN); - loop { - let per_output_change = total_change.div_with_remainder( - NonZeroU64::new( - u64::try_from(usize::from(split_count)).expect("usize fits into u64"), - ) - .unwrap(), - ); - if *per_output_change.quotient() >= self.min_split_output_size { - return split_count; - } else if let Some(new_count) = NonZeroUsize::new(usize::from(split_count) - 1) { - split_count = new_count; - } else { - // We always create at least one change output. - return NonZeroUsize::MIN; + let min_split_output_value = self.min_split_output_value.or_else(|| { + // If no minimum split output size is set, we choose the minimum split size to be a + // quarter of the average value of notes in the wallet after the transaction. + (existing_notes_total + total_change).map(|total| { + *total + .div_with_remainder(to_nonzero_u64( + usize::from(self.target_output_count).saturating_mul(4), + )) + .quotient() + }) + }); + + if let Some(min_split_output_value) = min_split_output_value { + loop { + let per_output_change = + total_change.div_with_remainder(to_nonzero_u64(usize::from(split_count))); + if *per_output_change.quotient() >= min_split_output_value { + return split_count; + } else if let Some(new_count) = NonZeroUsize::new(usize::from(split_count) - 1) { + split_count = new_count; + } else { + // We always create at least one change output. + return NonZeroUsize::MIN; + } } + } else { + // This is purely defensive; this case would only arise in the case that the addition + // of the existing notes with the total change overflows the maximum monetary amount. + // Since it's always safe to fall back to a single change value, this is better than a + // panic. + NonZeroUsize::MIN } } } diff --git a/zcash_client_backend/src/fees/common.rs b/zcash_client_backend/src/fees/common.rs index 2a1a4a4e2d..37d70cea94 100644 --- a/zcash_client_backend/src/fees/common.rs +++ b/zcash_client_backend/src/fees/common.rs @@ -429,8 +429,11 @@ where // available in the wallet, irrespective of pool. If we don't have any wallet metadata // available, we fall back to generating a single change output. let split_count = wallet_meta.map_or(NonZeroUsize::MIN, |wm| { - cfg.split_policy - .split_count(wm.total_note_count(), proposed_change) + cfg.split_policy.split_count( + wm.total_note_count(), + wm.total_value(), + proposed_change, + ) }); let per_output_change = proposed_change.div_with_remainder( NonZeroU64::new( @@ -531,8 +534,8 @@ where // We can add a change output if necessary. assert!(fee_with_change <= fee_with_dust); - let reasonable_fee = - (fee_with_change + (MINIMUM_FEE * 10).unwrap()).ok_or_else(overflow)?; + let reasonable_fee = (fee_with_change + (MINIMUM_FEE * 10u64).unwrap()) + .ok_or_else(overflow)?; if fee_with_dust > reasonable_fee { // Defend against losing money by using AddDustToFee with a too-high diff --git a/zcash_client_backend/src/fees/zip317.rs b/zcash_client_backend/src/fees/zip317.rs index 7d89897be6..a16b9be4ad 100644 --- a/zcash_client_backend/src/fees/zip317.rs +++ b/zcash_client_backend/src/fees/zip317.rs @@ -14,7 +14,7 @@ use zcash_primitives::{ use zcash_protocol::value::{BalanceError, Zatoshis}; use crate::{ - data_api::{InputSource, WalletMeta}, + data_api::{InputSource, NoteSelector, WalletMeta}, fees::StandardFeeRule, ShieldedProtocol, }; @@ -216,7 +216,13 @@ where account: ::AccountId, exclude: &[::NoteRef], ) -> Result::Error> { - meta_source.get_wallet_metadata(account, self.split_policy.min_split_output_size(), exclude) + let note_selector = NoteSelector::ExceedsMinValue( + self.split_policy + .min_split_output_value() + .unwrap_or(SplitPolicy::MIN_NOTE_VALUE), + ); + + meta_source.get_wallet_metadata(account, ¬e_selector, exclude) } fn compute_balance( @@ -334,7 +340,7 @@ mod tests { None, ShieldedProtocol::Sapling, DustOutputPolicy::default(), - SplitPolicy::new( + SplitPolicy::with_min_output_value( NonZeroUsize::new(5).unwrap(), NonNegativeAmount::const_from_u64(100_0000), ), @@ -342,7 +348,7 @@ mod tests { { // spend a single Sapling note and produce 5 outputs - let balance = |existing_notes| { + let balance = |existing_notes, total| { change_strategy.compute_balance( &Network::TestNetwork, Network::TestNetwork @@ -365,14 +371,17 @@ mod tests { None, &WalletMeta::new( existing_notes, + total, #[cfg(feature = "orchard")] 0, + #[cfg(feature = "orchard")] + NonNegativeAmount::ZERO, ), ) }; assert_matches!( - balance(0), + balance(0, NonNegativeAmount::ZERO), Ok(balance) if balance.proposed_change() == [ ChangeValue::sapling(NonNegativeAmount::const_from_u64(129_4000), None), @@ -385,7 +394,7 @@ mod tests { ); assert_matches!( - balance(2), + balance(2, NonNegativeAmount::const_from_u64(100_0000)), Ok(balance) if balance.proposed_change() == [ ChangeValue::sapling(NonNegativeAmount::const_from_u64(216_0000), None), @@ -421,8 +430,11 @@ mod tests { None, &WalletMeta::new( 0, + NonNegativeAmount::ZERO, #[cfg(feature = "orchard")] 0, + #[cfg(feature = "orchard")] + NonNegativeAmount::ZERO, ), ); diff --git a/zcash_client_sqlite/CHANGELOG.md b/zcash_client_sqlite/CHANGELOG.md index 035313f541..0851eb471e 100644 --- a/zcash_client_sqlite/CHANGELOG.md +++ b/zcash_client_sqlite/CHANGELOG.md @@ -15,6 +15,7 @@ and this library adheres to Rust's notion of - MSRV is now 1.77.0. - Migrated from `schemer` to our fork `schemerz`. - Migrated to `rusqlite 0.32`. +- `error::SqliteClientError` has additional variant `NoteSelectorInvalid` ## [0.12.2] - 2024-10-21 diff --git a/zcash_client_sqlite/src/error.rs b/zcash_client_sqlite/src/error.rs index 741003b6ea..eebbe2dd5d 100644 --- a/zcash_client_sqlite/src/error.rs +++ b/zcash_client_sqlite/src/error.rs @@ -5,6 +5,7 @@ use std::fmt; use shardtree::error::ShardTreeError; use zcash_address::ParseError; +use zcash_client_backend::data_api::NoteSelector; use zcash_client_backend::PoolType; use zcash_keys::keys::AddressGenerationError; use zcash_primitives::zip32; @@ -121,6 +122,9 @@ pub enum SqliteClientError { /// An error occurred in computing wallet balance BalanceError(BalanceError), + /// A note selection query contained an invalid constant or was otherwise not supported. + NoteSelectorInvalid(NoteSelector), + /// The proposal cannot be constructed until transactions with previously reserved /// ephemeral address outputs have been mined. The parameters are the account id and /// the index that could not safely be reserved. @@ -187,6 +191,7 @@ impl fmt::Display for SqliteClientError { SqliteClientError::ChainHeightUnknown => write!(f, "Chain height unknown; please call `update_chain_tip`"), SqliteClientError::UnsupportedPoolType(t) => write!(f, "Pool type is not currently supported: {}", t), SqliteClientError::BalanceError(e) => write!(f, "Balance error: {}", e), + SqliteClientError::NoteSelectorInvalid(s) => write!(f, "Could not evaluate selection query: {:?}", s), #[cfg(feature = "transparent-inputs")] SqliteClientError::ReachedGapLimit(account_id, bad_index) => write!(f, "The proposal cannot be constructed until transactions with previously reserved ephemeral address outputs have been mined. \ diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index f110229273..30f847a342 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -51,9 +51,10 @@ use zcash_client_backend::{ chain::{BlockSource, ChainState, CommitmentTreeRoot}, scanning::{ScanPriority, ScanRange}, Account, AccountBirthday, AccountPurpose, AccountSource, BlockMetadata, - DecryptedTransaction, InputSource, NullifierQuery, ScannedBlock, SeedRelevance, - SentTransaction, SpendableNotes, TransactionDataRequest, WalletCommitmentTrees, WalletMeta, - WalletRead, WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT, + DecryptedTransaction, InputSource, NoteSelector, NullifierQuery, ScannedBlock, + SeedRelevance, SentTransaction, SpendableNotes, TransactionDataRequest, + WalletCommitmentTrees, WalletMeta, WalletRead, WalletSummary, WalletWrite, + SAPLING_SHARD_HEIGHT, }, keys::{ AddressGenerationError, UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey, @@ -128,7 +129,7 @@ pub mod error; pub mod wallet; use wallet::{ commitment_tree::{self, put_shard_roots}, - common::count_outputs, + common::spendable_notes_meta, SubtreeProgressEstimator, }; @@ -348,38 +349,43 @@ impl, P: consensus::Parameters> InputSource for ) } + /// Returns metadata for the spendable notes in the wallet. At present, + /// only [`NoteSelector::ExceedsMinValue`] is supported. fn get_wallet_metadata( &self, account_id: Self::AccountId, - min_value: NonNegativeAmount, + selector: &NoteSelector, exclude: &[Self::NoteRef], ) -> Result { let chain_tip_height = wallet::chain_tip_height(self.conn.borrow())? .ok_or(SqliteClientError::ChainHeightUnknown)?; - let sapling_note_count = count_outputs( + let sapling_pool_meta = spendable_notes_meta( self.conn.borrow(), - account_id, - min_value, - exclude, ShieldedProtocol::Sapling, chain_tip_height, + account_id, + selector, + exclude, )?; #[cfg(feature = "orchard")] - let orchard_note_count = count_outputs( + let orchard_pool_meta = spendable_notes_meta( self.conn.borrow(), - account_id, - min_value, - exclude, ShieldedProtocol::Orchard, chain_tip_height, + account_id, + selector, + exclude, )?; Ok(WalletMeta::new( - sapling_note_count, + sapling_pool_meta.note_count, + sapling_pool_meta.total_value, + #[cfg(feature = "orchard")] + orchard_pool_meta.note_count, #[cfg(feature = "orchard")] - orchard_note_count, + orchard_pool_meta.total_value, )) } } diff --git a/zcash_client_sqlite/src/wallet/common.rs b/zcash_client_sqlite/src/wallet/common.rs index 378abbe1a1..8a45e788f0 100644 --- a/zcash_client_sqlite/src/wallet/common.rs +++ b/zcash_client_sqlite/src/wallet/common.rs @@ -1,11 +1,14 @@ //! Functions common to Sapling and Orchard support in the wallet. use rusqlite::{named_params, types::Value, Connection, Row}; -use std::rc::Rc; +use std::{num::NonZeroU64, rc::Rc}; -use zcash_client_backend::{wallet::ReceivedNote, ShieldedProtocol}; -use zcash_primitives::transaction::{components::amount::NonNegativeAmount, TxId}; -use zcash_protocol::consensus::{self, BlockHeight}; +use zcash_client_backend::{data_api::NoteSelector, wallet::ReceivedNote, ShieldedProtocol}; +use zcash_primitives::transaction::{components::amount::NonNegativeAmount, fees::zip317, TxId}; +use zcash_protocol::{ + consensus::{self, BlockHeight}, + value::BalanceError, +}; use super::wallet_birthday; use crate::{error::SqliteClientError, AccountId, ReceivedNoteId, SAPLING_TABLES_PREFIX}; @@ -226,14 +229,19 @@ where .collect::>() } -pub(crate) fn count_outputs( +pub(crate) struct PoolMeta { + pub(crate) note_count: usize, + pub(crate) total_value: NonNegativeAmount, +} + +pub(crate) fn spendable_notes_meta( conn: &rusqlite::Connection, - account: AccountId, - min_value: NonNegativeAmount, - exclude: &[ReceivedNoteId], protocol: ShieldedProtocol, chain_tip_height: BlockHeight, -) -> Result { + account: AccountId, + selector: &NoteSelector, + exclude: &[ReceivedNoteId], +) -> Result { let (table_prefix, _, _) = per_protocol_names(protocol); let excluded: Vec = exclude @@ -248,33 +256,167 @@ pub(crate) fn count_outputs( .collect(); let excluded_ptr = Rc::new(excluded); - conn.query_row( - &format!( - "SELECT COUNT(*) - FROM {table_prefix}_received_notes rn - INNER JOIN accounts ON accounts.id = rn.account_id - INNER JOIN transactions ON transactions.id_tx = rn.tx - WHERE value >= :min_value - AND accounts.id = :account_id - AND accounts.ufvk IS NOT NULL - AND recipient_key_scope IS NOT NULL - AND transactions.mined_height IS NOT NULL - AND rn.id NOT IN rarray(:exclude) - AND rn.id NOT IN ( - SELECT {table_prefix}_received_note_id - FROM {table_prefix}_received_note_spends - JOIN transactions stx ON stx.id_tx = transaction_id - WHERE stx.block IS NOT NULL -- the spending tx is mined - OR stx.expiry_height IS NULL -- the spending tx will not expire - OR stx.expiry_height > :chain_tip_height -- the spending tx is unexpired - )" - ), - named_params![ - ":account_id": account.0, - ":min_value": u64::from(min_value), - ":exclude": &excluded_ptr, - ":chain_tip_height": u32::from(chain_tip_height) - ], - |row| row.get(0), - ) + fn zatoshis(value: i64) -> Result { + NonNegativeAmount::from_nonnegative_i64(value).map_err(|_| { + SqliteClientError::CorruptedData(format!("Negative received note value: {}", value)) + }) + } + + let run_selection = |min_value| { + conn.query_row_and_then::<_, SqliteClientError, _, _>( + &format!( + "SELECT COUNT(*), SUM(rn.value) + FROM {table_prefix}_received_notes rn + INNER JOIN transactions ON transactions.id_tx = rn.tx + WHERE rn.account_id = :account_id + AND rn.value >= :min_value + AND transactions.mined_height IS NOT NULL + AND rn.id NOT IN rarray(:exclude) + AND rn.id NOT IN ( + SELECT {table_prefix}_received_note_id + FROM {table_prefix}_received_note_spends rns + JOIN transactions stx ON stx.id_tx = rns.transaction_id + WHERE stx.block IS NOT NULL -- the spending tx is mined + OR stx.expiry_height IS NULL -- the spending tx will not expire + OR stx.expiry_height > :chain_tip_height -- the spending tx is unexpired + )" + ), + named_params![ + ":account_id": account.0, + ":min_value": u64::from(min_value), + ":exclude": &excluded_ptr, + ":chain_tip_height": u32::from(chain_tip_height) + ], + |row| { + Ok(( + row.get::<_, usize>(0)?, + row.get::<_, Option>(1)?.map(zatoshis).transpose()?, + )) + }, + ) + }; + + fn min_note_value( + conn: &rusqlite::Connection, + table_prefix: &str, + account: AccountId, + selector: &NoteSelector, + chain_tip_height: BlockHeight, + ) -> Result, SqliteClientError> { + match selector { + NoteSelector::ExceedsMinValue(v) => Ok(Some(*v)), + NoteSelector::ExceedsPriorSendPercentile(n) => { + let mut bucket_query = conn.prepare( + "WITH bucketed AS ( + SELECT s.value, NTILE(10) OVER (ORDER BY s.value) AS percentile + FROM sent_notes s + JOIN transactions t ON s.tx = t.id_tx + WHERE s.from_account_id = :account_id + -- only count mined transactions + t.mined_height IS NOT NULL + -- exclude change and account-internal sends + AND (s.to_account_id IS NULL OR s.from_account_id != s.to_account_id) + ) + SELECT percentile, MAX(value) as value + FROM bucketed + GROUP BY percentile", + )?; + + let bucket_maxima = bucket_query + .query_and_then::<_, SqliteClientError, _, _>( + named_params![":account_id": account.0], + |row| { + let value = + NonNegativeAmount::from_nonnegative_i64(row.get::<_, i64>(1)?) + .map_err(|_| { + SqliteClientError::CorruptedData(format!( + "Negative received note value: {}", + n.value() + )) + })?; + + Ok(value) + }, + )? + .collect::, _>>()?; + + // Pick a bucket index by scaling the requested percentile to the number of buckets + let i = bucket_maxima.len() * usize::from(*n) / 100; + Ok(bucket_maxima.get(i).copied()) + } + NoteSelector::ExceedsBalancePercentage(p) => { + let balance = conn.query_row_and_then::<_, SqliteClientError, _, _>( + &format!( + "SELECT SUM(rn.value) + FROM {table_prefix}_received_notes rn + INNER JOIN transactions ON transactions.id_tx = rn.tx + WHERE rn.account_id = :account_id + AND transactions.mined_height IS NOT NULL + AND rn.id NOT IN ( + SELECT {table_prefix}_received_note_id + FROM {table_prefix}_received_note_spends rns + JOIN transactions stx ON stx.id_tx = rns.transaction_id + WHERE stx.block IS NOT NULL -- the spending tx is mined + OR stx.expiry_height IS NULL -- the spending tx will not expire + OR stx.expiry_height > :chain_tip_height -- the spending tx is unexpired + )" + ), + named_params![ + ":account_id": account.0, + ":chain_tip_height": u32::from(chain_tip_height) + ], + |row| row.get::<_, Option>(1)?.map(zatoshis).transpose(), + )?; + + Ok(match balance { + None => None, + Some(b) => { + let numerator = (b * u64::from(p.value())).ok_or(BalanceError::Overflow)?; + Some(numerator / NonZeroU64::new(100).expect("Constant is nonzero.")) + } + }) + } + NoteSelector::And(a, b) => { + // All the existing note selectors set lower bounds on note value, so the "and" + // operation is just taking the maximum of the two lower bounds. + let a_min_value = + min_note_value(conn, table_prefix, account, a.as_ref(), chain_tip_height)?; + let b_min_value = + min_note_value(conn, table_prefix, account, b.as_ref(), chain_tip_height)?; + Ok(a_min_value + .zip(b_min_value) + .map(|(av, bv)| std::cmp::max(av, bv)) + .or(a_min_value) + .or(b_min_value)) + } + NoteSelector::Attempt { + condition, + fallback, + } => { + let cond = min_note_value( + conn, + table_prefix, + account, + condition.as_ref(), + chain_tip_height, + )?; + if cond.is_none() { + min_note_value(conn, table_prefix, account, fallback, chain_tip_height) + } else { + Ok(cond) + } + } + } + } + + // In the absence of any limitations on note values, we return metadata for all notes with + // value >= 10x the marginal fee + let min_value = min_note_value(conn, table_prefix, account, selector, chain_tip_height)? + .unwrap_or((zip317::MARGINAL_FEE * 10u64).unwrap()); + let (note_count, total_value) = run_selection(min_value)?; + + Ok(PoolMeta { + note_count, + total_value: total_value.unwrap_or(NonNegativeAmount::ZERO), + }) } diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index baed0a1297..5ef18818ab 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -230,6 +230,9 @@ fn sqlite_client_error_to_wallet_migration_error(e: SqliteClientError) -> Wallet SqliteClientError::EphemeralAddressReuse(_, _) => { unreachable!("we don't do ephemeral address tracking") } + SqliteClientError::NoteSelectorInvalid(_) => { + unreachable!("we don't do note selection in migrations") + } } }