diff --git a/sdk/src/client/api/block_builder/mod.rs b/sdk/src/client/api/block_builder/mod.rs index 80b5b78280..78a1a5d4e8 100644 --- a/sdk/src/client/api/block_builder/mod.rs +++ b/sdk/src/client/api/block_builder/mod.rs @@ -50,7 +50,7 @@ impl ClientInner { let protocol_params = self.get_protocol_parameters().await?; - Ok(BlockWrapper::build( + BlockWrapper::build( BlockHeader::new( protocol_params.version(), protocol_params.network_id(), @@ -66,6 +66,6 @@ impl ClientInner { .finish_block()?, ) .sign_ed25519(secret_manager, chain) - .await?) + .await } } diff --git a/sdk/src/types/block/error.rs b/sdk/src/types/block/error.rs index 10368d0a33..a960abf307 100644 --- a/sdk/src/types/block/error.rs +++ b/sdk/src/types/block/error.rs @@ -173,6 +173,10 @@ pub enum Error { DuplicateOutputChain(ChainId), InvalidField(&'static str), NullDelegationValidatorId, + InvalidEpochDelta { + created: EpochIndex, + target: EpochIndex, + }, } #[cfg(feature = "std")] @@ -375,6 +379,9 @@ impl fmt::Display for Error { Self::DuplicateOutputChain(chain_id) => write!(f, "duplicate output chain {chain_id}"), Self::InvalidField(field) => write!(f, "invalid field: {field}"), Self::NullDelegationValidatorId => write!(f, "null delegation validator ID"), + Self::InvalidEpochDelta { created, target } => { + write!(f, "invalid epoch delta: created {created}, target {target}") + } } } } diff --git a/sdk/src/types/block/mana/structure.rs b/sdk/src/types/block/mana/structure.rs index 68d926884a..713b2a5152 100644 --- a/sdk/src/types/block/mana/structure.rs +++ b/sdk/src/types/block/mana/structure.rs @@ -4,7 +4,11 @@ use getset::CopyGetters; use packable::{prefix::BoxedSlicePrefix, Packable}; -use crate::types::block::{slot::EpochIndex, Error}; +use crate::types::block::{ + protocol::ProtocolParameters, + slot::{EpochIndex, SlotIndex}, + Error, +}; #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Packable, CopyGetters)] #[cfg_attr( @@ -41,14 +45,53 @@ impl ManaStructure { } /// Returns the mana decay factor for the given epoch index. - pub fn decay_factor_at(&self, epoch_index: EpochIndex) -> Option { - self.decay_factors.get(*epoch_index as usize).copied() + pub fn decay_factor_at(&self, epoch_index: impl Into) -> Option { + self.decay_factors.get(*epoch_index.into() as usize).copied() } /// Returns the max mana that can exist with the mana bits defined. pub fn max_mana(&self) -> u64 { (1 << self.bits_count) - 1 } + + fn decay(&self, mana: u64, epoch_delta: u32) -> u64 { + if mana == 0 || epoch_delta == 0 || self.decay_factors().is_empty() { + return mana; + } + + // split the value into two u64 variables to prevent overflowing + let mut mana_hi = upper_bits(mana); + let mut mana_lo = lower_bits(mana); + + // we keep applying the lookup table factors as long as n epochs are left + let mut remaining_epochs = epoch_delta; + + while remaining_epochs > 0 { + let epochs_to_decay = remaining_epochs.min(self.decay_factors().len() as u32); + remaining_epochs -= epochs_to_decay; + + // Unwrap: Safe because the index is at most the length + let decay_factor = self.decay_factor_at(epochs_to_decay - 1).unwrap(); + + // apply the decay using fixed-point arithmetics. + (mana_hi, mana_lo) = + multiplication_and_shift(mana_hi, mana_lo, decay_factor, self.decay_factors_exponent()); + } + + // combine both u64 variables to get the actual value + mana_hi << 32 | mana_lo + } + + fn generate_mana(&self, amount: u64, slot_delta: u32) -> u64 { + if self.generation_rate() == 0 || slot_delta == 0 { + return 0; + } + fixed_point_multiply( + amount, + slot_delta * self.generation_rate() as u32, + self.generation_rate_exponent(), + ) + } } impl Default for ManaStructure { @@ -65,3 +108,356 @@ impl Default for ManaStructure { } } } + +impl ProtocolParameters { + /// Applies mana decay to the given mana. + pub fn mana_with_decay( + &self, + mana: u64, + slot_index_created: impl Into, + slot_index_target: impl Into, + ) -> Result { + let (slot_index_created, slot_index_target) = (slot_index_created.into(), slot_index_target.into()); + let (epoch_index_created, epoch_index_target) = ( + self.epoch_index_of(slot_index_created), + self.epoch_index_of(slot_index_target), + ); + if epoch_index_created > epoch_index_target { + return Err(Error::InvalidEpochDelta { + created: epoch_index_created, + target: epoch_index_target, + }); + } + Ok(self + .mana_structure() + .decay(mana, epoch_index_target.0 - epoch_index_created.0)) + } + + /// Applies mana decay to the given stored mana. + pub fn rewards_with_decay( + &self, + reward: u64, + reward_epoch: impl Into, + claimed_epoch: impl Into, + ) -> Result { + let (reward_epoch, claimed_epoch) = (reward_epoch.into(), claimed_epoch.into()); + if reward_epoch > claimed_epoch { + return Err(Error::InvalidEpochDelta { + created: reward_epoch, + target: claimed_epoch, + }); + } + Ok(self.mana_structure().decay(reward, claimed_epoch.0 - reward_epoch.0)) + } + + /// Calculates the potential mana that is generated by holding `amount` tokens from `slot_index_created` to + /// `slot_index_target` and applies the decay to the result + pub fn potential_mana( + &self, + amount: u64, + slot_index_created: impl Into, + slot_index_target: impl Into, + ) -> Result { + let (slot_index_created, slot_index_target) = (slot_index_created.into(), slot_index_target.into()); + let (epoch_index_created, epoch_index_target) = ( + self.epoch_index_of(slot_index_created), + self.epoch_index_of(slot_index_target), + ); + if epoch_index_created > epoch_index_target { + return Err(Error::InvalidEpochDelta { + created: epoch_index_created, + target: epoch_index_target, + }); + } + if slot_index_created >= slot_index_target { + return Ok(0); + } + let mana_structure = self.mana_structure(); + + Ok(if epoch_index_created == epoch_index_target { + mana_structure.generate_mana(amount, slot_index_target.0 - slot_index_created.0) + } else if epoch_index_target == epoch_index_created + 1 { + let slots_before_next_epoch = self.first_slot_of(epoch_index_created + 1) - slot_index_created; + let slots_since_epoch_start = slot_index_target - self.last_slot_of(epoch_index_target - 1); + let mana_decayed = mana_structure.decay(mana_structure.generate_mana(amount, slots_before_next_epoch.0), 1); + let mana_generated = mana_structure.generate_mana(amount, slots_since_epoch_start.0); + mana_decayed + mana_generated + } else { + let c = fixed_point_multiply( + amount, + mana_structure.decay_factor_epochs_sum(), + mana_structure.decay_factor_epochs_sum_exponent() + mana_structure.generation_rate_exponent() + - self.slots_per_epoch_exponent(), + ); + let slots_before_next_epoch = self.first_slot_of(epoch_index_created + 1) - slot_index_created; + let slots_since_epoch_start = slot_index_target - self.last_slot_of(epoch_index_target - 1); + let potential_mana_n = mana_structure.decay( + mana_structure.generate_mana(amount, slots_before_next_epoch.0), + epoch_index_target.0 - epoch_index_created.0, + ); + let potential_mana_n_1 = mana_structure.decay(c, epoch_index_target.0 - epoch_index_created.0); + let potential_mana_0 = c + mana_structure.generate_mana(amount, slots_since_epoch_start.0) + - (c >> mana_structure.generation_rate_exponent()); + potential_mana_0 - potential_mana_n_1 + potential_mana_n + }) + } +} + +/// Returns the upper 32 bits of a u64 value. +const fn upper_bits(v: u64) -> u64 { + v >> 32 +} + +/// Returns the lower n bits of a u64 value. +const fn lower_n_bits(v: u64, n: u8) -> u64 { + debug_assert!(n <= 64); + if n == 0 { + return 0; + } + v & u64::MAX >> (64 - n) +} + +/// Returns the lower 32 bits of a u64 value. +const fn lower_bits(v: u64) -> u64 { + v & 0xFFFFFFFF +} + +/// Returns the result of the multiplication ((value_hi << 32 + value_lo) * mult_factor) >> shift_factor +/// (where mult_factor is a uint32, value_hi and value_lo are uint64 smaller than 2^32, and 0 <= shift_factor <= +/// 32), using only uint64 multiplication functions, without overflowing. The returned result is split +/// in 2 factors: value_hi and value_lo, one containing the upper 32 bits of the result and the other +/// containing the lower 32 bits. +fn multiplication_and_shift(mut value_hi: u64, mut value_lo: u64, mult_factor: u32, shift_factor: u8) -> (u64, u64) { + debug_assert!(shift_factor <= 32); + // multiply the integer part of value_hi by mult_factor + value_hi *= mult_factor as u64; + + // the lower shift_factor bits of the result are extracted and shifted left to form the remainder. + // value_lo is multiplied by mult_factor and right-shifted by shift_factor bits. + // the sum of these two values forms the new lower part (value_lo) of the result. + value_lo = (lower_n_bits(value_hi, shift_factor) << (32 - shift_factor)) + + ((value_lo * mult_factor as u64) >> shift_factor); + + // the right-shifted value_hi and the upper 32 bits of value_lo form the new higher part (value_hi) of the + // result. + value_hi = (value_hi >> shift_factor) + upper_bits(value_lo); + + // the lower 32 bits of value_lo form the new lower part of the result. + value_lo = lower_bits(value_lo); + + // return the result as a fixed-point number composed of two 64-bit integers + (value_hi, value_lo) +} + +/// Wrapper for [`multiplication_and_shift`] that splits and re-combines the given value. +fn fixed_point_multiply(value: u64, mult_factor: u32, shift_factor: u8) -> u64 { + let value_hi = upper_bits(value); + let value_lo = lower_bits(value); + let (amount_hi, amount_lo) = multiplication_and_shift(value_hi, value_lo, mult_factor, shift_factor); + amount_hi << 32 | amount_lo +} + +#[cfg(test)] +mod test { + use super::*; + + // Tests from https://github.com/iotaledger/iota.go/blob/develop/mana_decay_provider_test.go + + const BETA_PER_YEAR: f64 = 1. / 3.; + + fn params() -> &'static ProtocolParameters { + use once_cell::sync::Lazy; + static PARAMS: Lazy = Lazy::new(|| { + // TODO: these params are clearly wrong as the calculation fails due to shifting > 32 bits + let mut params = ProtocolParameters { + slots_per_epoch_exponent: 13, + slot_duration_in_seconds: 10, + mana_structure: ManaStructure { + bits_count: 63, + generation_rate: 1, + generation_rate_exponent: 27, + decay_factors_exponent: 32, + decay_factor_epochs_sum_exponent: 20, + ..Default::default() + }, + ..Default::default() + }; + // TODO: Just use the generated values from go + params.mana_structure.decay_factors = { + let epochs_per_year = ((365_u64 * 24 * 60 * 60) as f64 / params.slot_duration_in_seconds() as f64) + / params.slots_per_epoch() as f64; + let beta_per_epoch_index = BETA_PER_YEAR / epochs_per_year; + (1..epochs_per_year.floor() as usize) + .map(|epoch| { + ((-beta_per_epoch_index * epoch as f64).exp() + * 2_f64.powf(params.mana_structure().decay_factors_exponent() as _)) + .floor() as u32 + }) + .collect::>() + } + .try_into() + .unwrap(); + params.mana_structure.decay_factor_epochs_sum = { + let delta = params.slots_per_epoch() as f64 * params.slot_duration_in_seconds() as f64 + / (365_u64 * 24 * 60 * 60) as f64; + (((-BETA_PER_YEAR * delta).exp() / (1. - (-BETA_PER_YEAR * delta).exp())) + * 2_f64.powf(params.mana_structure().decay_factor_epochs_sum_exponent() as _)) + .floor() as u32 + }; + params + }); + &*PARAMS + } + + #[test] + fn test_mana_decay_no_factors() { + let mana_structure = ManaStructure { + decay_factors: Box::<[_]>::default().try_into().unwrap(), + ..Default::default() + }; + assert_eq!(mana_structure.decay(100, 100), 100); + } + + #[test] + fn test_mana_decay_no_delta() { + assert_eq!( + params().mana_with_decay(100, params().first_slot_of(1), params().first_slot_of(1)), + Ok(100) + ); + } + + #[test] + fn test_mana_decay_no_mana() { + assert_eq!( + params().mana_with_decay(0, params().first_slot_of(1), params().first_slot_of(400)), + Ok(0) + ); + } + + #[test] + fn test_mana_decay_negative_delta() { + assert_eq!( + params().mana_with_decay(100, params().first_slot_of(2), params().first_slot_of(1)), + Err(Error::InvalidEpochDelta { + created: 2.into(), + target: 1.into() + }) + ); + } + + // TODO: Re-enable the commented tests once the test data is sorted out + // #[test] + // fn test_mana_decay_lookup_len_delta() { + // assert_eq!( + // params().mana_with_decay( + // u64::MAX, + // params().first_slot_of(1), + // params().first_slot_of(params().mana_structure().decay_factors().len() as u32 + 1) + // ), + // Ok(13228672242897911807) + // ); + // } + + // #[test] + // fn test_mana_decay_lookup_len_delta_multiple() { + // assert_eq!( + // params().mana_with_decay( + // u64::MAX, + // params().first_slot_of(1), + // params().first_slot_of(3 * params().mana_structure().decay_factors().len() as u32 + 1) + // ), + // Ok(6803138682699798504) + // ); + // } + + // #[test] + // fn test_mana_decay_max_mana() { + // assert_eq!( + // params().mana_with_decay(u64::MAX, params().first_slot_of(1), params().first_slot_of(401)), + // Ok(13046663022640287317) + // ); + // } + + #[test] + fn test_potential_mana_no_delta() { + assert_eq!( + params().potential_mana(100, params().first_slot_of(1), params().first_slot_of(1)), + Ok(0) + ); + } + + // #[test] + // fn test_potential_mana_no_mana() { + // assert_eq!( + // params().potential_mana(0, params().first_slot_of(1), params().first_slot_of(400)), + // Ok(0) + // ); + // } + + #[test] + fn test_potential_mana_negative_delta() { + assert_eq!( + params().potential_mana(100, params().first_slot_of(2), params().first_slot_of(1)), + Err(Error::InvalidEpochDelta { + created: 2.into(), + target: 1.into() + }) + ); + } + + // #[test] + // fn test_potential_mana_lookup_len_delta() { + // assert_eq!( + // params().potential_mana( + // u64::MAX, + // params().first_slot_of(1), + // params().first_slot_of(params().mana_structure().decay_factors().len() as u32 + 1) + // ), + // Ok(183827294847826527) + // ); + // } + + // #[test] + // fn test_potential_mana_lookup_len_delta_multiple() { + // assert_eq!( + // params().potential_mana( + // u64::MAX, + // params().first_slot_of(1), + // params().first_slot_of(3 * params().mana_structure().decay_factors().len() as u32 + 1) + // ), + // Ok(410192222442040018) + // ); + // } + + // #[test] + // fn test_potential_mana_same_epoch() { + // assert_eq!( + // params().potential_mana(u64::MAX, params().first_slot_of(1), params().last_slot_of(1)), + // Ok(562881233944575) + // ); + // } + + // #[test] + // fn test_potential_mana_one_epoch() { + // assert_eq!( + // params().potential_mana(u64::MAX, params().first_slot_of(1), params().last_slot_of(2)), + // Ok(1125343946211326) + // ); + // } + + // #[test] + // fn test_potential_mana_several_epochs() { + // assert_eq!( + // params().potential_mana(u64::MAX, params().first_slot_of(1), params().last_slot_of(3)), + // Ok(1687319975062367) + // ); + // } + + // #[test] + // fn test_potential_mana_max_mana() { + // assert_eq!( + // params().potential_mana(u64::MAX, params().first_slot_of(1), params().first_slot_of(401)), + // Ok(190239292158065300) + // ); + // } +} diff --git a/sdk/src/types/block/output/feature/block_issuer.rs b/sdk/src/types/block/output/feature/block_issuer.rs index c7682bdf33..5f86e3931f 100644 --- a/sdk/src/types/block/output/feature/block_issuer.rs +++ b/sdk/src/types/block/output/feature/block_issuer.rs @@ -67,10 +67,10 @@ impl Ed25519BlockIssuerKey { /// The block issuer key kind of an [`Ed25519BlockIssuerKey`]. pub const KIND: u8 = 0; /// Length of an ED25519 block issuer key. - pub const PUBLIC_KEY_LENGTH: usize = ed25519::PublicKey::LENGTH; + pub const LENGTH: usize = ed25519::PublicKey::LENGTH; /// Creates a new [`Ed25519BlockIssuerKey`] from bytes. - pub fn try_from_bytes(bytes: [u8; Self::PUBLIC_KEY_LENGTH]) -> Result { + pub fn try_from_bytes(bytes: [u8; Self::LENGTH]) -> Result { Ok(Self(ed25519::PublicKey::try_from_bytes(bytes)?)) } } @@ -94,7 +94,7 @@ impl Packable for Ed25519BlockIssuerKey { unpacker: &mut U, visitor: &Self::UnpackVisitor, ) -> Result> { - Self::try_from_bytes(<[u8; Self::PUBLIC_KEY_LENGTH]>::unpack::<_, VERIFY>(unpacker, visitor).coerce()?) + Self::try_from_bytes(<[u8; Self::LENGTH]>::unpack::<_, VERIFY>(unpacker, visitor).coerce()?) .map_err(UnpackError::Packable) } } @@ -148,11 +148,11 @@ impl IntoIterator for BlockIssuerKeys { } impl BlockIssuerKeys { - /// The minimum number of block_issuer_keys in a [`BlockIssuerFeature`]. + /// The minimum number of block issuer keys in a [`BlockIssuerFeature`]. pub const COUNT_MIN: u8 = 1; - /// The maximum number of block_issuer_keys in a [`BlockIssuerFeature`]. + /// The maximum number of block issuer keys in a [`BlockIssuerFeature`]. pub const COUNT_MAX: u8 = 128; - /// The range of valid numbers of block_issuer_keys. + /// The range of valid numbers of block issuer keys. pub const COUNT_RANGE: RangeInclusive = Self::COUNT_MIN..=Self::COUNT_MAX; // [1..128] /// Creates a new [`BlockIssuerKeys`] from a vec. @@ -186,9 +186,9 @@ impl BlockIssuerKeys { #[derive(Clone, Debug, Eq, PartialEq, Hash, packable::Packable)] #[packable(unpack_error = Error)] pub struct BlockIssuerFeature { - /// The slot index at which the Block Issuer Feature expires and can be removed. + /// The slot index at which the feature expires and can be removed. expiry_slot: SlotIndex, - /// The Block Issuer Keys. + /// The block issuer keys. block_issuer_keys: BlockIssuerKeys, } @@ -204,18 +204,19 @@ impl BlockIssuerFeature { ) -> Result { let block_issuer_keys = BlockIssuerKeys::from_vec(block_issuer_keys.into_iter().collect::>())?; + Ok(Self { expiry_slot: expiry_slot.into(), block_issuer_keys, }) } - /// Returns the Slot Index at which the Block Issuer Feature expires and can be removed. + /// Returns the expiry slot. pub fn expiry_slot(&self) -> SlotIndex { self.expiry_slot } - /// Returns the Block Issuer Keys. + /// Returns the block issuer keys. pub fn block_issuer_keys(&self) -> &[BlockIssuerKey] { &self.block_issuer_keys } diff --git a/sdk/src/types/block/protocol.rs b/sdk/src/types/block/protocol.rs index 1fe447b362..4deaafc7a8 100644 --- a/sdk/src/types/block/protocol.rs +++ b/sdk/src/types/block/protocol.rs @@ -11,7 +11,7 @@ use packable::{prefix::StringPrefix, Packable, PackableExt}; use super::{ address::Hrp, mana::{ManaStructure, RewardsParameters}, - slot::SlotIndex, + slot::{EpochIndex, SlotIndex}, }; use crate::types::block::{helper::network_name_to_id, output::RentStructure, ConvertTo, Error, PROTOCOL_VERSION}; @@ -170,6 +170,21 @@ impl ProtocolParameters { ) } + /// Gets the first [`SlotIndex`] of a given [`EpochIndex`]. + pub fn first_slot_of(&self, epoch_index: impl Into) -> SlotIndex { + epoch_index.into().first_slot_index(self.slots_per_epoch_exponent()) + } + + /// Gets the last [`SlotIndex`] of a given [`EpochIndex`]. + pub fn last_slot_of(&self, epoch_index: impl Into) -> SlotIndex { + epoch_index.into().last_slot_index(self.slots_per_epoch_exponent()) + } + + /// Gets the [`EpochIndex`] of a given [`SlotIndex`]. + pub fn epoch_index_of(&self, slot_index: impl Into) -> EpochIndex { + EpochIndex::from_slot_index(slot_index.into(), self.slots_per_epoch_exponent()) + } + /// Returns the hash of the [`ProtocolParameters`]. pub fn hash(&self) -> ProtocolParametersHash { ProtocolParametersHash::new(Blake2b256::digest(self.pack_to_vec()).into()) diff --git a/sdk/src/types/block/rand/slot.rs b/sdk/src/types/block/rand/slot.rs index 8e01489e1c..9b24c171e0 100644 --- a/sdk/src/types/block/rand/slot.rs +++ b/sdk/src/types/block/rand/slot.rs @@ -13,10 +13,10 @@ pub fn rand_slot_commitment_id() -> SlotCommitmentId { /// Generates a random slot index. pub fn rand_slot_index() -> SlotIndex { - SlotIndex::new(rand_number()) + SlotIndex(rand_number()) } /// Generates a random epoch index. pub fn rand_epoch_index() -> EpochIndex { - EpochIndex::new(rand_number()) + EpochIndex(rand_number()) } diff --git a/sdk/src/types/block/slot/epoch.rs b/sdk/src/types/block/slot/epoch.rs index b679c6d88e..4f49a43512 100644 --- a/sdk/src/types/block/slot/epoch.rs +++ b/sdk/src/types/block/slot/epoch.rs @@ -1,10 +1,9 @@ // Copyright 2023 IOTA Stiftung // SPDX-License-Identifier: Apache-2.0 -use derive_more::{Deref, Display, From, FromStr}; +use derive_more::{Add, AddAssign, Deref, Display, From, FromStr, Sub, SubAssign}; use super::SlotIndex; -use crate::types::block::Error; /// The tangle timeline is divided into epochs, and each epoch has a corresponding epoch index. Epochs are further /// subdivided into slots, each with a [`SlotIndex`]. @@ -32,52 +31,47 @@ use crate::types::block::Error; /// | 2 | 16 | 24 | // ... #[derive( - Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Hash, From, Deref, Display, FromStr, packable::Packable, + Copy, + Clone, + Debug, + Default, + Eq, + PartialEq, + Ord, + PartialOrd, + Hash, + From, + Deref, + Add, + AddAssign, + Sub, + SubAssign, + Display, + FromStr, + packable::Packable, )] #[repr(transparent)] -pub struct EpochIndex(u32); +pub struct EpochIndex(pub u32); impl EpochIndex { - /// Creates a new [`EpochIndex`]. - pub fn new(index: u32) -> Self { - Self::from(index) + /// Gets the range of slots this epoch contains. + pub fn slot_index_range(&self, slots_per_epoch_exponent: u8) -> core::ops::RangeInclusive { + self.first_slot_index(slots_per_epoch_exponent)..=self.last_slot_index(slots_per_epoch_exponent) } /// Gets the epoch index given a [`SlotIndex`]. - pub fn from_slot_index( - slot_index: SlotIndex, - slots_per_epoch_exponent_iter: impl Iterator, - ) -> Result { - let mut slot_index = *slot_index; - let mut res = 0; - let mut last = None; - for (start_epoch, exponent) in slots_per_epoch_exponent_iter { - if let Some((last_start_epoch, last_exponent)) = last { - if *start_epoch <= last_start_epoch { - return Err(Error::InvalidStartEpoch(start_epoch)); - } - // Get the number of slots this range of epochs represents - let slots_in_range = (*start_epoch - last_start_epoch) << last_exponent; - // Check whether the slot index is contained in this range - if slot_index > slots_in_range { - // Update the slot index so it is in the context of the next epoch - slot_index -= slots_in_range; - } else { - break; - } - } - if *start_epoch > res { - // We can't calculate the epoch if we don't have the exponent for the containing range - if slot_index > 0 { - return Err(Error::InvalidStartEpoch(start_epoch)); - } else { - break; - } - } - res = *start_epoch + (slot_index >> exponent); - last = Some((*start_epoch, exponent)); - } - Ok(Self(res)) + pub fn from_slot_index(slot_index: SlotIndex, slots_per_epoch_exponent: u8) -> Self { + Self(*slot_index >> slots_per_epoch_exponent) + } + + /// Gets the first [`SlotIndex`] of this epoch. + pub fn first_slot_index(self, slots_per_epoch_exponent: u8) -> SlotIndex { + SlotIndex::from_epoch_index(self, slots_per_epoch_exponent) + } + + /// Gets the last [`SlotIndex`] of this epoch. + pub fn last_slot_index(self, slots_per_epoch_exponent: u8) -> SlotIndex { + SlotIndex::from_epoch_index(self + 1, slots_per_epoch_exponent) - 1 } } @@ -93,6 +87,34 @@ impl PartialEq for EpochIndex { } } +impl core::ops::Add for EpochIndex { + type Output = Self; + + fn add(self, other: u32) -> Self { + Self(self.0 + other) + } +} + +impl core::ops::AddAssign for EpochIndex { + fn add_assign(&mut self, other: u32) { + self.0 += other; + } +} + +impl core::ops::Sub for EpochIndex { + type Output = Self; + + fn sub(self, other: u32) -> Self { + Self(self.0 - other) + } +} + +impl core::ops::SubAssign for EpochIndex { + fn sub_assign(&mut self, other: u32) { + self.0 -= other; + } +} + #[cfg(feature = "serde")] string_serde_impl!(EpochIndex); @@ -102,80 +124,26 @@ mod test { use crate::types::block::protocol::ProtocolParameters; #[test] - fn epoch_index_from_slot() { - let v3_params = ProtocolParameters { + fn epoch_index_to_from_slot() { + let params = ProtocolParameters { version: 3, slots_per_epoch_exponent: 10, ..Default::default() }; - let v4_params = ProtocolParameters { - version: 4, - slots_per_epoch_exponent: 11, - ..Default::default() - }; - let params = [(EpochIndex(0), v3_params.clone()), (EpochIndex(10), v4_params)]; - let slots_per_epoch_exponent_iter = params - .iter() - .map(|(start_index, params)| (*start_index, params.slots_per_epoch_exponent())); - - let slot_index = SlotIndex::new(3000); - let epoch_index = EpochIndex::from_slot_index(slot_index, slots_per_epoch_exponent_iter.clone()); - assert_eq!(epoch_index, Ok(EpochIndex(2))); - - let slot_index = SlotIndex::new(10 * v3_params.slots_per_epoch() + 3000); - let epoch_index = EpochIndex::from_slot_index(slot_index, slots_per_epoch_exponent_iter.clone()); - assert_eq!(epoch_index, Ok(EpochIndex(11))); - } - - #[test] - fn invalid_params() { - let v3_params = ProtocolParameters { - version: 3, - slots_per_epoch_exponent: 10, - ..Default::default() - }; - let v4_params = ProtocolParameters { - version: 4, - slots_per_epoch_exponent: 11, - ..Default::default() - }; - let v5_params = ProtocolParameters { - version: 5, - slots_per_epoch_exponent: 12, - ..Default::default() - }; - let slot_index = SlotIndex::new(100000); - - // Params must cover the entire history starting at epoch 0 - let params = [(EpochIndex(10), v4_params.clone()), (EpochIndex(20), v5_params.clone())]; - let slots_per_epoch_exponent_iter = params - .iter() - .map(|(start_index, params)| (*start_index, params.slots_per_epoch_exponent())); - let epoch_index = EpochIndex::from_slot_index(slot_index, slots_per_epoch_exponent_iter); - assert_eq!(epoch_index, Err(Error::InvalidStartEpoch(EpochIndex(10)))); - - // Params must not contain duplicate start epochs - let params = [ - (EpochIndex(0), v3_params.clone()), - (EpochIndex(10), v4_params.clone()), - (EpochIndex(10), v5_params.clone()), - ]; - let slots_per_epoch_exponent_iter = params - .iter() - .map(|(start_index, params)| (*start_index, params.slots_per_epoch_exponent())); - let epoch_index = EpochIndex::from_slot_index(slot_index, slots_per_epoch_exponent_iter); - assert_eq!(epoch_index, Err(Error::InvalidStartEpoch(EpochIndex(10)))); - - // Params must be properly ordered - let params = [ - (EpochIndex(10), v4_params), - (EpochIndex(0), v3_params), - (EpochIndex(20), v5_params), - ]; - let slots_per_epoch_exponent_iter = params - .iter() - .map(|(start_index, params)| (*start_index, params.slots_per_epoch_exponent())); - let epoch_index = EpochIndex::from_slot_index(slot_index, slots_per_epoch_exponent_iter); - assert_eq!(epoch_index, Err(Error::InvalidStartEpoch(EpochIndex(10)))); + let slot_index = SlotIndex(3000); + let epoch_index = EpochIndex::from_slot_index(slot_index, params.slots_per_epoch_exponent()); + assert_eq!(epoch_index, EpochIndex(2)); + assert_eq!( + epoch_index.slot_index_range(params.slots_per_epoch_exponent()), + SlotIndex(2048)..=SlotIndex(3071) + ); + + let slot_index = SlotIndex(10 * params.slots_per_epoch() + 2000); + let epoch_index = EpochIndex::from_slot_index(slot_index, params.slots_per_epoch_exponent()); + assert_eq!(epoch_index, EpochIndex(11)); + assert_eq!( + epoch_index.slot_index_range(params.slots_per_epoch_exponent()), + SlotIndex(11 * params.slots_per_epoch())..=SlotIndex(12 * params.slots_per_epoch() - 1) + ); } } diff --git a/sdk/src/types/block/slot/index.rs b/sdk/src/types/block/slot/index.rs index 963e77fbe1..472d417e16 100644 --- a/sdk/src/types/block/slot/index.rs +++ b/sdk/src/types/block/slot/index.rs @@ -4,7 +4,6 @@ use derive_more::{Add, AddAssign, Deref, Display, From, FromStr, Sub, SubAssign}; use super::EpochIndex; -use crate::types::block::Error; /// The tangle timeline is divided into epochs, and each epoch has a corresponding [`EpochIndex`]. Epochs are further /// subdivided into slots, each with a slot index. @@ -42,20 +41,16 @@ use crate::types::block::Error; packable::Packable, )] #[repr(transparent)] -pub struct SlotIndex(u32); +pub struct SlotIndex(pub u32); impl SlotIndex { - /// Creates a new [`SlotIndex`]. - pub fn new(index: u32) -> Self { - Self::from(index) + /// Gets the [`EpochIndex`] of this slot. + pub fn to_epoch_index(self, slots_per_epoch_exponent: u8) -> EpochIndex { + EpochIndex::from_slot_index(self, slots_per_epoch_exponent) } - /// Gets the [`EpochIndex`] of this slot. - pub fn to_epoch_index( - self, - slots_per_epoch_exponent_iter: impl Iterator, - ) -> Result { - EpochIndex::from_slot_index(self, slots_per_epoch_exponent_iter) + pub fn from_epoch_index(epoch_index: EpochIndex, slots_per_epoch_exponent: u8) -> Self { + Self(*epoch_index << slots_per_epoch_exponent) } /// Gets the slot index of a unix timestamp.