Skip to content

Commit

Permalink
Feat/flat power forgetting curve (#134)
Browse files Browse the repository at this point in the history
* Feat/flat power forgetting curve

* sqrt(count) as weights for pretrain

open-spaced-repetition/fsrs4anki#461 (comment)

* float eq

* use assert_approx_eq

* channel up

* make DECAY, FACTOR const and pub

* clippy fix

* pub(crate)

* fix test

---------

Co-authored-by: AsukaMinato <[email protected]>
  • Loading branch information
L-M-Sherlock and asukaminato0721 authored Dec 18, 2023
1 parent 2c7cdf9 commit f0715e0
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 38 deletions.
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
# older versions may fail to compile; newer versions may fail the clippy tests
channel = "1.74"
channel = "1.74.1"
components = ["rustfmt", "clippy"]
30 changes: 16 additions & 14 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use crate::model::Model;
use crate::training::BCELoss;
use crate::{FSRSError, FSRSItem};
use burn::tensor::ElementConversion;

pub(crate) const DECAY: f64 = -0.5;
/// (9/10) ^ (1 / DECAY) - 1
pub(crate) const FACTOR: f64 = 19f64 / 81f64;
/// This is a slice for efficiency, but should always be 17 in length.
pub type Weights = [f32];

Expand Down Expand Up @@ -57,8 +59,8 @@ impl<B: Backend> From<MemoryState> for MemoryStateTensors<B> {
}
}

fn next_interval(stability: f32, request_retention: f32) -> u32 {
(9.0 * stability * (1.0 / request_retention - 1.0))
pub fn next_interval(stability: f32, desired_retention: f32) -> u32 {
(stability / FACTOR as f32 * (desired_retention.powf(1.0 / DECAY as f32) - 1.0))
.round()
.max(1.0) as u32
}
Expand Down Expand Up @@ -365,7 +367,7 @@ mod tests {
assert_eq!(
fsrs.memory_state(item, None).unwrap(),
MemoryState {
stability: 51.344814,
stability: 51.31289,
difficulty: 7.005062
}
);
Expand All @@ -383,7 +385,7 @@ mod tests {
.good
.memory,
MemoryState {
stability: 51.344814,
stability: 51.339684,
difficulty: 7.005062
}
);
Expand All @@ -392,12 +394,12 @@ mod tests {

#[test]
fn test_next_interval() {
let request_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::<Vec<_>>();
let intervals = request_retentions
let desired_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::<Vec<_>>();
let intervals = desired_retentions
.iter()
.map(|r| next_interval(1.0, *r))
.collect::<Vec<_>>();
assert_eq!(intervals, [81, 36, 21, 14, 9, 6, 4, 2, 1, 1,]);
assert_eq!(intervals, [422, 102, 43, 22, 13, 8, 4, 2, 1, 1]);
}

#[test]
Expand All @@ -408,13 +410,13 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.21600282, 0.06387164]), 5);
.assert_approx_eq(&Data::from([0.21364396810531616, 0.05370686203241348]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.203_217_7, 0.015_836_29]), 5);
.assert_approx_eq(&Data::from([0.20306083, 0.01326745]), 5);
Ok(())
}

Expand Down Expand Up @@ -447,28 +449,28 @@ mod tests {
NextStates {
again: ItemState {
memory: MemoryState {
stability: 4.5802255,
stability: 4.577856,
difficulty: 8.881129,
},
interval: 5
},
hard: ItemState {
memory: MemoryState {
stability: 27.7025,
stability: 27.6745,
difficulty: 7.9430957
},
interval: 28,
},
good: ItemState {
memory: MemoryState {
stability: 51.344814,
stability: 51.31289,
difficulty: 7.005062
},
interval: 51,
},
easy: ItemState {
memory: MemoryState {
stability: 101.98282,
stability: 101.94249,
difficulty: 6.0670285
},
interval: 102,
Expand Down
6 changes: 3 additions & 3 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::{FSRSError, Result};
use crate::inference::Weights;
use crate::inference::{Weights, DECAY, FACTOR};
use crate::weight_clipper::clip_weights;
use crate::DEFAULT_WEIGHTS;
use burn::backend::ndarray::NdArrayDevice;
Expand Down Expand Up @@ -58,7 +58,7 @@ impl<B: Backend> Model<B> {
}

pub fn power_forgetting_curve(&self, t: Tensor<B, 1>, s: Tensor<B, 1>) -> Tensor<B, 1> {
(t / (s * 9) + 1).powf(-1.0)
(t / s * FACTOR + 1).powf(DECAY as f32)
}

fn stability_after_success(
Expand Down Expand Up @@ -267,7 +267,7 @@ mod tests {
let retention = model.power_forgetting_curve(delta_t, stability);
assert_eq!(
retention.to_data(),
Data::from([1.0, 0.9473684, 0.9310345, 0.92307687, 0.9, 0.7826087])
Data::from([1.0, 0.946059, 0.9299294, 0.9221679, 0.9, 0.79394597])
)
}

Expand Down
21 changes: 12 additions & 9 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::{FSRSError, Result};
use crate::inference::{ItemProgress, Weights};
use crate::inference::{next_interval, ItemProgress, Weights, DECAY, FACTOR};
use crate::{DEFAULT_WEIGHTS, FSRS};
use burn::tensor::backend::Backend;
use itertools::izip;
Expand Down Expand Up @@ -90,7 +90,7 @@ fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 {
.clamp(0.1, s)
}

fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: Option<u64>) -> f64 {
fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: Option<u64>) -> f64 {
let SimulatorConfig {
deck_size,
learn_span,
Expand Down Expand Up @@ -140,11 +140,15 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O

let mut retrievability = Array1::zeros(deck_size); // Create an array for retrievability

fn power_forgetting_curve(t: f64, s: f64) -> f64 {
(t / s * FACTOR + 1.0).powf(DECAY)
}

// Calculate retrievability for entries where has_learned is true
izip!(&mut retrievability, &delta_t, &old_stability, &has_learned)
.filter(|(.., &has_learned_flag)| has_learned_flag)
.for_each(|(retrievability, &delta_t, &stability, ..)| {
*retrievability = (1.0 + delta_t / (9.0 * stability)).powi(-1)
*retrievability = power_forgetting_curve(delta_t, stability)
});

// Set 'cost' column to 0
Expand Down Expand Up @@ -315,8 +319,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
izip!(&mut new_interval, &new_stability, &true_review, &true_learn)
.filter(|(.., &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag)
.for_each(|(new_ivl, &new_stab, ..)| {
*new_ivl = (9.0 * new_stab * (1.0 / request_retention - 1.0))
.round()
*new_ivl = (next_interval(new_stab as f32, desired_retention as f32) as f64)
.clamp(1.0, max_ivl);
});

Expand Down Expand Up @@ -354,7 +357,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
fn sample<F>(
config: &SimulatorConfig,
weights: &[f64],
request_retention: f64,
desired_retention: f64,
n: usize,
progress: &mut F,
) -> Result<f64>
Expand All @@ -370,7 +373,7 @@ where
simulate(
config,
weights,
request_retention,
desired_retention,
Some((i + 42).try_into().unwrap()),
)
})
Expand Down Expand Up @@ -626,15 +629,15 @@ mod tests {
0.9,
None,
);
assert_eq!(memorization, 2405.020202735966)
assert_eq!(memorization, 2380.9836436993573)
}

#[test]
fn optimal_retention() -> Result<()> {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8608067460076987);
assert_eq!(optimal_retention, 0.8568971936549108);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down
29 changes: 18 additions & 11 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::error::{FSRSError, Result};
use crate::inference::{DECAY, FACTOR};
use crate::FSRSItem;
use crate::DEFAULT_WEIGHTS;
use itertools::Itertools;
Expand Down Expand Up @@ -87,7 +88,7 @@ fn total_rating_count(
}

fn power_forgetting_curve(t: &Array1<f32>, s: f32) -> Array1<f32> {
1.0 / (1.0 + t / (9.0 * s))
(t / s * FACTOR as f32 + 1.0).mapv(|v| v.powf(DECAY as f32))
}

fn loss(
Expand All @@ -100,10 +101,9 @@ fn loss(
let y_pred = power_forgetting_curve(delta_t, init_s0);
let logloss = (-(recall * y_pred.clone().mapv_into(|v| v.ln())
+ (1.0 - recall) * (1.0 - &y_pred).mapv_into(|v| v.ln()))
* count
/ count.sum())
* count.mapv(|v| v.sqrt()))
.sum();
let l1 = (init_s0 - default_s0).abs() / count.sum() / 16.0;
let l1 = (init_s0 - default_s0).abs() / 16.0;
logloss + l1
}

Expand Down Expand Up @@ -313,7 +313,7 @@ mod tests {
let t = Array1::from(vec![0.0, 1.0, 2.0, 3.0]);
let s = 1.0;
let y = power_forgetting_curve(&t, s);
let expected = Array1::from(vec![1.0, 0.9, 0.8181818, 0.75]);
let expected = Array1::from(vec![1.0, 0.90000004, 0.82502866, 0.76613086]);
assert_eq!(y, expected);
}

Expand All @@ -324,8 +324,9 @@ mod tests {
let count = Array1::from(vec![100.0, 100.0, 100.0]);
let init_s0 = 1.0;
let actual = loss(&delta_t, &recall, &count, init_s0, init_s0);
assert_eq!(actual, 0.45385247);
assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 0.48355862);
assert_eq!(actual, 13.624332);
Data::from([loss(&delta_t, &recall, &count, 2.0, init_s0)])
.assert_approx_eq(&Data::from([14.5771]), 5);
}

#[test]
Expand Down Expand Up @@ -356,7 +357,8 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
Data::from([actual.get(&4).unwrap().clone()]).assert_approx_eq(&Data::from([1.2390649]), 4);
Data::from([actual.get(&4).unwrap().clone()])
.assert_approx_eq(&Data::from([1.2301323413848877]), 4);
}

#[test]
Expand All @@ -365,9 +367,14 @@ mod tests {
let items = anki21_sample_file_converted_to_fsrs();
let average_recall = calculate_average_recall(&items);
let pretrainset = split_data(items, 1).0;
assert_eq!(
pretrain(pretrainset, average_recall).unwrap(),
[0.94550645, 1.6813093, 3.9867811, 8.992397,],
Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq(
&Data::from([
0.9560174345970154,
1.694406509399414,
3.998023509979248,
8.26822280883789,
]),
4,
)
}

Expand Down

0 comments on commit f0715e0

Please sign in to comment.