From 930aeb78cfe7624f1bb0b9f3edd5f278a5e0f095 Mon Sep 17 00:00:00 2001 From: Adrian Date: Sun, 30 Jul 2023 11:25:43 +0200 Subject: [PATCH] Add surprise based difficulty function Implements open-spaced-repetition/fsrs4anki#352 --- src/fsrs_optimizer/fsrs_optimizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index de87d5b..a2f7337 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -74,7 +74,8 @@ def step(self, X: Tensor, state: Tensor) -> Tensor: new_d = new_d.clamp(1, 10) else: r = power_forgetting_curve(X[:,0], state[:,0]) - new_d = state[:,1] - self.w[6] * (X[:,1] - 3) + a = self.surprise(r, X[:,1]) + new_d = a * (state[:,1] - self.w[6] * (X[:,1] - 3)) + (1-a) * state[:,1] new_d = self.mean_reversion(self.w[4], new_d) new_d = new_d.clamp(1, 10) condition = X[:,1] > 1 @@ -97,6 +98,9 @@ def forward(self, inputs: Tensor, state: Optional[Tensor]=None) -> Tensor: def mean_reversion(self, init: Tensor, current: Tensor) -> Tensor: return self.w[7] * init + (1-self.w[7]) * current + def surprise(self, retrievability: Tensor, grade: Tensor) -> Tensor: + return torch.exp(-1 - (retrievability - 0.5) * (grade - 2)) + class WeightClipper: def __init__(self, frequency: int=1): self.frequency = frequency