Skip to content

Commit

Permalink
Feat/option enable_short_term in training (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jan 1, 2025
1 parent a7aaa40 commit 9809beb
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.5.0"
version = "2.0.0"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
2 changes: 1 addition & 1 deletion examples/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Default parameters: {:?}", DEFAULT_PARAMETERS);

// Optimize the FSRS model using the created items
let optimized_parameters = fsrs.compute_parameters(fsrs_items, None)?;
let optimized_parameters = fsrs.compute_parameters(fsrs_items, None, false)?;

println!("Optimized parameters: {:?}", optimized_parameters);

Expand Down
10 changes: 8 additions & 2 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ impl<B: Backend, const N: usize> Pow<B, N> for Tensor<B, N> {
impl<B: Backend> Model<B> {
#[allow(clippy::new_without_default)]
pub fn new(config: ModelConfig) -> Self {
let initial_params = config
let mut initial_params: Vec<f32> = config
.initial_stability
.unwrap_or_else(|| DEFAULT_PARAMETERS[0..4].try_into().unwrap())
.into_iter()
.chain(DEFAULT_PARAMETERS[4..].iter().copied())
.collect();
if config.freeze_short_term_stability {
initial_params[17] = 0.0;
initial_params[18] = 0.0;
}

Self {
w: Param::from_tensor(Tensor::from_floats(
Expand Down Expand Up @@ -199,8 +203,10 @@ pub(crate) struct MemoryStateTensors<B: Backend> {
#[derive(Config, Module, Debug, Default)]
pub struct ModelConfig {
#[config(default = false)]
pub freeze_stability: bool,
pub freeze_initial_stability: bool,
pub initial_stability: Option<[f32; 4]>,
#[config(default = false)]
pub freeze_short_term_stability: bool,
}

impl ModelConfig {
Expand Down
68 changes: 44 additions & 24 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ impl<B: AutodiffBackend> Model<B> {
self.w.grad_replace(&mut grad, updated_grad_tensor);
grad
}

fn free_short_term_stability(&self, mut grad: B::Gradients) -> B::Gradients {
let grad_tensor = self.w.grad(&grad).unwrap();
let updated_grad_tensor =
grad_tensor.slice_assign([17..19], Tensor::zeros([2], &B::Device::default()));

self.w.grad_remove(&mut grad);
self.w.grad_replace(&mut grad, updated_grad_tensor);
grad
}
}

#[derive(Debug, Default, Clone)]
Expand Down Expand Up @@ -202,6 +212,7 @@ impl<B: Backend> FSRS<B> {
&self,
train_set: Vec<FSRSItem>,
progress: Option<Arc<Mutex<CombinedProgressState>>>,
enable_short_term: bool,
) -> Result<Vec<f32>> {
let finish_progress = || {
if let Some(progress) = &progress {
Expand Down Expand Up @@ -235,8 +246,9 @@ impl<B: Backend> FSRS<B> {
}
let config = TrainingConfig::new(
ModelConfig {
freeze_stability: false,
freeze_initial_stability: !enable_short_term,
initial_stability: Some(initial_stability),
freeze_short_term_stability: !enable_short_term,
},
AdamConfig::new().with_epsilon(1e-8),
);
Expand Down Expand Up @@ -295,7 +307,7 @@ impl<B: Backend> FSRS<B> {
Ok(optimized_parameters)
}

pub fn benchmark(&self, mut train_set: Vec<FSRSItem>) -> Vec<f32> {
pub fn benchmark(&self, mut train_set: Vec<FSRSItem>, enable_short_term: bool) -> Vec<f32> {
let average_recall = calculate_average_recall(&train_set);
let (pre_train_set, _next_train_set) = train_set
.clone()
Expand All @@ -304,8 +316,9 @@ impl<B: Backend> FSRS<B> {
let initial_stability = pretrain(pre_train_set, average_recall).unwrap().0;
let config = TrainingConfig::new(
ModelConfig {
freeze_stability: false,
freeze_initial_stability: !enable_short_term,
initial_stability: Some(initial_stability),
freeze_short_term_stability: !enable_short_term,
},
AdamConfig::new().with_epsilon(1e-8),
);
Expand Down Expand Up @@ -377,9 +390,12 @@ fn train<B: AutodiffBackend>(
Reduction::Sum,
);
let mut gradients = loss.backward();
if model.config.freeze_stability {
if model.config.freeze_initial_stability {
gradients = model.freeze_initial_stability(gradients);
}
if model.config.freeze_short_term_stability {
gradients = model.free_short_term_stability(gradients);
}
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = parameter_clipper(model.w);
Expand Down Expand Up @@ -663,26 +679,30 @@ mod tests {
.unwrap();
}
for items in [anki21_sample_file_converted_to_fsrs(), data_from_csv()] {
let progress = CombinedProgressState::new_shared();
let progress2 = Some(progress.clone());
thread::spawn(move || {
let mut finished = false;
while !finished {
thread::sleep(Duration::from_millis(500));
let guard = progress.lock().unwrap();
finished = guard.finished();
println!("progress: {}/{}", guard.current(), guard.total());
}
});

let fsrs = FSRS::new(Some(&[])).unwrap();
let parameters = fsrs.compute_parameters(items.clone(), progress2).unwrap();
dbg!(&parameters);

// evaluate
let model = FSRS::new(Some(&parameters)).unwrap();
let metrics = model.evaluate(items, |_| true).unwrap();
dbg!(&metrics);
for enable_short_term in [true, false] {
let progress = CombinedProgressState::new_shared();
let progress2 = Some(progress.clone());
thread::spawn(move || {
let mut finished = false;
while !finished {
thread::sleep(Duration::from_millis(500));
let guard = progress.lock().unwrap();
finished = guard.finished();
println!("progress: {}/{}", guard.current(), guard.total());
}
});

let fsrs = FSRS::new(Some(&[])).unwrap();
let parameters = fsrs
.compute_parameters(items.clone(), progress2, enable_short_term)
.unwrap();
dbg!(&parameters);

// evaluate
let model = FSRS::new(Some(&parameters)).unwrap();
let metrics = model.evaluate(items.clone(), |_| true).unwrap();
dbg!(&metrics);
}
}
}
}

0 comments on commit 9809beb

Please sign in to comment.