-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🍃 Added Mixtral on TGI / Jetstream Pytorch (#103)
* refactor(Jetstream Pt): move modeling to separate directory * feat(Jetstream Pt): Add initial support for mixtral Note that for now serving Mixtral-8X7B is very hard due to the large amount of resources available. * ci(slow tests): better split; add Mixtral testing Tests are split to avoid a memory problem appearing when serving different models in subsequent tests when using Jetstream/Pytorch. Here we better clarify the steps of this workflow and add Mixtral test to nightly workflow. * review: remove commented import
- Loading branch information
1 parent
f42a228
commit 16596de
Showing
9 changed files
with
73 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
...eneration-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .gemma_model_hf import GemmaModelHf as GemmaModel | ||
from .llama_model_exportable_hf import TransformerHf as LlamaModel | ||
from .mixtral_model_hf import MixtralModelHf as MixtralModel |
2 changes: 0 additions & 2 deletions
2
...er/jetstream_pt_support/gemma_model_hf.py → ...tream_pt_support/models/gemma_model_hf.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
38 changes: 38 additions & 0 deletions
38
...n-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
|
||
from jetstream_pt.third_party.mixtral import config as mixtral_config | ||
from jetstream_pt.third_party.mixtral.model import Transformer | ||
from transformers import GenerationConfig, GenerationMixin, MixtralConfig | ||
|
||
|
||
class MixtralModelHf(Transformer, GenerationMixin): | ||
"""Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: MixtralConfig, | ||
device, | ||
env, | ||
): | ||
self.config = config | ||
self.generation_config = GenerationConfig.from_model_config(config) | ||
|
||
args = mixtral_config.ModelArgs( | ||
block_size=config.max_position_embeddings, | ||
vocab_size=config.vocab_size, | ||
n_layer=config.num_hidden_layers, | ||
n_head=config.num_attention_heads, | ||
dim=config.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
n_local_heads=config.num_local_experts or config.num_attention_heads, | ||
num_activated_experts=config.num_experts_per_tok, | ||
device=device, | ||
) | ||
super().__init__(args, env) | ||
|
||
|
||
@classmethod | ||
def from_config(cls, config, env): | ||
device = "meta" | ||
model = cls(config, device, env) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters