From 6882c8c8ba96bfbcda1eaf825da11ddd77b5f678 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Mon, 9 Dec 2024 18:46:33 +0800 Subject: [PATCH] Fix/consider short-term params when clipping PLS (#150) * Fix/consider short-term params when clipping PLS * update simulator * update unit tests * bump version --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 6 ++++-- src/fsrs_optimizer/fsrs_simulator.py | 2 +- tests/model_test.py | 2 +- tests/simulator_test.py | 2 +- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 834d2c4..565999e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.4.1" +version = "5.4.2" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 652a780..1095682 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -90,13 +90,15 @@ def stability_after_success( return new_s def stability_after_failure(self, state: Tensor, r: Tensor) -> Tensor: + old_s = state[:, 0] new_s = ( self.w[11] * torch.pow(state[:, 1], -self.w[12]) - * (torch.pow(state[:, 0] + 1, self.w[13]) - 1) + * (torch.pow(old_s + 1, self.w[13]) - 1) * torch.exp((1 - r) * self.w[14]) ) - return torch.minimum(new_s, state[:, 0]) + new_minimum_s = old_s / torch.exp(self.w[17] * self.w[18]) + return torch.minimum(new_s, new_minimum_s) def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor: new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3 + self.w[18])) diff --git a/src/fsrs_optimizer/fsrs_simulator.py b/src/fsrs_optimizer/fsrs_simulator.py index 300b046..68fe276 100644 --- a/src/fsrs_optimizer/fsrs_simulator.py +++ b/src/fsrs_optimizer/fsrs_simulator.py @@ -106,7 +106,7 @@ def stability_after_failure(s, r, d): * np.power(d, -w[12]) * (np.power(s + 1, w[13]) - 1) * np.exp((1 - r) * w[14]), - s, + s / np.exp(w[17] * w[18]), ), ) diff --git a/tests/model_test.py b/tests/model_test.py index 8b5180e..7470f82 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -65,7 +65,7 @@ def test_forward(self): difficulty = state[:, 1] assert torch.allclose( stability, - torch.tensor([0.2619, 1.7073, 5.8691, 25.0123, 0.3403, 2.1482]), + torch.tensor([0.2619, 1.7074, 5.8691, 25.0124, 0.2859, 2.1482]), atol=1e-4, ) assert torch.allclose( diff --git a/tests/simulator_test.py b/tests/simulator_test.py index 3da1aba..13e6abf 100644 --- a/tests/simulator_test.py +++ b/tests/simulator_test.py @@ -11,7 +11,7 @@ def test_simulate(self): cost_per_day, revlogs, ) = simulate(w=DEFAULT_PARAMETER, request_retention=0.9) - assert memorized_cnt_per_day[-1] == 5875.025236206539 + assert memorized_cnt_per_day[-1] == 5880.482440745369 def test_optimal_retention(self): default_params = {