Skip to content

Commit

Permalink
add .to_dtype for input tensor dtype conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Apr 10, 2024
1 parent 385c757 commit 13f2d2b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions atoma-inference/src/jrpc_server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{net::Shutdown, sync::Arc};
use std::sync::Arc;

use axum::{extract::State, http::StatusCode, routing::post, Extension, Json, Router};
use axum::{http::StatusCode, routing::post, Extension, Json, Router};
use serde_json::{json, Value};
use tokio::sync::{mpsc, oneshot};

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl ModelTrait for MambaModel {
let mut output = String::new();

for &token in tokens.iter() {
let input = Tensor::new(&[token], &self.device)?;
let input = Tensor::new(&[token], &self.device)?.to_dtype(self.dtype)?;
let logits = self.model.forward(&input, &mut state)?;

next_logits = Some(logits);
Expand Down Expand Up @@ -213,7 +213,7 @@ impl ModelTrait for MambaModel {
output.push_str(t.as_str());
}

let input = Tensor::new(&[next_token], &self.device)?;
let input = Tensor::new(&[next_token], &self.device)?.to_dtype(self.dtype)?;
next_logits = Some(self.model.forward(&input, &mut state)?);
}
let dt = start_gen.elapsed();
Expand Down

0 comments on commit 13f2d2b

Please sign in to comment.