Skip to content

Commit

Permalink
🍃 Added Mixtral on TGI / Jetstream Pytorch (#103)
Browse files Browse the repository at this point in the history
* refactor(Jetstream Pt): move modeling to separate directory

* feat(Jetstream Pt): Add initial support for mixtral

Note that for now serving Mixtral-8X7B is very hard due to the large
amount of resources available.

* ci(slow tests): better split; add Mixtral testing

Tests are split to avoid a memory problem appearing when serving
different models in subsequent tests when using Jetstream/Pytorch.
Here we better clarify the steps of this workflow and add Mixtral test
to nightly workflow.

* review: remove commented import
  • Loading branch information
tengomucho authored Oct 11, 2024
1 parent f42a228 commit 16596de
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 10 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@ jobs:
- name: Checkout
uses: actions/checkout@v4

- name: Run TGI tests - Jetstream Pytorch (also slow tests)
- name: Build and install Jetstream Pytorch TGI
run: |
make jetstream_requirements tgi_server test_installs
find text-generation-inference/ -name "text_generation_server-*whl" -exec python -m pip install {} \;
- name: Run TGI Jetstream Pytorch - Llama
run: |
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and Llama"
- name: Run TGI Jetstream Pytorch - Gemma
run: |
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and gemma"
- name: Run TGI Jetstream Pytorch - Mixtral greedy
run: |
JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \
pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and not LLama"
pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and Mixtral and greedy"
7 changes: 6 additions & 1 deletion text-generation-inference/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ dependencies = [
]

[tool.setuptools]
packages = ["text_generation_server", "text_generation_server.pb", "text_generation_server.jetstream_pt_support"]
packages = [
"text_generation_server",
"text_generation_server.pb",
"text_generation_server.jetstream_pt_support",
"text_generation_server.jetstream_pt_support.models",
]

[tool.setuptools.dynamic]
version = {attr = "text_generation_server.version.__version__"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def model_can_use_jetstream_pt(model_path: str) -> bool:
"""
config = AutoConfig.from_pretrained(model_path)
# For now few models are supported
supported_models = ["llama", "gemma"]
supported_models = ["llama", "gemma", "mixtral"]
if config.model_type not in supported_models:
return False
if jetstream_pt_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from transformers import AutoConfig

from .compatibility import model_can_use_jetstream_pt
from .gemma_model_hf import GemmaModelHf as GemmaModel
from .llama_model_exportable_hf import TransformerHf as LlamaModel
from .models import GemmaModel, LlamaModel, MixtralModel


def _get_head_dim(config: "PretrainedConfig") -> int:
Expand All @@ -38,6 +37,8 @@ def load_model_info(config: "PretrainedConfig") -> Any:
model_class = LlamaModel
elif config.model_type == "gemma":
model_class = GemmaModel
elif config.model_type == "mixtral":
model_class = MixtralModel
else:
raise ValueError(f"Unsupported model type {config.model_type}")
model_info = fetch_models.ModelInfo(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gemma_model_hf import GemmaModelHf as GemmaModel
from .llama_model_exportable_hf import TransformerHf as LlamaModel
from .mixtral_model_hf import MixtralModelHf as MixtralModel
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@

from jetstream_pt.third_party.gemma import config as gemma_config
from jetstream_pt.third_party.gemma.model import GemmaModel

#.model_exportable import Transformer, model_args
from transformers import GemmaConfig, GenerationConfig, GenerationMixin


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

from jetstream_pt.third_party.mixtral import config as mixtral_config
from jetstream_pt.third_party.mixtral.model import Transformer
from transformers import GenerationConfig, GenerationMixin, MixtralConfig


class MixtralModelHf(Transformer, GenerationMixin):
"""Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device.
"""

def __init__(
self,
config: MixtralConfig,
device,
env,
):
self.config = config
self.generation_config = GenerationConfig.from_model_config(config)

args = mixtral_config.ModelArgs(
block_size=config.max_position_embeddings,
vocab_size=config.vocab_size,
n_layer=config.num_hidden_layers,
n_head=config.num_attention_heads,
dim=config.hidden_size,
intermediate_size=config.intermediate_size,
n_local_heads=config.num_local_experts or config.num_attention_heads,
num_activated_experts=config.num_experts_per_tok,
device=device,
)
super().__init__(args, env)


@classmethod
def from_config(cls, config, env):
device = "meta"
model = cls(config, device, env)
return model
14 changes: 12 additions & 2 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ def _test_decode_single(params):
sequence_length=128,
expected_text="\n\nThe time is 1984. The place is Airstrip One, the British",
),
DecodeTestParams(
model_id="mistralai/Mixtral-8x7B-v0.1",
sequence_length=1024,
expected_text="\n\nGeorge Orwell, 1984\n\nThe clocks are striking thirteen",
),
],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b"],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"],
)
def test_decode_single_jetstream_pytorch_slow(params, do_sample):
if not jetstream_pt_available():
Expand All @@ -146,8 +151,13 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
sequence_length=1024,
expected_text="\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never",
),
DecodeTestParams(
model_id="dacorvo/Mixtral-tiny", # This is a random tiny model, just to test model can be loaded.
sequence_length=512,
expected_text="манaminationVariableßer Rog malesazine longふ Toy Champions enero Facereverse▲verbose prosecut literally disappearedअ",
),
],
ids=["TinyLLama-v0", "gemma-2b"],
ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"],
)
def test_decode_single_jetstream_pytorch(params, do_sample):
if not jetstream_pt_available():
Expand Down

0 comments on commit 16596de

Please sign in to comment.