Skip to content

Commit

Permalink
Fix/consider short-term params when clipping PLS (#150)
Browse files Browse the repository at this point in the history
* Fix/consider short-term params when clipping PLS

* update simulator

* update unit tests

* bump version
  • Loading branch information
L-M-Sherlock authored Dec 9, 2024
1 parent 348a6de commit 6882c8c
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.4.1"
version = "5.4.2"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
6 changes: 4 additions & 2 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,15 @@ def stability_after_success(
return new_s

def stability_after_failure(self, state: Tensor, r: Tensor) -> Tensor:
old_s = state[:, 0]
new_s = (
self.w[11]
* torch.pow(state[:, 1], -self.w[12])
* (torch.pow(state[:, 0] + 1, self.w[13]) - 1)
* (torch.pow(old_s + 1, self.w[13]) - 1)
* torch.exp((1 - r) * self.w[14])
)
return torch.minimum(new_s, state[:, 0])
new_minimum_s = old_s / torch.exp(self.w[17] * self.w[18])
return torch.minimum(new_s, new_minimum_s)

def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor:
new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3 + self.w[18]))
Expand Down
2 changes: 1 addition & 1 deletion src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def stability_after_failure(s, r, d):
* np.power(d, -w[12])
* (np.power(s + 1, w[13]) - 1)
* np.exp((1 - r) * w[14]),
s,
s / np.exp(w[17] * w[18]),
),
)

Expand Down
2 changes: 1 addition & 1 deletion tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_forward(self):
difficulty = state[:, 1]
assert torch.allclose(
stability,
torch.tensor([0.2619, 1.7073, 5.8691, 25.0123, 0.3403, 2.1482]),
torch.tensor([0.2619, 1.7074, 5.8691, 25.0124, 0.2859, 2.1482]),
atol=1e-4,
)
assert torch.allclose(
Expand Down
2 changes: 1 addition & 1 deletion tests/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_simulate(self):
cost_per_day,
revlogs,
) = simulate(w=DEFAULT_PARAMETER, request_retention=0.9)
assert memorized_cnt_per_day[-1] == 5875.025236206539
assert memorized_cnt_per_day[-1] == 5880.482440745369

def test_optimal_retention(self):
default_params = {
Expand Down

0 comments on commit 6882c8c

Please sign in to comment.