Skip to content

Commit

Permalink
refactor: Optimize batch shuffling implementation for better performa…
Browse files Browse the repository at this point in the history
…nce (#252)

* improve performance

* refactor ShuffleDataLoader

* add more assertion

* bump version
L-M-Sherlock authored Oct 30, 2024
1 parent 0c8a7ae commit 7477d2b
Showing 5 changed files with 134 additions and 469 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.4.2"
version = "1.4.3"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
554 changes: 104 additions & 450 deletions src/batch_shuffle.rs

Large diffs are not rendered by default.

24 changes: 17 additions & 7 deletions src/dataset.rs
Original file line number Diff line number Diff line change
@@ -107,15 +107,25 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
delta_t.resize(pad_size, 0);
rating.resize(pad_size, 0);
let delta_t = Tensor::from_data(
Data::new(delta_t, Shape { dims: [pad_size] }).convert(),
Data::new(
delta_t,
Shape {
dims: [1, pad_size],
},
)
.convert(),
&self.device,
)
.unsqueeze();
);
let rating = Tensor::from_data(
Data::new(rating, Shape { dims: [pad_size] }).convert(),
Data::new(
rating,
Shape {
dims: [1, pad_size],
},
)
.convert(),
&self.device,
)
.unsqueeze();
);
(delta_t, rating)
})
.unzip();
@@ -156,7 +166,7 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
}

pub(crate) struct FSRSDataset {
items: Vec<FSRSItem>,
pub(crate) items: Vec<FSRSItem>,
}

impl Dataset<FSRSItem> for FSRSDataset {
21 changes: 11 additions & 10 deletions src/training.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use crate::batch_shuffle::BatchShuffledDataLoaderBuilder;
use crate::batch_shuffle::{BatchTensorDataset, ShuffleDataLoader};
use crate::cosine_annealing::CosineAnnealingLR;
use crate::dataset::{prepare_training_data, FSRSBatcher, FSRSDataset, FSRSItem};
use crate::dataset::{prepare_training_data, FSRSDataset, FSRSItem};
use crate::error::Result;
use crate::model::{Model, ModelConfig};
use crate::parameter_clipper::parameter_clipper;
use crate::pre_training::{pretrain, smooth_and_fill};
use crate::{FSRSError, DEFAULT_PARAMETERS, FSRS};
use burn::backend::Autodiff;

use burn::data::dataloader::DataLoaderBuilder;
use burn::lr_scheduler::LrScheduler;
use burn::module::AutodiffModule;
use burn::nn::loss::Reduction;
@@ -325,17 +324,19 @@ fn train<B: AutodiffBackend>(

// Training data
let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs;
let batcher_train = FSRSBatcher::<B>::new(device.clone());
let dataloader_train = BatchShuffledDataLoaderBuilder::new(batcher_train).build(
let batch_dataset = BatchTensorDataset::<B>::new(
FSRSDataset::from(train_set),
config.batch_size,
config.seed,
device.clone(),
);
let dataloader_train = ShuffleDataLoader::new(batch_dataset, config.seed);

let batcher_valid = FSRSBatcher::new(device);
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.build(FSRSDataset::from(test_set.clone()));
let batch_dataset = BatchTensorDataset::<B::InnerBackend>::new(
FSRSDataset::from(test_set.clone()),
config.batch_size,
device,
);
let dataloader_valid = ShuffleDataLoader::new(batch_dataset, config.seed);

let mut lr_scheduler = CosineAnnealingLR::init(iterations as f64, config.learning_rate);
let interrupter = TrainingInterrupter::new();

0 comments on commit 7477d2b

Please sign in to comment.