Skip to content

Commit

Permalink
Feat/update to FSRS-4.5 (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Mar 14, 2024
1 parent f1ecf3c commit bd339b4
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
10 changes: 7 additions & 3 deletions fsrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ func (p *Parameters) Repeat(card Card, now time.Time) map[Rating]SchedulingInfo

s.schedule(now, hardInterval, goodInterval, easyInterval)
case Review:
interval := float64(card.ElapsedDays)
elapsedDays := float64(card.ElapsedDays)
lastD := card.Difficulty
lastS := card.Stability
retrievability := math.Pow(1+interval/(9*lastS), -1)
retrievability := p.forgettingCurve(elapsedDays, lastS)
p.nextDS(s, lastD, lastS, retrievability)

hardInterval := p.nextInterval(s.Hard.Stability)
Expand Down Expand Up @@ -121,6 +121,10 @@ func (s *schedulingCards) recordLog(card Card, now time.Time) map[Rating]Schedul
return m
}

func (p *Parameters) forgettingCurve(elapsedDays float64, stability float64) float64 {
return math.Pow(1+p.Factor*elapsedDays/stability, p.Decay)
}

func (p *Parameters) initDS(s *schedulingCards) {
s.Again.Difficulty = p.initDifficulty(Again)
s.Again.Stability = p.initStability(Again)
Expand Down Expand Up @@ -155,7 +159,7 @@ func constrainDifficulty(d float64) float64 {
}

func (p *Parameters) nextInterval(s float64) float64 {
newInterval := s * 9 * (1/p.RequestRetention - 1)
newInterval := s / p.Factor * (math.Pow(p.RequestRetention, 1/p.Decay) - 1)
return math.Max(math.Min(math.Round(newInterval), p.MaximumInterval), 1)
}

Expand Down
61 changes: 57 additions & 4 deletions fsrs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ package fsrs
import (
"encoding/json"
"fmt"
"math"
"reflect"
"testing"
"time"
)

func roundFloat(val float64, precision uint) float64 {
ratio := math.Pow(10, float64(precision))
return math.Round(val*ratio) / ratio
}

func TestExample(t *testing.T) {
p := DefaultParam()
p.W = Weights{1.14, 1.01, 5.44, 14.67, 5.3024, 1.5662, 1.2503, 0.0028, 1.5489, 0.1763, 0.9953, 2.7473, 0.0179, 0.3105, 0.3976, 0.0, 2.0902}
p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321,
2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244}
card := NewCard()
now := time.Date(2022, 11, 29, 12, 30, 0, 0, time.UTC)
var ivlList []uint64
Expand Down Expand Up @@ -38,12 +45,58 @@ func TestExample(t *testing.T) {
fmt.Println(ivlList)
fmt.Println(stateList)

wantIvlList := []uint64{0, 5, 16, 43, 106, 236, 0, 0, 12, 25, 47, 85, 147}
wantIvlList := []uint64{0, 4, 15, 49, 143, 379, 0, 0, 15, 37, 85, 184, 376}
if !reflect.DeepEqual(ivlList, wantIvlList) {
t.Errorf("excepted:%v, got:%v", ivlList, wantIvlList)
t.Errorf("excepted:%v, got:%v", wantIvlList, ivlList)
}
wantStateList := []State{New, Learning, Review, Review, Review, Review, Review, Relearning, Relearning, Review, Review, Review, Review}
if !reflect.DeepEqual(stateList, wantStateList) {
t.Errorf("excepted:%v, got:%v", stateList, wantStateList)
t.Errorf("excepted:%v, got:%v", wantStateList, stateList)
}
}

func TestMemoState(t *testing.T) {
p := DefaultParam()
p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321,
2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244}
card := NewCard()
now := time.Date(2022, 11, 29, 12, 30, 0, 0, time.UTC)

schedulingCards := p.Repeat(card, now)
var ratings = []Rating{Again, Good, Good, Good, Good, Good}
var ivlList = []uint64{0, 0, 1, 3, 8, 21}
var rating Rating
for i := 0; i < len(ratings); i++ {
rating = ratings[i]
card = schedulingCards[rating].Card
now = now.Add(time.Duration(ivlList[i]) * 24 * time.Hour)
schedulingCards = p.Repeat(card, now)
}
wantStability := 43.0554
cardStability := roundFloat(schedulingCards[Good].Card.Stability, 4)
wantDifficulty := 7.7609
cardDifficulty := roundFloat(schedulingCards[Good].Card.Difficulty, 4)

if !reflect.DeepEqual(wantStability, cardStability) {
t.Errorf("excepted:%v, got:%v", wantStability, cardStability)
}

if !reflect.DeepEqual(wantDifficulty, cardDifficulty) {
t.Errorf("excepted:%v, got:%v", wantDifficulty, cardDifficulty)
}
}

func TestNextInterval(t *testing.T) {
p := DefaultParam()
p.W = Weights{1.0171, 1.8296, 4.4145, 10.9355, 5.0965, 1.3322, 1.017, 0.0, 1.6243, 0.1369, 1.0321,
2.1866, 0.0661, 0.336, 1.7766, 0.1693, 2.9244}
var ivlList []float64
for i := 1; i <= 10; i++ {
p.RequestRetention = float64(i) / 10
ivlList = append(ivlList, p.nextInterval(1))
}
wantIvlList := []float64{422, 102, 43, 22, 13, 8, 4, 2, 1, 1}
if !reflect.DeepEqual(ivlList, wantIvlList) {
t.Errorf("excepted:%v, got:%v", wantIvlList, ivlList)
}
}
11 changes: 10 additions & 1 deletion params.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
package fsrs

import "math"

type Weights [17]float64

type Parameters struct {
RequestRetention float64 `json:"RequestRetention"`
MaximumInterval float64 `json:"MaximumInterval"`
W Weights `json:"Weights"`
Decay float64 `json:"Decay"`
Factor float64 `json:"Factor"`
}

func DefaultParam() Parameters {
var Decay = -0.5
var Factor = math.Pow(0.9, 1/Decay) - 1
return Parameters{
RequestRetention: 0.9,
MaximumInterval: 36500,
W: DefaultWeights(),
Decay: Decay,
Factor: Factor,
}
}

func DefaultWeights() Weights {
return Weights{0.4, 0.6, 2.4, 5.8, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29, 2.61}
return Weights{0.5701, 1.4436, 4.1386, 10.9355, 5.1443, 1.2006, 0.8627, 0.0362, 1.629, 0.1342, 1.0166, 2.1174,
0.0839, 0.3204, 1.4676, 0.219, 2.8237}
}

0 comments on commit bd339b4

Please sign in to comment.