diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 86e3f3cd..5aaa2d56 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,4 +1,4 @@ [toolchain] # older versions may fail to compile; newer versions may fail the clippy tests -channel = "1.74" +channel = "1.74.1" components = ["rustfmt", "clippy"] diff --git a/src/inference.rs b/src/inference.rs index eb95e96f..0dc22b59 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -12,7 +12,9 @@ use crate::model::Model; use crate::training::BCELoss; use crate::{FSRSError, FSRSItem}; use burn::tensor::ElementConversion; - +pub(crate) const DECAY: f64 = -0.5; +/// (9/10) ^ (1 / DECAY) - 1 +pub(crate) const FACTOR: f64 = 19f64 / 81f64; /// This is a slice for efficiency, but should always be 17 in length. pub type Weights = [f32]; @@ -57,8 +59,8 @@ impl From for MemoryStateTensors { } } -fn next_interval(stability: f32, request_retention: f32) -> u32 { - (9.0 * stability * (1.0 / request_retention - 1.0)) +pub fn next_interval(stability: f32, desired_retention: f32) -> u32 { + (stability / FACTOR as f32 * (desired_retention.powf(1.0 / DECAY as f32) - 1.0)) .round() .max(1.0) as u32 } @@ -365,7 +367,7 @@ mod tests { assert_eq!( fsrs.memory_state(item, None).unwrap(), MemoryState { - stability: 51.344814, + stability: 51.31289, difficulty: 7.005062 } ); @@ -383,7 +385,7 @@ mod tests { .good .memory, MemoryState { - stability: 51.344814, + stability: 51.339684, difficulty: 7.005062 } ); @@ -392,12 +394,12 @@ mod tests { #[test] fn test_next_interval() { - let request_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::>(); - let intervals = request_retentions + let desired_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::>(); + let intervals = desired_retentions .iter() .map(|r| next_interval(1.0, *r)) .collect::>(); - assert_eq!(intervals, [81, 36, 21, 14, 9, 6, 4, 2, 1, 1,]); + assert_eq!(intervals, [422, 102, 43, 22, 13, 8, 4, 2, 1, 1]); } #[test] @@ -408,13 +410,13 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.21600282, 0.06387164]), 5); + .assert_approx_eq(&Data::from([0.21364396810531616, 0.05370686203241348]), 5); let fsrs = FSRS::new(Some(WEIGHTS))?; let metrics = fsrs.evaluate(items, |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.203_217_7, 0.015_836_29]), 5); + .assert_approx_eq(&Data::from([0.20306083, 0.01326745]), 5); Ok(()) } @@ -447,28 +449,28 @@ mod tests { NextStates { again: ItemState { memory: MemoryState { - stability: 4.5802255, + stability: 4.577856, difficulty: 8.881129, }, interval: 5 }, hard: ItemState { memory: MemoryState { - stability: 27.7025, + stability: 27.6745, difficulty: 7.9430957 }, interval: 28, }, good: ItemState { memory: MemoryState { - stability: 51.344814, + stability: 51.31289, difficulty: 7.005062 }, interval: 51, }, easy: ItemState { memory: MemoryState { - stability: 101.98282, + stability: 101.94249, difficulty: 6.0670285 }, interval: 102, diff --git a/src/model.rs b/src/model.rs index 6cadb8bb..6a5e5d7a 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,5 @@ use crate::error::{FSRSError, Result}; -use crate::inference::Weights; +use crate::inference::{Weights, DECAY, FACTOR}; use crate::weight_clipper::clip_weights; use crate::DEFAULT_WEIGHTS; use burn::backend::ndarray::NdArrayDevice; @@ -58,7 +58,7 @@ impl Model { } pub fn power_forgetting_curve(&self, t: Tensor, s: Tensor) -> Tensor { - (t / (s * 9) + 1).powf(-1.0) + (t / s * FACTOR + 1).powf(DECAY as f32) } fn stability_after_success( @@ -267,7 +267,7 @@ mod tests { let retention = model.power_forgetting_curve(delta_t, stability); assert_eq!( retention.to_data(), - Data::from([1.0, 0.9473684, 0.9310345, 0.92307687, 0.9, 0.7826087]) + Data::from([1.0, 0.946059, 0.9299294, 0.9221679, 0.9, 0.79394597]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 5883f9b7..b657299e 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -1,5 +1,5 @@ use crate::error::{FSRSError, Result}; -use crate::inference::{ItemProgress, Weights}; +use crate::inference::{next_interval, ItemProgress, Weights, DECAY, FACTOR}; use crate::{DEFAULT_WEIGHTS, FSRS}; use burn::tensor::backend::Backend; use itertools::izip; @@ -90,7 +90,7 @@ fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 { .clamp(0.1, s) } -fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: Option) -> f64 { +fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: Option) -> f64 { let SimulatorConfig { deck_size, learn_span, @@ -140,11 +140,15 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O let mut retrievability = Array1::zeros(deck_size); // Create an array for retrievability + fn power_forgetting_curve(t: f64, s: f64) -> f64 { + (t / s * FACTOR + 1.0).powf(DECAY) + } + // Calculate retrievability for entries where has_learned is true izip!(&mut retrievability, &delta_t, &old_stability, &has_learned) .filter(|(.., &has_learned_flag)| has_learned_flag) .for_each(|(retrievability, &delta_t, &stability, ..)| { - *retrievability = (1.0 + delta_t / (9.0 * stability)).powi(-1) + *retrievability = power_forgetting_curve(delta_t, stability) }); // Set 'cost' column to 0 @@ -315,8 +319,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O izip!(&mut new_interval, &new_stability, &true_review, &true_learn) .filter(|(.., &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag) .for_each(|(new_ivl, &new_stab, ..)| { - *new_ivl = (9.0 * new_stab * (1.0 / request_retention - 1.0)) - .round() + *new_ivl = (next_interval(new_stab as f32, desired_retention as f32) as f64) .clamp(1.0, max_ivl); }); @@ -354,7 +357,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O fn sample( config: &SimulatorConfig, weights: &[f64], - request_retention: f64, + desired_retention: f64, n: usize, progress: &mut F, ) -> Result @@ -370,7 +373,7 @@ where simulate( config, weights, - request_retention, + desired_retention, Some((i + 42).try_into().unwrap()), ) }) @@ -626,7 +629,7 @@ mod tests { 0.9, None, ); - assert_eq!(memorization, 2405.020202735966) + assert_eq!(memorization, 2380.9836436993573) } #[test] @@ -634,7 +637,7 @@ mod tests { let config = SimulatorConfig::default(); let fsrs = FSRS::new(None)?; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8608067460076987); + assert_eq!(optimal_retention, 0.8568971936549108); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } diff --git a/src/pre_training.rs b/src/pre_training.rs index 1f81f928..d91a0831 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -1,4 +1,5 @@ use crate::error::{FSRSError, Result}; +use crate::inference::{DECAY, FACTOR}; use crate::FSRSItem; use crate::DEFAULT_WEIGHTS; use itertools::Itertools; @@ -87,7 +88,7 @@ fn total_rating_count( } fn power_forgetting_curve(t: &Array1, s: f32) -> Array1 { - 1.0 / (1.0 + t / (9.0 * s)) + (t / s * FACTOR as f32 + 1.0).mapv(|v| v.powf(DECAY as f32)) } fn loss( @@ -100,10 +101,9 @@ fn loss( let y_pred = power_forgetting_curve(delta_t, init_s0); let logloss = (-(recall * y_pred.clone().mapv_into(|v| v.ln()) + (1.0 - recall) * (1.0 - &y_pred).mapv_into(|v| v.ln())) - * count - / count.sum()) + * count.mapv(|v| v.sqrt())) .sum(); - let l1 = (init_s0 - default_s0).abs() / count.sum() / 16.0; + let l1 = (init_s0 - default_s0).abs() / 16.0; logloss + l1 } @@ -313,7 +313,7 @@ mod tests { let t = Array1::from(vec![0.0, 1.0, 2.0, 3.0]); let s = 1.0; let y = power_forgetting_curve(&t, s); - let expected = Array1::from(vec![1.0, 0.9, 0.8181818, 0.75]); + let expected = Array1::from(vec![1.0, 0.90000004, 0.82502866, 0.76613086]); assert_eq!(y, expected); } @@ -324,8 +324,9 @@ mod tests { let count = Array1::from(vec![100.0, 100.0, 100.0]); let init_s0 = 1.0; let actual = loss(&delta_t, &recall, &count, init_s0, init_s0); - assert_eq!(actual, 0.45385247); - assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 0.48355862); + assert_eq!(actual, 13.624332); + Data::from([loss(&delta_t, &recall, &count, 2.0, init_s0)]) + .assert_approx_eq(&Data::from([14.5771]), 5); } #[test] @@ -356,7 +357,8 @@ mod tests { ], )]); let actual = search_parameters(pretrainset, 0.9); - Data::from([actual.get(&4).unwrap().clone()]).assert_approx_eq(&Data::from([1.2390649]), 4); + Data::from([actual.get(&4).unwrap().clone()]) + .assert_approx_eq(&Data::from([1.2301323413848877]), 4); } #[test] @@ -365,9 +367,14 @@ mod tests { let items = anki21_sample_file_converted_to_fsrs(); let average_recall = calculate_average_recall(&items); let pretrainset = split_data(items, 1).0; - assert_eq!( - pretrain(pretrainset, average_recall).unwrap(), - [0.94550645, 1.6813093, 3.9867811, 8.992397,], + Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq( + &Data::from([ + 0.9560174345970154, + 1.694406509399414, + 3.998023509979248, + 8.26822280883789, + ]), + 4, ) }