Skip to content

Commit

Permalink
Merge pull request #14 from atoma-network/20240327-add-candle-features
Browse files Browse the repository at this point in the history
integrates candle features
  • Loading branch information
jorgeantonio21 authored Apr 2, 2024
2 parents 48e7845 + a159880 commit 2929815
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ version = "0.1.0"
[workspace.dependencies]
async-trait = "0.1.78"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" }
ed25519-consensus = "2.1.0"
Expand Down
8 changes: 8 additions & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"
[dependencies]
async-trait.workspace = true
candle.workspace = true
candle-flash-attn = { workspace = true, optional = true }
candle-nn.workspace = true
candle-transformers.workspace = true
ed25519-consensus.workspace = true
Expand All @@ -16,3 +17,10 @@ thiserror.workspace = true
tokenizers.workspace = true
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true

[features]
accelerate = ["candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
6 changes: 4 additions & 2 deletions atoma-inference/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ pub enum ModelConfig {
Llama(LlamaConfig),
Llama2(Llama2Config),
Mamba(MambaConfig),
Mixtral8x7b(MixtralConfig),
Mistral(MistralConfig),
StableDiffusion(StableDiffusionConfig),
Mixtral8x7b(Box<MixtralConfig>),
StableDiffusion(Box<StableDiffusionConfig>),
}

impl From<ModelType> for ModelConfig {
Expand All @@ -74,6 +74,8 @@ pub enum ModelCache {

pub trait ModelApi {
fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self;

#[allow(clippy::too_many_arguments)]
fn run(
&self,
input: String,
Expand Down

0 comments on commit 2929815

Please sign in to comment.