From a9244120e016e9bf695f978b93a770cbdf6cce79 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 9 Dec 2024 16:38:02 +0800 Subject: [PATCH] Fix/consider short-term params when clipping PLS refer to: https://github.com/open-spaced-repetition/fsrs-optimizer/pull/150 --- src/inference.rs | 4 ++-- src/model.rs | 8 ++++++-- src/optimal_retention.rs | 8 ++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index 9eb60a0..bd095fd 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -499,7 +499,7 @@ mod tests { let fsrs = FSRS::new(Some(&[]))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); - assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216326, 0.038727]); + assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216286, 0.038692]); let fsrs = FSRS::new(Some(PARAMETERS))?; let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); @@ -510,7 +510,7 @@ mod tests { .universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true) .unwrap(); - assert_approx_eq([self_by_other, other_by_self], [0.016236, 0.031085]); + assert_approx_eq([self_by_other, other_by_self], [0.016570, 0.031037]); Ok(()) } diff --git a/src/model.rs b/src/model.rs index 4e5f172..9e2e091 100644 --- a/src/model.rs +++ b/src/model.rs @@ -94,9 +94,10 @@ impl Model { * last_d.pow(-self.w.get(12)) * ((last_s.clone() + 1).pow(self.w.get(13)) - 1) * ((-r + 1) * self.w.get(14)).exp(); + let new_s_min = last_s / (self.w.get(17) * self.w.get(18)).exp(); new_s .clone() - .mask_where(last_s.clone().lower(new_s), last_s) + .mask_where(new_s_min.clone().lower(new_s), new_s_min) } fn stability_short_term(&self, last_s: Tensor, rating: Tensor) -> Tensor { @@ -380,7 +381,10 @@ mod tests { &device, ); let state = model.forward(delta_ts, ratings, None); - dbg!(&state); + let stability = state.stability.to_data(); + let difficulty = state.difficulty.to_data(); + stability.assert_approx_eq(&Data::from([0.2619, 1.7074, 5.8691, 25.0124, 0.2859, 2.1482]), 4); + difficulty.assert_approx_eq(&Data::from([8.0827, 7.0405, 5.2729, 2.1301, 8.0827, 7.0405]), 4); } #[test] diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 09e3941..5de3007 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -83,7 +83,7 @@ fn stability_after_success(w: &[f32], s: f32, r: f32, d: f32, rating: usize) -> fn stability_after_failure(w: &[f32], s: f32, r: f32, d: f32) -> f32 { (w[11] * d.powf(-w[12]) * ((s + 1.0).powf(w[13]) - 1.0) * f32::exp((1.0 - r) * w[14])) - .clamp(S_MIN, s) + .clamp(S_MIN, s / (w[17] * w[18]).exp()) } fn stability_short_term(w: &[f32], s: f32, rating_offset: f32, session_len: f32) -> f32 { @@ -903,7 +903,7 @@ mod tests { simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?; assert_eq!( memorized_cnt_per_day[memorized_cnt_per_day.len() - 1], - 6919.944 + 6911.91 ); Ok(()) } @@ -1023,7 +1023,7 @@ mod tests { ..Default::default() }; let results = simulate(&config, &DEFAULT_PARAMETERS, 0.9, None, None)?; - assert_eq!(results.0[results.0.len() - 1], 6591.4854); + assert_eq!(results.0[results.0.len() - 1], 6559.517); Ok(()) } @@ -1076,7 +1076,7 @@ mod tests { ..Default::default() }; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.84499365); + assert_eq!(optimal_retention, 0.84458643); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) }