Skip to content

Commit

Permalink
Fixing gemma2. (#2135)
Browse files Browse the repository at this point in the history
* Fixing gemma2.

* Adding new model.
  • Loading branch information
Narsil authored Jun 27, 2024
1 parent 0e4ab6d commit 3ea8259
Show file tree
Hide file tree
Showing 7 changed files with 622 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/source/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)
Expand Down
30 changes: 30 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
from text_generation_server.models.flash_gemma import (
FlashGemma,
)
from text_generation_server.models.flash_gemma2 import (
FlashGemma2,
)
from text_generation_server.models.pali_gemma import (
PaliGemma,
)
Expand Down Expand Up @@ -102,6 +105,7 @@
__all__.append(FlashQwen2)
__all__.append(FlashStarcoder2)
__all__.append(FlashGemma)
__all__.append(FlashGemma2)
__all__.append(FlashCohere)

MAMBA_AVAILABLE = True
Expand Down Expand Up @@ -143,6 +147,11 @@ class ModelType(enum.Enum):
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/google/gemma2-9b",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
Expand Down Expand Up @@ -630,6 +639,27 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == GEMMA2:
if FLASH_ATTENTION:
return FlashGemma2(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

if model_type == COHERE:
if FLASH_ATTENTION:
Expand Down
Loading

0 comments on commit 3ea8259

Please sign in to comment.