Skip to content

Commit

Permalink
update burn to v0.15.0
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 29, 2024
1 parent c45637d commit 151fddc
Show file tree
Hide file tree
Showing 12 changed files with 1,319 additions and 849 deletions.
1,841 changes: 1,150 additions & 691 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ keywords = ["spaced-repetition", "algorithm", "fsrs", "machine-learning"]
license = "BSD-3-Clause"
readme = "README.md"
repository = "https://github.com/open-spaced-repetition/fsrs-rs"
rust-version = "1.75.0"
rust-version = "1.81.0"
description = "FSRS for Rust, including Optimizer and Scheduler"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies.burn]
version = "0.13.2"
version = "0.15.0"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
default-features = false
features = ["std", "train", "ndarray"]

[dev-dependencies.burn]
version = "0.13.2"
version = "0.15.0"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
Expand All @@ -48,7 +48,7 @@ chrono-tz = "0.8.4"
criterion = { version = "0.5.1" }
csv = "1.3.0"
fern = "0.6.0"
rusqlite = { version = "0.30.0" }
rusqlite = { version = "0.32.0" }

[[bench]]
name = "benchmark"
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
# older versions may fail to compile; newer versions may fail the clippy tests
channel = "1.80"
channel = "1.81"
components = ["rustfmt", "clippy"]
4 changes: 2 additions & 2 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,12 @@ mod tests {
let item = dataloader.iter().next().unwrap();
assert_eq!(
item.t_historys.shape(),
burn::tensor::Shape { dims: [6, 512] }
burn::tensor::Shape { dims: vec![6, 512] }
);
let item2 = dataloader.iter().next().unwrap();
assert_eq!(
item2.t_historys.shape(),
burn::tensor::Shape { dims: [4, 512] }
burn::tensor::Shape { dims: vec![4, 512] }
);
}

Expand Down
21 changes: 11 additions & 10 deletions src/convertor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use burn::backend::ndarray::NdArrayDevice;
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::Dataset;
use burn::data::dataset::InMemDataset;
use burn::tensor::Data;
use burn::tensor::cast::ToElement;
use burn::tensor::TensorData;
use chrono::prelude::*;
use chrono_tz::Tz;
use itertools::Itertools;
Expand Down Expand Up @@ -391,15 +392,15 @@ fn conversion_works() {
let batcher = FSRSBatcher::<NdArrayAutodiff>::new(device);
let res = batcher.batch(vec![fsrs_items.pop().unwrap()]);
assert_eq!(res.delta_ts.into_scalar(), 64.0);
assert_eq!(
res.r_historys.squeeze(1).to_data(),
Data::from([3.0, 4.0, 3.0, 3.0, 3.0, 2.0])
);
assert_eq!(
res.t_historys.squeeze(1).to_data(),
Data::from([0.0, 0.0, 5.0, 10.0, 22.0, 56.0])
);
assert_eq!(res.labels.to_data(), Data::from([1]));
res.r_historys
.squeeze::<1>(1)
.to_data()
.assert_approx_eq(&TensorData::from([3.0, 4.0, 3.0, 3.0, 3.0, 2.0]), 5);
res.t_historys
.squeeze::<1>(1)
.to_data()
.assert_approx_eq(&TensorData::from([0.0, 0.0, 5.0, 10.0, 22.0, 56.0]), 5);
assert_eq!(res.labels.into_scalar().to_i32(), 1);
}

#[test]
Expand Down
17 changes: 8 additions & 9 deletions src/cosine_annealing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ impl CosineAnnealingLR {
}
}

impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
type Record = usize;
impl LrScheduler for CosineAnnealingLR {
type Record<B: Backend> = usize;

fn step(&mut self) -> LearningRate {
self.step_count += 1.0;
Expand Down Expand Up @@ -52,11 +52,11 @@ impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
self.current_lr
}

fn to_record(&self) -> Self::Record {
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.step_count as usize
}

fn load_record(mut self, record: Self::Record) -> Self {
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.step_count = record as LearningRate;
self
}
Expand All @@ -65,23 +65,22 @@ impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
#[cfg(test)]
mod tests {
use super::*;
use burn::{backend::NdArray, tensor::Data};
type Backend = NdArray<f32>;
use burn::tensor::TensorData;

#[test]
fn lr_scheduler() {
let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1);

let lrs = (0..=200000)
.map(|_| {
LrScheduler::<Backend>::step(&mut lr_scheduler);
LrScheduler::step(&mut lr_scheduler);
lr_scheduler.current_lr
})
.step_by(20000)
.collect::<Vec<_>>();

Data::from(&lrs[..]).assert_approx_eq(
&Data::from([
TensorData::from(&lrs[..]).assert_approx_eq(
&TensorData::from([
0.1,
0.09045084971874785,
0.06545084971874875,
Expand Down
60 changes: 37 additions & 23 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet};
use burn::data::dataloader::batcher::Batcher;
use burn::{
data::dataset::Dataset,
tensor::{backend::Backend, Data, ElementConversion, Float, Int, Shape, Tensor},
tensor::{backend::Backend, Float, Int, Shape, Tensor, TensorData},
};

use itertools::Itertools;
Expand Down Expand Up @@ -106,13 +106,23 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
item.history().map(|r| (r.delta_t, r.rating)).unzip();
delta_t.resize(pad_size, 0);
rating.resize(pad_size, 0);
let delta_t = Tensor::from_data(
Data::new(delta_t, Shape { dims: [pad_size] }).convert(),
let delta_t = Tensor::<B, 1>::from_floats(
TensorData::new(
delta_t,
Shape {
dims: vec![pad_size],
},
),
&self.device,
)
.unsqueeze();
let rating = Tensor::from_data(
Data::new(rating, Shape { dims: [pad_size] }).convert(),
let rating = Tensor::<B, 1>::from_data(
TensorData::new(
rating,
Shape {
dims: vec![pad_size],
},
),
&self.device,
)
.unsqueeze();
Expand All @@ -124,12 +134,12 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
.iter()
.map(|item| {
let current = item.current();
let delta_t = Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
let delta_t = Tensor::<B, 1>::from_floats([current.delta_t], &self.device);
let label = match current.rating {
1 => 0.0,
_ => 1.0,
1 => 0,
_ => 1,
};
let label = Tensor::from_data(Data::from([label.elem()]), &self.device);
let label = Tensor::<B, 1, Int>::from_ints([label], &self.device);
(delta_t, label)
})
.unzip();
Expand Down Expand Up @@ -426,29 +436,33 @@ mod tests {
},
];
let batch = batcher.batch(items);
assert_eq!(
batch.t_historys.to_data(),
Data::from([
batch.t_historys.to_data().assert_approx_eq(
&TensorData::from([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 5.0, 0.0, 2.0, 2.0, 2.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0, 0.0]
])
[0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0, 0.0],
]),
5,
);
assert_eq!(
batch.r_historys.to_data(),
Data::from([
batch.r_historys.to_data().assert_approx_eq(
&TensorData::from([
[4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0],
[0.0, 3.0, 0.0, 3.0, 3.0, 3.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0]
])
[0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0],
]),
5,
);
assert_eq!(
batch.delta_ts.to_data(),
Data::from([5.0, 11.0, 2.0, 6.0, 16.0, 39.0, 1.0, 1.0])

batch.delta_ts.to_data().assert_approx_eq(
&TensorData::from([5.0, 11.0, 2.0, 6.0, 16.0, 39.0, 1.0, 1.0]),
5,
);
assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1]));
batch
.labels
.to_data()
.assert_approx_eq(&TensorData::from([1, 1, 1, 1, 1, 1, 0, 1]), 5);
}

#[test]
Expand Down
53 changes: 27 additions & 26 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::ops::{Add, Sub};

use crate::model::{Get, MemoryStateTensors, FSRS};
use burn::nn::loss::Reduction;
use burn::tensor::{Data, Shape, Tensor};
use burn::tensor::cast::ToElement;
use burn::tensor::{Shape, Tensor, TensorData};
use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend};

use crate::dataset::FSRSBatch;
Expand Down Expand Up @@ -45,23 +46,17 @@ pub struct MemoryState {
impl<B: Backend> From<MemoryStateTensors<B>> for MemoryState {
fn from(m: MemoryStateTensors<B>) -> Self {
Self {
stability: m.stability.to_data().value[0].elem(),
difficulty: m.difficulty.to_data().value[0].elem(),
stability: m.stability.into_scalar().elem(),
difficulty: m.difficulty.into_scalar().elem(),
}
}
}

impl<B: Backend> From<MemoryState> for MemoryStateTensors<B> {
fn from(m: MemoryState) -> Self {
Self {
stability: Tensor::from_data(
Data::new(vec![m.stability.elem()], Shape { dims: [1] }),
&B::Device::default(),
),
difficulty: Tensor::from_data(
Data::new(vec![m.difficulty.elem()], Shape { dims: [1] }),
&B::Device::default(),
),
stability: Tensor::from_floats([m.stability], &B::Device::default()),
difficulty: Tensor::from_floats([m.difficulty], &B::Device::default()),
}
}
}
Expand All @@ -84,14 +79,14 @@ impl<B: Backend> FSRS<B> {
let (time_history, rating_history) =
item.reviews.iter().map(|r| (r.delta_t, r.rating)).unzip();
let size = item.reviews.len();
let time_history = Tensor::from_data(
Data::new(time_history, Shape { dims: [size] }).convert(),
let time_history = Tensor::<B, 1>::from_data(
TensorData::new(time_history, Shape { dims: vec![size] }),
&self.device(),
)
.unsqueeze()
.transpose();
let rating_history = Tensor::from_data(
Data::new(rating_history, Shape { dims: [size] }).convert(),
let rating_history = Tensor::<B, 1>::from_data(
TensorData::new(rating_history, Shape { dims: vec![size] }),
&self.device(),
)
.unsqueeze()
Expand Down Expand Up @@ -147,7 +142,7 @@ impl<B: Backend> FSRS<B> {
let stability = stability.unwrap_or_else(|| {
// get initial stability for new card
let rating = Tensor::from_data(
Data::new(vec![rating.elem()], Shape { dims: [1] }),
TensorData::new(vec![rating], Shape { dims: vec![1] }),
&self.device(),
);
let model = self.model();
Expand All @@ -165,7 +160,7 @@ impl<B: Backend> FSRS<B> {
days_elapsed: u32,
) -> Result<NextStates> {
let delta_t = Tensor::from_data(
Data::new(vec![days_elapsed.elem()], Shape { dims: [1] }),
TensorData::new(vec![days_elapsed], Shape { dims: vec![1] }),
&self.device(),
);
let current_memory_state_tensors = current_memory_state.map(MemoryStateTensors::from);
Expand All @@ -175,7 +170,7 @@ impl<B: Backend> FSRS<B> {
let state = MemoryState::from(model.step(
delta_t.clone(),
Tensor::from_data(
Data::new(vec![rating.elem()], Shape { dims: [1] }),
TensorData::new(vec![rating], Shape { dims: vec![1] }),
&self.device(),
),
current_memory_state_tensors.clone(),
Expand Down Expand Up @@ -223,15 +218,15 @@ impl<B: Backend> FSRS<B> {
for chunk in items.chunks(512) {
let batch = batcher.batch(chunk.to_vec());
let (_state, retention) = infer::<B>(model, batch.clone());
let pred = retention.clone().to_data().convert::<f32>().value;
let true_val = batch.labels.clone().to_data().convert::<f32>().value;
let pred = retention.clone().to_data().to_vec::<f32>().unwrap();
let true_val = batch.labels.clone().to_data().to_vec::<i64>().unwrap();
all_retention.push(retention);
all_labels.push(batch.labels);
izip!(chunk, pred, true_val).for_each(|(item, p, y)| {
let bin = item.r_matrix_index();
let (pred, real, count) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0));
*pred += p;
*real += y;
*real += y as f32;
*count += 1.0;
});
progress_info.current += chunk.len();
Expand All @@ -253,7 +248,7 @@ impl<B: Backend> FSRS<B> {
let all_labels = Tensor::cat(all_labels, 0).float();
let loss = BCELoss::new().forward(all_retention, all_labels, Reduction::Mean);
Ok(ModelEvaluation {
log_loss: loss.to_data().value[0].elem(),
log_loss: loss.into_scalar().to_f32(),
rmse_bins: rmse,
})
}
Expand Down Expand Up @@ -293,14 +288,20 @@ impl<B: Backend> FSRS<B> {
let batch = batcher.batch(chunk.to_vec());

let (_state, retention) = infer::<B>(model_self, batch.clone());
let pred = retention.clone().to_data().convert::<f32>().value;
let pred = retention.clone().to_data().to_vec::<f32>().unwrap();
all_predictions_self.extend(pred);

let (_state, retention) = infer::<B>(model_other, batch.clone());
let pred = retention.clone().to_data().convert::<f32>().value;
let pred = retention.clone().to_data().to_vec::<f32>().unwrap();
all_predictions_other.extend(pred);

let true_val = batch.labels.clone().to_data().convert::<f32>().value;
let true_val: Vec<f32> = batch
.labels
.clone()
.to_data()
.convert::<f32>()
.to_vec()
.unwrap();
all_true_val.extend(true_val);
progress_info.current += chunk.len();
if !progress(progress_info) {
Expand Down Expand Up @@ -397,7 +398,7 @@ mod tests {
0.121442534,
];
fn assert_approx_eq(a: [f32; 2], b: [f32; 2]) {
Data::from(a).assert_approx_eq(&Data::from(b), 5);
TensorData::from(a).assert_approx_eq(&TensorData::from(b), 5);
}
#[test]
fn test_get_bin() {
Expand Down
Loading

0 comments on commit 151fddc

Please sign in to comment.