From 16596de248a8fc7a9e7e0d3667b055c9a120a770 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Fri, 11 Oct 2024 09:59:56 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=8D=83=20Added=20Mixtral=20on=20TGI=20/?= =?UTF-8?q?=20Jetstream=20Pytorch=20(#103)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(Jetstream Pt): move modeling to separate directory * feat(Jetstream Pt): Add initial support for mixtral Note that for now serving Mixtral-8X7B is very hard due to the large amount of resources available. * ci(slow tests): better split; add Mixtral testing Tests are split to avoid a memory problem appearing when serving different models in subsequent tests when using Jetstream/Pytorch. Here we better clarify the steps of this workflow and add Mixtral test to nightly workflow. * review: remove commented import --- ...-pytorch-xla-tpu-tgi-nightly-jetstream.yml | 12 +++++- .../server/pyproject.toml | 7 +++- .../jetstream_pt_support/compatibility.py | 2 +- .../jetstream_pt_support/engine_loader.py | 5 ++- .../jetstream_pt_support/models/__init__.py | 3 ++ .../{ => models}/gemma_model_hf.py | 2 - .../{ => models}/llama_model_exportable_hf.py | 0 .../models/mixtral_model_hf.py | 38 +++++++++++++++++++ .../tests/test_decode.py | 14 ++++++- 9 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py rename text-generation-inference/server/text_generation_server/jetstream_pt_support/{ => models}/gemma_model_hf.py (96%) rename text-generation-inference/server/text_generation_server/jetstream_pt_support/{ => models}/llama_model_exportable_hf.py (100%) create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml index f85dc1c0..84d44db7 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi-nightly-jetstream.yml @@ -28,11 +28,19 @@ jobs: - name: Checkout uses: actions/checkout@v4 - - name: Run TGI tests - Jetstream Pytorch (also slow tests) + - name: Build and install Jetstream Pytorch TGI run: | make jetstream_requirements tgi_server test_installs find text-generation-inference/ -name "text_generation_server-*whl" -exec python -m pip install {} \; + - name: Run TGI Jetstream Pytorch - Llama + run: | JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \ pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and Llama" + - name: Run TGI Jetstream Pytorch - Gemma + run: | + JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \ + pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and gemma" + - name: Run TGI Jetstream Pytorch - Mixtral greedy + run: | JETSTREAM_PT=1 HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} python -m \ - pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and not LLama" + pytest -sv text-generation-inference/tests --runslow -k "jetstream and slow and Mixtral and greedy" diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml index 5a3d4070..a10727b8 100644 --- a/text-generation-inference/server/pyproject.toml +++ b/text-generation-inference/server/pyproject.toml @@ -22,7 +22,12 @@ dependencies = [ ] [tool.setuptools] -packages = ["text_generation_server", "text_generation_server.pb", "text_generation_server.jetstream_pt_support"] +packages = [ + "text_generation_server", + "text_generation_server.pb", + "text_generation_server.jetstream_pt_support", + "text_generation_server.jetstream_pt_support.models", +] [tool.setuptools.dynamic] version = {attr = "text_generation_server.version.__version__"} diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py index dc8b7406..d1fc325d 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py @@ -25,7 +25,7 @@ def model_can_use_jetstream_pt(model_path: str) -> bool: """ config = AutoConfig.from_pretrained(model_path) # For now few models are supported - supported_models = ["llama", "gemma"] + supported_models = ["llama", "gemma", "mixtral"] if config.model_type not in supported_models: return False if jetstream_pt_available(): diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index c89e722f..f36e4c82 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -19,8 +19,7 @@ from transformers import AutoConfig from .compatibility import model_can_use_jetstream_pt -from .gemma_model_hf import GemmaModelHf as GemmaModel -from .llama_model_exportable_hf import TransformerHf as LlamaModel +from .models import GemmaModel, LlamaModel, MixtralModel def _get_head_dim(config: "PretrainedConfig") -> int: @@ -38,6 +37,8 @@ def load_model_info(config: "PretrainedConfig") -> Any: model_class = LlamaModel elif config.model_type == "gemma": model_class = GemmaModel + elif config.model_type == "mixtral": + model_class = MixtralModel else: raise ValueError(f"Unsupported model type {config.model_type}") model_info = fetch_models.ModelInfo( diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py new file mode 100644 index 00000000..9855bde6 --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py @@ -0,0 +1,3 @@ +from .gemma_model_hf import GemmaModelHf as GemmaModel +from .llama_model_exportable_hf import TransformerHf as LlamaModel +from .mixtral_model_hf import MixtralModelHf as MixtralModel diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/gemma_model_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py similarity index 96% rename from text-generation-inference/server/text_generation_server/jetstream_pt_support/gemma_model_hf.py rename to text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py index e4760113..e61788d8 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/gemma_model_hf.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/gemma_model_hf.py @@ -1,8 +1,6 @@ from jetstream_pt.third_party.gemma import config as gemma_config from jetstream_pt.third_party.gemma.model import GemmaModel - -#.model_exportable import Transformer, model_args from transformers import GemmaConfig, GenerationConfig, GenerationMixin diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/llama_model_exportable_hf.py similarity index 100% rename from text-generation-inference/server/text_generation_server/jetstream_pt_support/llama_model_exportable_hf.py rename to text-generation-inference/server/text_generation_server/jetstream_pt_support/models/llama_model_exportable_hf.py diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py new file mode 100644 index 00000000..0e476a9a --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/mixtral_model_hf.py @@ -0,0 +1,38 @@ + +from jetstream_pt.third_party.mixtral import config as mixtral_config +from jetstream_pt.third_party.mixtral.model import Transformer +from transformers import GenerationConfig, GenerationMixin, MixtralConfig + + +class MixtralModelHf(Transformer, GenerationMixin): + """Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device. + """ + + def __init__( + self, + config: MixtralConfig, + device, + env, + ): + self.config = config + self.generation_config = GenerationConfig.from_model_config(config) + + args = mixtral_config.ModelArgs( + block_size=config.max_position_embeddings, + vocab_size=config.vocab_size, + n_layer=config.num_hidden_layers, + n_head=config.num_attention_heads, + dim=config.hidden_size, + intermediate_size=config.intermediate_size, + n_local_heads=config.num_local_experts or config.num_attention_heads, + num_activated_experts=config.num_experts_per_tok, + device=device, + ) + super().__init__(args, env) + + + @classmethod + def from_config(cls, config, env): + device = "meta" + model = cls(config, device, env) + return model diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index ab29fe43..362535b0 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -122,8 +122,13 @@ def _test_decode_single(params): sequence_length=128, expected_text="\n\nThe time is 1984. The place is Airstrip One, the British", ), + DecodeTestParams( + model_id="mistralai/Mixtral-8x7B-v0.1", + sequence_length=1024, + expected_text="\n\nGeorge Orwell, 1984\n\nThe clocks are striking thirteen", + ), ], - ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b"], + ids=["Llama-2-7b-hf", "Meta-Llama-3-8B", "gemma-7b", "Mixtral-8x7B"], ) def test_decode_single_jetstream_pytorch_slow(params, do_sample): if not jetstream_pt_available(): @@ -146,8 +151,13 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample): sequence_length=1024, expected_text="\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never", ), + DecodeTestParams( + model_id="dacorvo/Mixtral-tiny", # This is a random tiny model, just to test model can be loaded. + sequence_length=512, + expected_text="манaminationVariableßer Rog malesazine longふ Toy Champions enero Facereverse▲verbose prosecut literally disappearedअ", + ), ], - ids=["TinyLLama-v0", "gemma-2b"], + ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny"], ) def test_decode_single_jetstream_pytorch(params, do_sample): if not jetstream_pt_available():