Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update burn to v0.15.0 #251

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,841 changes: 1,150 additions & 691 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ keywords = ["spaced-repetition", "algorithm", "fsrs", "machine-learning"]
license = "BSD-3-Clause"
readme = "README.md"
repository = "https://github.com/open-spaced-repetition/fsrs-rs"
rust-version = "1.75.0"
rust-version = "1.81.0"
description = "FSRS for Rust, including Optimizer and Scheduler"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies.burn]
version = "0.13.2"
version = "0.15.0"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
default-features = false
features = ["std", "train", "ndarray"]

[dev-dependencies.burn]
version = "0.13.2"
version = "0.15.0"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
Expand All @@ -48,7 +48,7 @@ chrono-tz = "0.8.4"
criterion = { version = "0.5.1" }
csv = "1.3.0"
fern = "0.6.0"
rusqlite = { version = "0.30.0" }
rusqlite = { version = "0.32.0" }

[[bench]]
name = "benchmark"
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[toolchain]
# older versions may fail to compile; newer versions may fail the clippy tests
channel = "1.80"
channel = "1.81"
components = ["rustfmt", "clippy"]
8 changes: 4 additions & 4 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ mod tests {
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: [7, batch_size]
dims: vec![7, batch_size]
}
);
let batch = iterator.next().unwrap();
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: [6, batch_size]
dims: vec![6, batch_size]
}
);

Expand All @@ -156,14 +156,14 @@ mod tests {
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: [19, batch_size]
dims: vec![19, batch_size]
}
);
let batch = iterator.next().unwrap();
assert_eq!(
batch.t_historys.shape(),
Shape {
dims: [9, batch_size]
dims: vec![9, batch_size]
}
);

Expand Down
21 changes: 11 additions & 10 deletions src/convertor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use burn::backend::ndarray::NdArrayDevice;
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::Dataset;
use burn::data::dataset::InMemDataset;
use burn::tensor::Data;
use burn::tensor::cast::ToElement;
use burn::tensor::TensorData;
use chrono::prelude::*;
use chrono_tz::Tz;
use itertools::Itertools;
Expand Down Expand Up @@ -391,15 +392,15 @@ fn conversion_works() {
let batcher = FSRSBatcher::<NdArrayAutodiff>::new(device);
let res = batcher.batch(vec![fsrs_items.pop().unwrap()]);
assert_eq!(res.delta_ts.into_scalar(), 64.0);
assert_eq!(
res.r_historys.squeeze(1).to_data(),
Data::from([3.0, 4.0, 3.0, 3.0, 3.0, 2.0])
);
assert_eq!(
res.t_historys.squeeze(1).to_data(),
Data::from([0.0, 0.0, 5.0, 10.0, 22.0, 56.0])
);
assert_eq!(res.labels.to_data(), Data::from([1]));
res.r_historys
.squeeze::<1>(1)
.to_data()
.assert_approx_eq(&TensorData::from([3.0, 4.0, 3.0, 3.0, 3.0, 2.0]), 5);
res.t_historys
.squeeze::<1>(1)
.to_data()
.assert_approx_eq(&TensorData::from([0.0, 0.0, 5.0, 10.0, 22.0, 56.0]), 5);
assert_eq!(res.labels.into_scalar().to_i32(), 1);
}

#[test]
Expand Down
17 changes: 8 additions & 9 deletions src/cosine_annealing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ impl CosineAnnealingLR {
}
}

impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
type Record = usize;
impl LrScheduler for CosineAnnealingLR {
type Record<B: Backend> = usize;

fn step(&mut self) -> LearningRate {
self.step_count += 1.0;
Expand Down Expand Up @@ -52,11 +52,11 @@ impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
self.current_lr
}

fn to_record(&self) -> Self::Record {
fn to_record<B: Backend>(&self) -> Self::Record<B> {
self.step_count as usize
}

fn load_record(mut self, record: Self::Record) -> Self {
fn load_record<B: Backend>(mut self, record: Self::Record<B>) -> Self {
self.step_count = record as LearningRate;
self
}
Expand All @@ -65,23 +65,22 @@ impl<B: Backend> LrScheduler<B> for CosineAnnealingLR {
#[cfg(test)]
mod tests {
use super::*;
use burn::{backend::NdArray, tensor::Data};
type Backend = NdArray<f32>;
use burn::tensor::TensorData;

#[test]
fn lr_scheduler() {
let mut lr_scheduler = CosineAnnealingLR::init(100000.0, 1.0e-1);

let lrs = (0..=200000)
.map(|_| {
LrScheduler::<Backend>::step(&mut lr_scheduler);
LrScheduler::step(&mut lr_scheduler);
lr_scheduler.current_lr
})
.step_by(20000)
.collect::<Vec<_>>();

Data::from(&lrs[..]).assert_approx_eq(
&Data::from([
TensorData::from(&lrs[..]).assert_approx_eq(
&TensorData::from([
0.1,
0.09045084971874785,
0.06545084971874875,
Expand Down
60 changes: 31 additions & 29 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::{HashMap, HashSet};
use burn::data::dataloader::batcher::Batcher;
use burn::{
data::dataset::Dataset,
tensor::{backend::Backend, Data, ElementConversion, Float, Int, Shape, Tensor},
tensor::{backend::Backend, Float, Int, Shape, Tensor, TensorData},
};

use itertools::Itertools;
Expand Down Expand Up @@ -106,24 +106,22 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
item.history().map(|r| (r.delta_t, r.rating)).unzip();
delta_t.resize(pad_size, 0);
rating.resize(pad_size, 0);
let delta_t = Tensor::from_data(
Data::new(
let delta_t = Tensor::<B, 2>::from_floats(
TensorData::new(
delta_t,
Shape {
dims: [1, pad_size],
dims: vec![1, pad_size],
},
)
.convert(),
),
&self.device,
);
let rating = Tensor::from_data(
Data::new(
let rating = Tensor::<B, 2>::from_data(
TensorData::new(
rating,
Shape {
dims: [1, pad_size],
dims: vec![1, pad_size],
},
)
.convert(),
),
&self.device,
);
(delta_t, rating)
Expand All @@ -134,12 +132,12 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
.iter()
.map(|item| {
let current = item.current();
let delta_t = Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
let delta_t = Tensor::<B, 1>::from_floats([current.delta_t], &self.device);
let label = match current.rating {
1 => 0.0,
_ => 1.0,
1 => 0,
_ => 1,
};
let label = Tensor::from_data(Data::from([label.elem()]), &self.device);
let label = Tensor::<B, 1, Int>::from_ints([label], &self.device);
(delta_t, label)
})
.unzip();
Expand Down Expand Up @@ -436,29 +434,33 @@ mod tests {
},
];
let batch = batcher.batch(items);
assert_eq!(
batch.t_historys.to_data(),
Data::from([
batch.t_historys.to_data().assert_approx_eq(
&TensorData::from([
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 5.0, 0.0, 2.0, 2.0, 2.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0, 0.0]
])
[0.0, 0.0, 0.0, 0.0, 0.0, 16.0, 0.0, 0.0],
]),
5,
);
assert_eq!(
batch.r_historys.to_data(),
Data::from([
batch.r_historys.to_data().assert_approx_eq(
&TensorData::from([
[4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0],
[0.0, 3.0, 0.0, 3.0, 3.0, 3.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0]
])
[0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0],
]),
5,
);
assert_eq!(
batch.delta_ts.to_data(),
Data::from([5.0, 11.0, 2.0, 6.0, 16.0, 39.0, 1.0, 1.0])

batch.delta_ts.to_data().assert_approx_eq(
&TensorData::from([5.0, 11.0, 2.0, 6.0, 16.0, 39.0, 1.0, 1.0]),
5,
);
assert_eq!(batch.labels.to_data(), Data::from([1, 1, 1, 1, 1, 1, 0, 1]));
batch
.labels
.to_data()
.assert_approx_eq(&TensorData::from([1, 1, 1, 1, 1, 1, 0, 1]), 5);
}

#[test]
Expand Down
Loading
Loading