Skip to content

Commit

Permalink
Feat/cosine_annealing_lr (#25)
Browse files Browse the repository at this point in the history
* Feat/cosine_annealing_lr

* is_empty for FSRSDataset

* replace f64 with LearningRate && add test

---------

Co-authored-by: Asuka Minato <[email protected]>
  • Loading branch information
L-M-Sherlock and asukaminato0721 authored Aug 25, 2023
1 parent 7d80f4e commit 87d11f0
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
92 changes: 92 additions & 0 deletions src/cosine_annealing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use burn::{lr_scheduler::LRScheduler, LearningRate};
#[derive(Clone, Debug)]
pub struct CosineAnnealingLR {
t_max: f64,
eta_min: f64,
init_lr: LearningRate,
step_count: f64,
current_lr: LearningRate,
}

impl CosineAnnealingLR {
pub fn init(t_max: f64, init_lr: LearningRate) -> CosineAnnealingLR {
CosineAnnealingLR {
t_max,
eta_min: 0.0,
init_lr,
step_count: 0.0,
current_lr: init_lr,
}
}
}

impl LRScheduler for CosineAnnealingLR {
type Record = usize;

fn step(&mut self) -> LearningRate {
self.step_count += 1.0;
use std::f64::consts::PI;
fn cosine_annealing_lr(
init_lr: LearningRate,
lr: LearningRate,
step_count: f64,
t_max: f64,
eta_min: f64,
) -> LearningRate {
let cosine_arg = PI * step_count / t_max;
if (step_count - 1.0 - t_max) % (2.0 * t_max) == 0.0 {
(init_lr - eta_min) * (1.0 - f64::cos(PI / t_max)) / 2.0
} else {
(1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max))
* (lr - eta_min)
+ eta_min
}
}
self.current_lr = cosine_annealing_lr(
self.init_lr,
self.current_lr,
self.step_count,
self.t_max,
self.eta_min,
);
self.current_lr
}

fn to_record(&self) -> Self::Record {
self.step_count as usize
}

fn load_record(mut self, record: Self::Record) -> Self {
self.step_count = record as LearningRate;
self
}
}

#[test]
fn test_lr_scheduler() {
let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1);
let mut lrs = vec![];
for i in 0..200000 {
if i % 20000 == 0 {
lrs.push(lr_scheduler.current_lr);
}
lr_scheduler.step();
}
lrs.push(lr_scheduler.current_lr);
assert_eq!(
lrs,
vec![
0.1,
0.09045084971874785,
0.06545084971874875,
0.034549150281253875,
0.009549150281252989,
0.0,
0.009549150281252692,
0.03454915028125239,
0.06545084971874746,
0.09045084971874952,
0.10000000000000353
]
)
}
8 changes: 8 additions & 0 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ impl FSRSDataset {
Self::new()
}

pub fn len(&self) -> usize {
self.dataset.len()
}

pub fn is_empty(&self) -> bool {
self.dataset.is_empty()
}

fn new() -> Self {
let dataset = InMemDataset::<FSRSItem>::new(anki_to_fsrs());
Self { dataset }
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod convertor;
mod cosine_annealing;
pub mod dataset;
pub mod model;
pub mod training;
Expand Down
8 changes: 7 additions & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::cosine_annealing::CosineAnnealingLR;
use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
use crate::model::{Model, ModelConfig};
use crate::weight_clipper::weight_clipper;
Expand Down Expand Up @@ -127,6 +128,11 @@ pub fn train<B: ADBackend<FloatElem = f32>>(
.num_workers(config.num_workers)
.build(FSRSDataset::test());

let lr_scheduler = CosineAnnealingLR::init(
(FSRSDataset::train().len() * config.num_epochs) as f64,
config.learning_rate,
);

let learner = LearnerBuilder::new(artifact_dir)
// .metric_train_plot(AccuracyMetric::new())
// .metric_valid_plot(AccuracyMetric::new())
Expand All @@ -138,7 +144,7 @@ pub fn train<B: ADBackend<FloatElem = f32>>(
.build(
config.model.init::<B>(),
config.optimizer.init(),
config.learning_rate,
lr_scheduler,
);

let mut model_trained = learner.fit(dataloader_train, dataloader_test);
Expand Down

0 comments on commit 87d11f0

Please sign in to comment.