Skip to content

Commit

Permalink
Add init_shape API for states.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Dec 21, 2024
1 parent b5575d4 commit 2b57f87
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 15 deletions.
16 changes: 10 additions & 6 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use web_rwkv::{
loader::{Loader, Lora},
model::{Bundle, ContextAutoLimits, ModelBuilder, ModelInfo, ModelVersion, Quant, State},
softmax::softmax_one,
v4, v5, v6, v7, TokioRuntime,
v4, v5, v6, v7, Runtime, TokioRuntime,
},
tensor::{TensorCpu, TensorInit, TensorShape},
tokenizer::Tokenizer,
Expand Down Expand Up @@ -258,30 +258,34 @@ async fn main() -> Result<()> {
None => builder,
};

let (runtime, state): (_, Box<dyn State>) = match info.version {
let (runtime, state): (Box<dyn Runtime>, Box<dyn State>) = match info.version {
ModelVersion::V4 => {
let model = builder.build_v4().await?;
let bundle = v4::Bundle::<f16>::new(model, 1);
let state = bundle.state();
(TokioRuntime::new(bundle).await, Box::new(state))
let runtime = TokioRuntime::new(bundle).await;
(Box::new(runtime), Box::new(state))
}
ModelVersion::V5 => {
let model = builder.build_v5().await?;
let bundle = v5::Bundle::<f16>::new(model, 1);
let state = bundle.state();
(TokioRuntime::new(bundle).await, Box::new(state))
let runtime = TokioRuntime::new(bundle).await;
(Box::new(runtime), Box::new(state))
}
ModelVersion::V6 => {
let model = builder.build_v6().await?;
let bundle = v6::Bundle::<f16>::new(model, 1);
let state = bundle.state();
(TokioRuntime::new(bundle).await, Box::new(state))
let runtime = TokioRuntime::new(bundle).await;
(Box::new(runtime), Box::new(state))
}
ModelVersion::V7 => {
let model = builder.build_v7().await?;
let bundle = v7::Bundle::<f16>::new(model, 1);
let state = bundle.state();
(TokioRuntime::new(bundle).await, Box::new(state))
let runtime = TokioRuntime::new(bundle).await;
(Box::new(runtime), Box::new(state))
}
};

Expand Down
4 changes: 3 additions & 1 deletion src/runtime/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
context::{Context, ContextBuilder},
impl_deserialize_seed,
num::Scalar,
tensor::{kind::ReadWrite, TensorCpu, TensorError, TensorGpu, TensorGpuView},
tensor::{kind::ReadWrite, shape::Shape, TensorCpu, TensorError, TensorGpu, TensorGpuView},
};

#[wasm_bindgen]
Expand Down Expand Up @@ -87,6 +87,8 @@ pub trait AsAny {
pub trait State {
/// Batch number of this state.
fn num_batch(&self) -> usize;
/// Shape of the initialized one-batch CPU state.
fn init_shape(&self) -> Shape;
/// Initialize a one-batch state on CPU.
fn init(&self) -> TensorCpu<f32>;
/// The part of the state that is used in an `att` layer.
Expand Down
9 changes: 7 additions & 2 deletions src/runtime/v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ impl super::model::State for State {
self.data.shape()[2]
}

#[inline]
fn init_shape(&self) -> Shape {
let info = &self.info;
[info.num_emb, 5 * info.num_layer, 1, 1].into()
}

fn init(&self) -> TensorCpu<f32> {
let info = &self.info;
let data = (0..info.num_layer)
Expand All @@ -163,8 +169,7 @@ impl super::model::State for State {
})
.collect_vec()
.concat();
let shape = Shape::new(info.num_emb, 5 * info.num_layer, 1, 1);
TensorCpu::from_data(shape, data).unwrap()
TensorCpu::from_data(self.init_shape(), data).unwrap()
}

fn att(&self, layer: usize) -> Result<TensorGpuView<f32>, TensorError> {
Expand Down
9 changes: 7 additions & 2 deletions src/runtime/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,15 @@ impl super::model::State for State {
self.data[0].shape()[2]
}

fn init(&self) -> TensorCpu<f32> {
#[inline]
fn init_shape(&self) -> Shape {
let info = &self.info;
let head_size = info.num_emb / info.num_head;
let shape = Shape::new(info.num_emb, head_size + 2, info.num_layer, 1);
[info.num_emb, head_size + 2, info.num_layer, 1].into()
}

fn init(&self) -> TensorCpu<f32> {
let shape = self.init_shape();
let data = vec![0.0; shape.len()];
TensorCpu::from_data(shape, data).unwrap()
}
Expand Down
9 changes: 7 additions & 2 deletions src/runtime/v6.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,15 @@ impl super::model::State for State {
self.data[0].shape()[2]
}

fn init(&self) -> TensorCpu<f32> {
#[inline]
fn init_shape(&self) -> Shape {
let info = &self.info;
let head_size = info.num_emb / info.num_head;
let shape = Shape::new(info.num_emb, head_size + 2, info.num_layer, 1);
[info.num_emb, head_size + 2, info.num_layer, 1].into()
}

fn init(&self) -> TensorCpu<f32> {
let shape = self.init_shape();
let data = vec![0.0; shape.len()];
TensorCpu::from_data(shape, data).unwrap()
}
Expand Down
9 changes: 7 additions & 2 deletions src/runtime/v7.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,15 @@ impl super::model::State for State {
self.data[0].shape()[2]
}

fn init(&self) -> TensorCpu<f32> {
#[inline]
fn init_shape(&self) -> Shape {
let info = &self.info;
let head_size = info.num_emb / info.num_head;
let shape = Shape::new(info.num_emb, head_size + 2, info.num_layer, 1);
[info.num_emb, head_size + 2, info.num_layer, 1].into()
}

fn init(&self) -> TensorCpu<f32> {
let shape = self.init_shape();
let data = vec![0.0; shape.len()];
TensorCpu::from_data(shape, data).unwrap()
}
Expand Down

0 comments on commit 2b57f87

Please sign in to comment.