Skip to content

Commit

Permalink
Fix/update difficulty before stability (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Sep 30, 2022
1 parent 5b9fd79 commit aa6224b
Showing 1 changed file with 85 additions and 81 deletions.
166 changes: 85 additions & 81 deletions fsrs4anki_optimizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# FSRS4Anki v2.0.4 Optimizer"
"# FSRS4Anki v2.1.1 Optimizer"
]
},
{
Expand All @@ -13,7 +13,7 @@
"id": "lurCmW0Jqz3s"
},
"source": [
"[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v2.0.4/fsrs4anki_optimizer.ipynb)\n",
"[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v2.1.1/fsrs4anki_optimizer.ipynb)\n",
"\n",
"↑ Click the above button to open the optimizer on Google Colab.\n",
"\n",
Expand Down Expand Up @@ -145,7 +145,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 5166/5166 [00:16<00:00, 308.15it/s]\n"
"100%|██████████| 5166/5166 [00:17<00:00, 299.33it/s]\n"
]
},
{
Expand Down Expand Up @@ -296,21 +296,25 @@
" :param d: difficulty\n",
" :return:\n",
" '''\n",
" if torch.equal(s, torch.FloatTensor([0.0])):\n",
" if torch.equal(s, self.zero):\n",
" # first learn, init memory states\n",
" next_s = self.f_s[0] * (self.f_s[1] * (x[1] - 1) + 1)\n",
" next_d = self.f_d[0] * (self.f_d[1] * (x[1] - 4) + 1)\n",
" new_d = self.f_d[0] * (self.f_d[1] * (x[1] - 4) + 1)\n",
" new_s = self.f_s[0] * (self.f_s[1] * (x[1] - 1) + 1)\n",
" else:\n",
" r = torch.exp(np.log(0.9) * x[0] / s)\n",
" new_d = d + self.f_d[2] * (x[1] - 3)\n",
" new_d = self.mean_reversion(self.f_d[0] * (- self.f_d[1] + 1), new_d)\n",
" new_d = self.constrain(new_d)\n",
" # recall\n",
" if x[1] > 1:\n",
" next_s = s * (1 + torch.exp(self.s_w[0]) * torch.pow(d, self.s_w[1]) *\n",
" torch.pow(s, self.s_w[2]) *\n",
" (torch.exp((1 - r) * self.s_w[3]) - 1))\n",
" new_s = s * (1 + torch.exp(self.s_w[0]) * \n",
" torch.pow(new_d, self.s_w[1]) *\n",
" torch.pow(s, self.s_w[2]) *\n",
" (torch.exp((1 - r) * self.s_w[3]) - 1))\n",
" # forget\n",
" else:\n",
" next_s = self.s_w[4] * torch.pow(d, self.s_w[5]) * torch.pow(s, self.s_w[6]) * torch.exp((1 - r) * self.s_w[7])\n",
" next_d = d + self.f_d[2] * (x[1] - 3)\n",
" next_d = self.mean_reversion(self.f_d[0] * (- self.f_d[1] + 1), next_d)\n",
" return next_s, self.constrain(next_d)\n",
" new_s = self.s_w[4] * torch.pow(new_d, self.s_w[5]) * torch.pow(s, self.s_w[6]) * torch.exp((1 - r) * self.s_w[7])\n",
" return new_s, new_d\n",
"\n",
" def loss(self, s, t, r):\n",
" return - (r * np.log(0.9) * t / s + (1 - r) * torch.log(1 - torch.exp(np.log(0.9) * t / s)))\n",
Expand Down Expand Up @@ -388,7 +392,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56910/56910 [00:03<00:00, 15081.04it/s]\n"
"100%|██████████| 56910/56910 [00:03<00:00, 15362.11it/s]\n"
]
},
{
Expand All @@ -402,7 +406,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"train: 0%|\u001b[31m \u001b[0m| 27/56910 [00:00<03:35, 263.88it/s]"
"train: 0%|\u001b[31m \u001b[0m| 26/56910 [00:00<03:40, 258.03it/s]"
]
},
{
Expand All @@ -419,160 +423,160 @@
"name": "stderr",
"output_type": "stream",
"text": [
"train: 10%|\u001b[31m█ \u001b[0m| 5736/56910 [00:14<02:03, 415.67it/s]"
"train: 10%|\u001b[31m█ \u001b[0m| 5741/56910 [00:14<02:01, 421.39it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 5692\n",
"f_s: [1.2724, 1.2575]\n",
"f_d: [1.0094, -0.9884, -1.1116, 0.0636]\n",
"s_w: [3.0022, -0.8891, -0.1969, 1.2971, 2.1941, -0.3021, 0.315, 1.1898]\n"
"f_s: [1.2702, 1.2586]\n",
"f_d: [1.0024, -0.9651, -1.1192, 0.0521]\n",
"s_w: [3.0059, -0.8937, -0.19, 1.3011, 2.2132, -0.2858, 0.334, 1.2064]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 20%|\u001b[31m██ \u001b[0m| 11457/56910 [00:31<02:08, 352.56it/s]"
"train: 20%|\u001b[31m██ \u001b[0m| 11434/56910 [00:28<02:02, 372.66it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 11383\n",
"f_s: [1.4695, 1.4782]\n",
"f_d: [1.0014, -0.9704, -1.0768, 0.0801]\n",
"s_w: [3.0619, -0.8667, -0.1365, 1.3528, 2.1935, -0.2805, 0.3057, 1.1731]\n"
"f_s: [1.4602, 1.4738]\n",
"f_d: [1.0014, -0.9322, -1.1048, 0.0728]\n",
"s_w: [3.0657, -0.8794, -0.1265, 1.3568, 2.2175, -0.2703, 0.3296, 1.1948]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 30%|\u001b[31m███ \u001b[0m| 17147/56910 [00:47<01:39, 399.03it/s]"
"train: 30%|\u001b[31m███ \u001b[0m| 17120/56910 [00:43<01:36, 410.21it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 17074\n",
"f_s: [1.5863, 1.6589]\n",
"f_d: [1.0258, -0.9882, -1.0913, 0.091]\n",
"s_w: [3.0673, -0.8904, -0.11, 1.3537, 2.1634, -0.2992, 0.2965, 1.1262]\n"
"f_s: [1.5726, 1.65]\n",
"f_d: [1.0219, -0.9416, -1.1315, 0.0839]\n",
"s_w: [3.0711, -0.9074, -0.0979, 1.3578, 2.1922, -0.2911, 0.3235, 1.1508]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 40%|\u001b[31m████ \u001b[0m| 22789/56910 [01:02<01:55, 296.51it/s]"
"train: 40%|\u001b[31m████ \u001b[0m| 22813/56910 [00:58<01:34, 359.07it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 22765\n",
"f_s: [1.5906, 1.7348]\n",
"f_d: [1.0689, -1.0498, -1.1311, 0.063]\n",
"s_w: [3.0328, -0.9322, -0.1411, 1.3173, 2.0915, -0.3735, 0.2617, 1.0674]\n"
"f_s: [1.5704, 1.72]\n",
"f_d: [1.0559, -0.9928, -1.1881, 0.061]\n",
"s_w: [3.0357, -0.9603, -0.127, 1.3206, 2.1261, -0.3587, 0.2932, 1.0941]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 50%|\u001b[31m█████ \u001b[0m| 28491/56910 [01:19<01:26, 329.21it/s]"
"train: 50%|\u001b[31m█████ \u001b[0m| 28533/56910 [01:12<01:13, 384.47it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 28456\n",
"f_s: [1.7049, 1.8936]\n",
"f_d: [1.014, -1.0166, -1.139, 0.0642]\n",
"s_w: [3.0524, -0.9503, -0.1023, 1.3343, 2.1543, -0.2904, 0.3008, 1.1328]\n"
"f_s: [1.6784, 1.8725]\n",
"f_d: [1.0067, -0.9636, -1.1995, 0.0616]\n",
"s_w: [3.0564, -0.9843, -0.0857, 1.3387, 2.1906, -0.2843, 0.3328, 1.1568]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 60%|\u001b[31m██████ \u001b[0m| 34200/56910 [01:34<00:54, 417.26it/s]"
"train: 60%|\u001b[31m██████ \u001b[0m| 34228/56910 [01:26<00:53, 426.19it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 34147\n",
"f_s: [1.7278, 1.92]\n",
"f_d: [1.0126, -1.028, -1.1428, 0.0598]\n",
"s_w: [3.0563, -0.9658, -0.0951, 1.3386, 2.1339, -0.2862, 0.29, 1.1129]\n"
"f_s: [1.6983, 1.8963]\n",
"f_d: [1.0033, -0.9629, -1.2178, 0.0613]\n",
"s_w: [3.0599, -1.0047, -0.0774, 1.3426, 2.1754, -0.286, 0.3243, 1.1376]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 70%|\u001b[31m███████ \u001b[0m| 39887/56910 [01:49<00:43, 394.78it/s]"
"train: 70%|\u001b[31m███████ \u001b[0m| 39894/56910 [01:41<00:47, 357.96it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 39838\n",
"f_s: [1.7382, 1.9699]\n",
"f_d: [1.0222, -1.0376, -1.1101, 0.1054]\n",
"s_w: [3.0621, -0.9548, -0.0881, 1.339, 2.1527, -0.2527, 0.3084, 1.1179]\n"
"f_s: [1.7047, 1.9418]\n",
"f_d: [1.0155, -0.976, -1.1999, 0.1074]\n",
"s_w: [3.0646, -1.0041, -0.0691, 1.3421, 2.1959, -0.2599, 0.3419, 1.141]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 80%|\u001b[31m████████ \u001b[0m| 45563/56910 [02:03<00:27, 405.77it/s]"
"train: 80%|\u001b[31m████████ \u001b[0m| 45559/56910 [02:00<00:48, 236.34it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 45529\n",
"f_s: [1.7408, 2.0119]\n",
"f_d: [1.0258, -1.0531, -1.1412, 0.083]\n",
"s_w: [3.0376, -1.0004, -0.0954, 1.3099, 2.165, -0.2424, 0.3187, 1.106]\n"
"f_s: [1.7054, 1.9824]\n",
"f_d: [1.0233, -0.9943, -1.2296, 0.0866]\n",
"s_w: [3.0414, -1.0547, -0.0734, 1.314, 2.2104, -0.2478, 0.3548, 1.1255]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 90%|\u001b[31m█████████ \u001b[0m| 51272/56910 [02:18<00:15, 355.37it/s]"
"train: 90%|\u001b[31m█████████ \u001b[0m| 51267/56910 [02:17<00:16, 338.00it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration: 51220\n",
"f_s: [1.6901, 2.0208]\n",
"f_d: [1.0166, -1.0509, -1.1197, 0.0811]\n",
"s_w: [3.0567, -0.991, -0.0839, 1.3272, 2.158, -0.2477, 0.3095, 1.093]\n"
"f_s: [1.6519, 1.9887]\n",
"f_d: [1.006, -0.9738, -1.2296, 0.0881]\n",
"s_w: [3.062, -1.0546, -0.061, 1.3333, 2.2051, -0.2531, 0.3472, 1.1081]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train: 100%|\u001b[31m██████████\u001b[0m| 56910/56910 [02:33<00:00, 371.12it/s]"
"train: 100%|\u001b[31m██████████\u001b[0m| 56910/56910 [02:35<00:00, 365.71it/s]"
]
},
{
Expand Down Expand Up @@ -679,9 +683,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"var f_s = [1.6959,2.0195];\n",
"var f_d = [1.004,-1.0202,-1.108,0.0792];\n",
"var s_w = [3.0735,-0.9749,-0.0806,1.3396,2.1591,-0.2358,0.3103,1.0855];\n"
"var f_s = [1.6543,1.9849];\n",
"var f_d = [1.006,-0.9394,-1.2311,0.0923];\n",
"var s_w = [3.0788,-1.0484,-0.0571,1.3462,2.2089,-0.2442,0.3476,1.099];\n"
]
}
],
Expand Down Expand Up @@ -719,23 +723,23 @@
"\n",
"first rating: 1\n",
"rating history: 1,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,2,3,5,9,16,28,48,82,140,237,399,667,1107,1825,2986\n",
"difficulty history: 0,4.1,3.9,3.8,3.6,3.5,3.4,3.3,3.2,3.1,3.0,2.9,2.9,2.8,2.7,2.7\n",
"interval history: 0,2,3,6,10,18,33,60,108,195,352,634,1138,2035,3622,6414\n",
"difficulty history: 0,3.8,3.7,3.5,3.4,3.2,3.1,3.0,2.9,2.8,2.7,2.7,2.6,2.5,2.5,2.4\n",
"\n",
"first rating: 2\n",
"rating history: 2,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,5,10,19,35,65,118,212,377,662,1149,1970,3339,5594,9266,15177\n",
"difficulty history: 0,3.1,3.0,2.9,2.8,2.8,2.7,2.7,2.6,2.6,2.5,2.5,2.4,2.4,2.4,2.4\n",
"interval history: 0,5,10,19,37,72,139,266,504,949,1771,3275,6003,10900,19606,34928\n",
"difficulty history: 0,2.9,2.8,2.7,2.7,2.6,2.5,2.5,2.4,2.4,2.3,2.3,2.3,2.2,2.2,2.2\n",
"\n",
"first rating: 3\n",
"rating history: 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,9,20,45,96,200,403,789,1506,2805,5107,9099,15889,27226,45826,75847\n",
"interval history: 0,8,19,44,99,218,467,978,2004,4022,7914,15282,28985,54043,99131,179022\n",
"difficulty history: 0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0\n",
"\n",
"first rating: 4\n",
"rating history: 4,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3\n",
"interval history: 0,12,42,131,368,948,2265,5076,10758,21718,42006,78228,140859,246106,418504,694473\n",
"difficulty history: 0,1.0,1.1,1.2,1.2,1.3,1.4,1.4,1.5,1.5,1.5,1.6,1.6,1.6,1.7,1.7\n",
"interval history: 0,12,41,129,373,1001,2521,6010,13650,29702,62211,125918,247134,471701,877783,1596050\n",
"difficulty history: 0,1.0,1.1,1.2,1.2,1.3,1.4,1.4,1.5,1.5,1.6,1.6,1.6,1.7,1.7,1.7\n",
"\n"
]
}
Expand Down Expand Up @@ -796,26 +800,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
"(tensor(8.5456), tensor(2.0282))\n",
"(tensor(3.9854), tensor(4.0685))\n",
"(tensor(2.6557), tensor(5.9472))\n",
"(tensor(4.1695), tensor(5.6366))\n",
"(tensor(6.2157), tensor(5.3507))\n",
"(tensor(9.3430), tensor(5.0873))\n",
"(tensor(14.1110), tensor(4.8449))\n",
"(tensor(3.7684), tensor(6.6620))\n",
"(tensor(5.5235), tensor(6.2948))\n",
"(tensor(8.2224), tensor(5.9567))\n",
"(tensor(11.8946), tensor(5.6454))\n",
"(tensor(17.5314), tensor(5.3587))\n",
"(tensor(26.1555), tensor(5.0948))\n",
"(tensor(38.8209), tensor(4.8517))\n",
"(tensor(58.1248), tensor(4.6280))\n",
"(tensor(87.2211), tensor(4.4219))\n",
"(tensor(6.7796), tensor(6.2726))\n",
"(tensor(8.2213), tensor(1.9510))\n",
"(tensor(3.6043), tensor(4.1858))\n",
"(tensor(2.4923), tensor(6.2143))\n",
"(tensor(3.4277), tensor(5.8206))\n",
"(tensor(4.9015), tensor(5.4632))\n",
"(tensor(7.4729), tensor(5.1389))\n",
"(tensor(11.2073), tensor(4.8445))\n",
"(tensor(3.5671), tensor(6.8121))\n",
"(tensor(5.2437), tensor(6.3632))\n",
"(tensor(7.4362), tensor(5.9558))\n",
"(tensor(10.6537), tensor(5.5859))\n",
"(tensor(15.9462), tensor(5.2503))\n",
"(tensor(23.9526), tensor(4.9456))\n",
"(tensor(36.4160), tensor(4.6690))\n",
"(tensor(55.7545), tensor(4.4180))\n",
"(tensor(86.7965), tensor(4.1902))\n",
"(tensor(7.4459), tensor(6.2183))\n",
"rating history: 3,1,1,3,3,3,3,1,3,3,3,3,3,3,3,3,1\n",
"interval history: 0,9,4,3,4,6,9,14,4,6,8,12,18,26,39,58,87,7\n",
"difficulty history: 0,2.0,4.1,5.9,5.6,5.4,5.1,4.8,6.7,6.3,6.0,5.6,5.4,5.1,4.9,4.6,4.4,6.3\n"
"interval history: 0,8,4,2,3,5,7,11,4,5,7,11,16,24,36,56,87,7\n",
"difficulty history: 0,2.0,4.2,6.2,5.8,5.5,5.1,4.8,6.8,6.4,6.0,5.6,5.3,4.9,4.7,4.4,4.2,6.2\n"
]
}
],
Expand Down

0 comments on commit aa6224b

Please sign in to comment.