Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: move errors to atoma-types #51 #53

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ members = [
"atoma-node",
"atoma-json-rpc",
"atoma-storage",
"atoma-types"
"atoma-types",
]

[workspace.package]
Expand All @@ -36,10 +36,13 @@ ethers = "2.0.14"
futures = "0.3.30"
futures-util = "0.3.30"
hf-hub = "0.3.2"
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
image = { version = "0.25.0", default-features = false, features = [
"jpeg",
"png",
] }
serde = "1.0.197"
serde_json = "1.0.114"
sui-sdk = { git = "https://github.com/mystenlabs/sui", package = "sui-sdk"}
sui-sdk = { git = "https://github.com/mystenlabs/sui", package = "sui-sdk" }
# solana-client = "1.18.9"
# solana-sdk = "1.18.8"
rand = "0.8.5"
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/apis/hugging_face.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::path::PathBuf;

use async_trait::async_trait;
use atoma_types::ApiError;
use hf_hub::{
api::sync::{Api, ApiBuilder},
Repo, RepoType,
Expand All @@ -9,7 +10,7 @@ use tracing::error;

use crate::models::ModelId;

use super::{ApiError, ApiTrait};
use super::ApiTrait;

#[async_trait]
impl ApiTrait for Api {
Expand Down
17 changes: 1 addition & 16 deletions atoma-inference/src/apis/mod.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
pub mod hugging_face;
use hf_hub::api::sync::ApiError as HuggingFaceError;

use std::path::PathBuf;

use thiserror::Error;
use atoma_types::ApiError;

use crate::models::ModelId;

#[derive(Debug, Error)]
pub enum ApiError {
#[error("Api Error: `{0}`")]
ApiError(String),
#[error("HuggingFace API error: `{0}`")]
HuggingFaceError(HuggingFaceError),
}

impl From<HuggingFaceError> for ApiError {
fn from(error: HuggingFaceError) -> Self {
Self::HuggingFaceError(error)
}
}

pub trait ApiTrait: Send {
fn fetch(&self, model_id: ModelId, revision: String) -> Result<Vec<PathBuf>, ApiError>;
fn create(api_key: String, cache_dir: PathBuf) -> Result<Self, ApiError>
Expand Down
7 changes: 2 additions & 5 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use atoma_inference::{
jrpc_server,
models::config::ModelsConfig,
service::{ModelService, ModelServiceError},
};
use atoma_inference::{jrpc_server, models::config::ModelsConfig, service::ModelService};
use atoma_types::ModelServiceError;
use ed25519_consensus::SigningKey as PrivateKey;

#[tokio::main]
Expand Down
50 changes: 10 additions & 40 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,26 @@
use std::{
collections::HashMap, fmt::Debug, path::PathBuf, str::FromStr, sync::mpsc, thread::JoinHandle,
};
use std::{collections::HashMap, path::PathBuf, str::FromStr, sync::mpsc, thread::JoinHandle};

use atoma_types::{Request, Response};
use atoma_types::{ModelThreadError, Request, Response};
use ed25519_consensus::VerificationKey as PublicKey;
use futures::stream::FuturesUnordered;
use thiserror::Error;
use tokio::sync::oneshot::{self, error::RecvError};
use tokio::sync::oneshot;
use tracing::{debug, error, info, warn};

use crate::{
apis::ApiError,
models::{
candle::{
falcon::FalconModel, llama::LlamaModel, mamba::MambaModel, mistral::MistralModel,
mixtral::MixtralModel, quantized::QuantizedModel, stable_diffusion::StableDiffusion,
},
config::{ModelConfig, ModelsConfig},
types::ModelType,
ModelError, ModelId, ModelTrait,
use crate::models::{
candle::{
falcon::FalconModel, llama::LlamaModel, mamba::MambaModel, mistral::MistralModel,
mixtral::MixtralModel, quantized::QuantizedModel, stable_diffusion::StableDiffusion,
},
config::{ModelConfig, ModelsConfig},
types::ModelType,
ModelId, ModelTrait,
};

pub struct ModelThreadCommand {
pub(crate) request: Request,
pub(crate) sender: oneshot::Sender<Response>,
}

#[derive(Debug, Error)]
pub enum ModelThreadError {
#[error("Model thread shutdown: `{0}`")]
ApiError(ApiError),
#[error("Model thread shutdown: `{0}`")]
ModelError(ModelError),
#[error("Core thread shutdown: `{0}`")]
Shutdown(RecvError),
#[error("Serde error: `{0}`")]
SerdeError(#[from] serde_json::Error),
}

impl From<ModelError> for ModelThreadError {
fn from(error: ModelError) -> Self {
Self::ModelError(error)
}
}

impl From<ApiError> for ModelThreadError {
fn from(error: ApiError) -> Self {
Self::ApiError(error)
}
}

pub struct ModelThreadHandle {
sender: mpsc::Sender<ModelThreadCommand>,
join_handle: std::thread::JoinHandle<Result<(), ModelThreadError>>,
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::ModelError;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -15,7 +16,7 @@ use crate::models::{
candle::hub_load_safetensors,
config::ModelConfig,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
ModelTrait,
};

use super::device;
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::ModelError;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -16,7 +17,7 @@ use crate::models::{
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
ModelTrait,
};

use super::{device, hub_load_safetensors};
Expand Down
16 changes: 7 additions & 9 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::{bail, ModelError};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -11,15 +12,12 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;

use crate::{
bail,
models::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
},
use crate::models::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelTrait,
};

pub struct MambaModel {
Expand Down
14 changes: 6 additions & 8 deletions atoma-inference/src/models/candle/mistral.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::str::FromStr;

use atoma_types::{bail, ModelError};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -11,14 +12,11 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;

use crate::{
bail,
models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
},
use crate::models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelTrait,
};

pub struct MistralModel {
Expand Down
14 changes: 6 additions & 8 deletions atoma-inference/src/models/candle/mixtral.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use atoma_types::{bail, ModelError};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -9,14 +10,11 @@ use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;
use tracing::info;

use crate::{
bail,
models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
},
use crate::models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelTrait,
};

pub struct MixtralModel {
Expand Down
5 changes: 1 addition & 4 deletions atoma-inference/src/models/candle/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
use std::{fs::File, io::Write, path::PathBuf};

use atoma_types::{bail, ModelError};
use candle::{
utils::{cuda_is_available, metal_is_available},
DType, Device, Tensor,
};
use tracing::info;

use crate::bail;

use super::ModelError;

pub mod falcon;
pub mod llama;
pub mod mamba;
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/candle/quantized.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr};

use atoma_types::ModelError;
use candle::{
quantized::{ggml_file, gguf_file},
DType, Device, Tensor,
Expand All @@ -17,7 +18,7 @@ use crate::models::{
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
ModelError, ModelTrait,
ModelTrait,
};
use candle_transformers::models::quantized_llama as model;

Expand Down
8 changes: 3 additions & 5 deletions atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::{bail, ModelError};
use candle_transformers::models::stable_diffusion::{
self, clip::ClipTextTransformer, unet_2d::UNet2DConditionModel, vae::AutoEncoderKL,
StableDiffusionConfig,
Expand All @@ -11,11 +12,8 @@ use serde::Deserialize;
use tokenizers::Tokenizer;
use tracing::{debug, info};

use crate::{
bail,
models::{
candle::save_image, config::ModelConfig, types::ModelType, ModelError, ModelId, ModelTrait,
},
use crate::models::{
candle::save_image, config::ModelConfig, types::ModelType, ModelId, ModelTrait,
};

use super::{convert_to_image, device, save_tensor_to_file};
Expand Down
43 changes: 2 additions & 41 deletions atoma-inference/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use std::path::PathBuf;

use ::candle::{DTypeParseError, Error as CandleError};
use self::{config::ModelConfig, types::ModelType};
use atoma_types::ModelError;
use ed25519_consensus::VerificationKey as PublicKey;
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;

use self::{config::ModelConfig, types::ModelType};

pub mod candle;
pub mod config;
Expand Down Expand Up @@ -45,40 +43,3 @@ pub trait Response: Send + 'static {

fn from_model_output(model_output: Self::ModelOutput) -> Self;
}

#[derive(Debug, Error)]
pub enum ModelError {
#[error("Deserialize error: `{0}`")]
DeserializeError(#[from] serde_json::Error),
#[error("{0}")]
Msg(String),
#[error("Candle error: `{0}`")]
CandleError(#[from] CandleError),
#[error("Config error: `{0}`")]
Config(String),
#[error("Image error: `{0}`")]
ImageError(#[from] image::ImageError),
#[error("Io error: `{0}`")]
IoError(#[from] std::io::Error),
#[error("Error: `{0}`")]
BoxedError(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error("ApiError error: `{0}`")]
ApiError(#[from] hf_hub::api::sync::ApiError),
#[error("DTypeParseError: `{0}`")]
DTypeParseError(#[from] DTypeParseError),
#[error("Invalid model type: `{0}`")]
InvalidModelType(String),
}

#[macro_export]
macro_rules! bail {
($msg:literal $(,)?) => {
return Err(ModelError::Msg(format!($msg).into()))
};
($err:expr $(,)?) => {
return Err(ModelError::Msg(format!($err).into()).bt())
};
($fmt:expr, $($arg:tt)*) => {
return Err(ModelError::Msg(format!($fmt, $($arg)*).into()).bt())
};
}
2 changes: 1 addition & 1 deletion atoma-inference/src/models/token_output_stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{bail, models::ModelError};
use atoma_types::{bail, ModelError};

/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
Expand Down
3 changes: 2 additions & 1 deletion atoma-inference/src/models/types.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::{fmt::Display, path::PathBuf, str::FromStr};

use atoma_types::ModelError;
use candle::{DType, Device};
use ed25519_consensus::VerificationKey as PublicKey;
use serde::{Deserialize, Serialize};

use crate::models::{ModelId, Request, Response};

use super::{candle::stable_diffusion::StableDiffusionInput, ModelError};
use super::candle::stable_diffusion::StableDiffusionInput;

pub type NodeId = PublicKey;

Expand Down
Loading