Skip to content

Commit

Permalink
Fix/filter out zero rating (open-spaced-repetition#10)
Browse files Browse the repository at this point in the history
* Fix/filter out zero rating

* update version
  • Loading branch information
L-M-Sherlock authored Jul 30, 2023
1 parent c7c1507 commit 118c2f5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.5.4"
version = "4.5.5"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
6 changes: 4 additions & 2 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def step(self, X: Tensor, state: Tensor) -> Tensor:
keys = keys.view(1, -1).expand(X[:,1].long().size(0), -1)
index = (X[:,1].long().unsqueeze(1) == keys).nonzero(as_tuple=True)
# first learn, init memory states
new_s = self.w[index[1]]
new_s = torch.ones_like(state[:,0])
new_s[index[0]] = self.w[index[1]]
new_d = self.w[4] - self.w[5] * (X[:,1] - 3)
new_d = new_d.clamp(1, 10)
else:
Expand Down Expand Up @@ -401,7 +402,7 @@ def cum_concat(x):
r_history = df.groupby('card_id', group_keys=False)['review_rating'].apply(lambda x: cum_concat([[i] for i in x]))
df['r_history']=[','.join(map(str, item[:-1])) for sublist in r_history for item in sublist]
df = df.groupby('card_id').filter(lambda group: group['review_time'].min() > time.mktime(datetime.strptime(revlog_start_date, "%Y-%m-%d").timetuple()) * 1000)
df = df[df['review_rating'] != 0].copy()
df = df[(df['review_rating'] != 0) & (df['r_history'].str.contains("0") == 0)].copy()
df['y'] = df['review_rating'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x])

def remove_outliers(group: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -415,6 +416,7 @@ def remove_outliers(group: pd.DataFrame) -> pd.DataFrame:
return group

df[df['i'] == 2] = df[df['i'] == 2].groupby(by=['r_history', 't_history'], as_index=False, group_keys=False).apply(remove_outliers)
df.dropna(inplace=True)

def remove_non_continuous_rows(group):
discontinuity = group['i'].diff().fillna(1).ne(1)
Expand Down

0 comments on commit 118c2f5

Please sign in to comment.