Skip to content

Commit

Permalink
chore: move modeltype
Browse files Browse the repository at this point in the history
  • Loading branch information
Cifko committed Apr 24, 2024
1 parent c55c334 commit 22b6c7c
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 329 deletions.
3 changes: 1 addition & 2 deletions atoma-inference/src/model_thread.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{collections::HashMap, path::PathBuf, str::FromStr, sync::mpsc, thread::JoinHandle};

use atoma_types::{ModelThreadError, Request, Response};
use atoma_types::{ModelThreadError, ModelType, Request, Response};
use ed25519_consensus::VerificationKey as PublicKey;
use futures::stream::FuturesUnordered;
use tokio::sync::oneshot;
Expand All @@ -12,7 +12,6 @@ use crate::models::{
mixtral::MixtralModel, quantized::QuantizedModel, stable_diffusion::StableDiffusion,
},
config::{ModelConfig, ModelsConfig},
types::ModelType,
ModelId, ModelTrait,
};

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/falcon.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::ModelError;
use atoma_types::{ModelError, ModelType};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -15,7 +15,7 @@ use tracing::{debug, error, info};
use crate::models::{
candle::hub_load_safetensors,
config::ModelConfig,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, TextModelInput, TextModelOutput},
ModelTrait,
};

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/llama.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::ModelError;
use atoma_types::{ModelError, ModelType};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -16,7 +16,7 @@ use tracing::info;
use crate::models::{
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, TextModelInput, TextModelOutput},
ModelTrait,
};

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/mamba.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::{bail, ModelError};
use atoma_types::{bail, ModelError, ModelType};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -16,7 +16,7 @@ use crate::models::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, TextModelInput, TextModelOutput},
ModelTrait,
};

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/mistral.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::str::FromStr;

use atoma_types::{bail, ModelError};
use atoma_types::{bail, ModelError, ModelType};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -15,7 +15,7 @@ use tracing::info;
use crate::models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, TextModelInput, TextModelOutput},
ModelTrait,
};

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/mixtral.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use atoma_types::{bail, ModelError};
use atoma_types::{bail, ModelError, ModelType};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::{
Expand All @@ -13,7 +13,7 @@ use tracing::info;
use crate::models::{
candle::{device, hub_load_safetensors},
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, TextModelInput, TextModelOutput},
ModelTrait,
};

Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/candle/quantized.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, str::FromStr};

use atoma_types::ModelError;
use atoma_types::{ModelError, ModelType};
use candle::{
quantized::{ggml_file, gguf_file},
DType, Device, Tensor,
Expand All @@ -17,7 +17,7 @@ use crate::models::{
candle::device,
config::ModelConfig,
token_output_stream::TokenOutputStream,
types::{LlmLoadData, ModelType, TextModelInput, TextModelOutput},
types::{LlmLoadData, TextModelInput, TextModelOutput},
ModelTrait,
};
use candle_transformers::models::quantized_llama as model;
Expand Down
72 changes: 2 additions & 70 deletions atoma-inference/src/models/candle/stable_diffusion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{path::PathBuf, str::FromStr, time::Instant};

use atoma_types::{bail, ModelError};
use atoma_types::{bail, ModelError, ModelType};
use candle_transformers::models::stable_diffusion::{
self, clip::ClipTextTransformer, unet_2d::UNet2DConditionModel, vae::AutoEncoderKL,
StableDiffusionConfig,
Expand All @@ -12,9 +12,7 @@ use serde::Deserialize;
use tokenizers::Tokenizer;
use tracing::{debug, info};

use crate::models::{
candle::save_image, config::ModelConfig, types::ModelType, ModelId, ModelTrait,
};
use crate::models::{candle::save_image, config::ModelConfig, ModelId, ModelTrait};

use super::{convert_to_image, device, save_tensor_to_file};

Expand Down Expand Up @@ -404,72 +402,6 @@ impl ModelTrait for StableDiffusion {
}
}

impl ModelType {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Self::StableDiffusionV1_5
| Self::StableDiffusionV2_1
| Self::StableDiffusionXl
| Self::StableDiffusionTurbo => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
"unet/diffusion_pytorch_model.safetensors"
}
}
_ => panic!("Invalid stable diffusion model type"),
}
}

fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
Self::StableDiffusionV1_5
| Self::StableDiffusionV2_1
| Self::StableDiffusionXl
| Self::StableDiffusionTurbo => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
"vae/diffusion_pytorch_model.safetensors"
}
}
_ => panic!("Invalid stable diffusion model type"),
}
}

fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
Self::StableDiffusionV1_5
| Self::StableDiffusionV2_1
| Self::StableDiffusionXl
| Self::StableDiffusionTurbo => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
"text_encoder/model.safetensors"
}
}
_ => panic!("Invalid stable diffusion model type"),
}
}

fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
Self::StableDiffusionV1_5
| Self::StableDiffusionV2_1
| Self::StableDiffusionXl
| Self::StableDiffusionTurbo => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
"text_encoder_2/model.safetensors"
}
}
_ => panic!("Invalid stable diffusion model type"),
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::PathBuf;

use self::{config::ModelConfig, types::ModelType};
use atoma_types::ModelError;
use self::config::ModelConfig;
use atoma_types::{ModelError, ModelType};
use ed25519_consensus::VerificationKey as PublicKey;
use serde::{de::DeserializeOwned, Serialize};

Expand Down
Loading

0 comments on commit 22b6c7c

Please sign in to comment.