Skip to content

Commit

Permalink
⬆️ Update dependencies (#120)
Browse files Browse the repository at this point in the history
* chore(fsdp): remove a warning not relevant anymore

A warning appeared on on an old version of torch xla (2.3.0), but that
is not supported anymore.

* chore: update jetstream dependency to v0.2.4

* chore(docker): increase ulimit to avoid error

When building the docker container, sometimes an error occurs due to a
"too many files open" error. Increasing the ulimit makes the error
disappear.

* fix(docker): "AS" statement should be uppercase to avoid warning

* chore: update TGI dependency to v2.4.1

Also align Dockerfile to TGI's one.

* chore(docker): update accelerate to v1.1.1

* fix(jetstream): correct Gemma and Mixtral config handling

The config object variable for these models was used by the Jetstream
code, but it does not completely match with HF's config definitions.
This creates a class that heritates from both classes, and makes the
adjustments necessary to avoid errors.

* chore(optimum): remove AutoModelForCausalLM from optimum.tpu

It is still possible to import it importing modeling, but it will reduce
the possibility of importing transformers and torch xla before xla2.

* chore: update torch and torch_xla to v2.5.1

* chore(jetstream): token selector operations are done in torch

Conversions of scores tensors from jax to torch and back are done when
calling logits processor. This will be required in newer versions of
transformers.

* chore(dependencies): update transformers to v4.46.3

* chore: update safetensors to v0.4.5

This is to be coherent with accelerate dependencies, and to update to a
newer version.

* review(mixtral): use properties in config to avoid aliasing ambiguity

Instead of assigning separate variables for Jetstream's config class,
properties are added, resulting in accessing the same data and avoiding
ambiguity.
  • Loading branch information
tengomucho authored Nov 29, 2024
1 parent ffa990d commit a1919c2
Show file tree
Hide file tree
Showing 20 changed files with 99 additions and 84 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-pytorch-xla-tpu-tgi-jetstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
env:
PJRT_DEVICE: TPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
env:
PJRT_DEVICE: TPU
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-pytorch-xla-tpu-tgi-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
env:
PJRT_DEVICE: TPU
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-pytorch-xla-tpu-tgi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
env:
PJRT_DEVICE: TPU
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-pytorch-xla-tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.1_3.10_tpuvm
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
env:
PJRT_DEVICE: TPU
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tpu-tgi-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f
TGI_VERSION="v2.4.1"
- name: Generate artifact attestation for TGI
Expand All @@ -95,7 +95,7 @@ jobs:
labels: ${{ steps.meta-ie.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
TGI_VERSION=0ff6ff60ada291840beed63d8bf458d6f9606f7f
TGI_VERSION="v2.4.1"
target: inference-endpoint


Expand Down
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL))

.PHONY: build_dist style style_check clean

# Ths is essentially v2.3.0 plus a fix to support v2 proto interface
TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f
TGI_VERSION ?= 690702b1ce9a27ce5bdf2a9dd3a80277ecea12cd

rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*))))

Expand All @@ -47,6 +46,7 @@ tpu-tgi:
docker build --rm -f text-generation-inference/docker/Dockerfile \
--build-arg VERSION=$(VERSION) \
--build-arg TGI_VERSION=$(TGI_VERSION) \
--ulimit nofile=100000:100000 \
-t huggingface/optimum-tpu:$(VERSION)-tgi .
docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest

Expand All @@ -55,6 +55,7 @@ tpu-tgi-ie:
--target inference-endpoint \
--build-arg VERSION=$(VERSION) \
--build-arg TGI_VERSION=$(TGI_VERSION) \
--ulimit nofile=100000:100000 \
-t huggingface/optimum-tpu:$(VERSION)-tgi .
docker tag huggingface/optimum-tpu:$(VERSION)-tgi huggingface/optimum-tpu:latest-ie

Expand Down
1 change: 0 additions & 1 deletion optimum/tpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@

from .jetstream_pt_support import jetstream_pt_available # isort:skip
from .fsdp_v2 import get_fsdp_config, use_fsdp_v2
from .modeling import AutoModelForCausalLM
from .version import VERSION, __version__
4 changes: 2 additions & 2 deletions optimum/tpu/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import typer


TORCH_VER = "2.4.0"
JETSTREAM_PT_VER = "02927c9f563082421abe8eedceabe8aedd7ec2f9"
TORCH_VER = "2.5.1"
JETSTREAM_PT_VER = "jetstream-v0.2.4"
DEFAULT_DEPS_PATH = os.path.join(Path.home(), ".jetstream-deps")

app = typer.Typer()
Expand Down
11 changes: 0 additions & 11 deletions optimum/tpu/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
"""
from typing import Any, Dict, List, Union

from transformers.utils import logging


PreTrainedModel = Any
# NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type
Expand Down Expand Up @@ -92,15 +90,6 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
from .modeling_gemma import GemmaForCausalLM

if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
logger = logging.get_logger(__name__)
from torch_xla import __version__ as xla_version

if xla_version == "2.3.0":
logger.warning_once(
"Fine-tuning Gemma on Pytorch XLA 2.3.0 might raise some issues. In case of any "
"issues consider using the nightly version, and report the issue on the optimum-tpu "
"GitHub repository: https://github.com/huggingface/optimum-tpu/issues/new."
)
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ keywords = [
]

dependencies = [
"transformers == 4.41.1",
"torch == 2.4.0",
"torch-xla[tpu] == 2.4.0",
"transformers == 4.46.3",
"torch == 2.5.1",
"torch-xla[tpu] == 2.5.1",
'typer == 0.6.1',
"loguru == 0.6.0",
"sentencepiece == 0.2.0",
Expand All @@ -63,7 +63,7 @@ quality = ["black", "ruff", "isort"]
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt",
"torch-xla[pallas] == 2.4.0"
"torch-xla[pallas] == 2.5.1"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# This is not a complete list of dependencies, but it allows to install torch without CUDA support
--index-url https://download.pytorch.org/whl/cpu
torch==2.4.0
torch==2.5.1
17 changes: 9 additions & 8 deletions text-generation-inference/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1

# Build cargo components (adapted from TGI original Dockerfile)
# Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04)
FROM lukemathwalker/cargo-chef:latest-rust-1.79-bookworm AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.80.1-bookworm AS chef
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

FROM chef as planner
FROM chef AS planner
COPY --from=tgi /tgi/Cargo.toml Cargo.toml
COPY --from=tgi /tgi/Cargo.lock Cargo.lock
COPY --from=tgi /tgi/rust-toolchain.toml rust-toolchain.toml
Expand Down Expand Up @@ -101,12 +101,12 @@ RUN apt-get update -y \
RUN pip install --upgrade pip

# Install HuggingFace packages
ARG TRANSFORMERS_VERSION='4.41.1'
ARG ACCELERATE_VERSION='0.27.2'
ARG SAFETENSORS_VERSION='0.4.2'
ARG TRANSFORMERS_VERSION='4.46.3'
ARG ACCELERATE_VERSION='1.1.1'
ARG SAFETENSORS_VERSION='0.4.5'

# TGI base env
ENV HUGGINGFACE_HUB_CACHE=/data \
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
PORT=80 \
VERSION=${VERSION}
Expand Down Expand Up @@ -134,7 +134,7 @@ RUN pip install dist/text_generation_server*.tar.gz


# TPU compatible image for Inference Endpoints
FROM tpu_base as inference-endpoint
FROM tpu_base AS inference-endpoint

COPY text-generation-inference/docker/entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh
Expand All @@ -145,4 +145,5 @@ ENTRYPOINT ["./entrypoint.sh"]
FROM tpu_base

ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
# This is commented out in the original TGI Dockerfile
# CMD ["--json-output"]
2 changes: 1 addition & 1 deletion text-generation-inference/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pkg_name := text_generation_server
BUILDDIR ?= $(CURDIR)/build
VERSION ?= 0.0.1
TGI_VERSION ?= 0ff6ff60ada291840beed63d8bf458d6f9606f7f
TGI_VERSION ?= "v2.4.1"
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
pkg_dir := $(BUILDDIR)/$(pkg_name)
Expand Down
4 changes: 2 additions & 2 deletions text-generation-inference/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ dependencies = [
'grpcio-reflection == 1.62.1',
'grpc-interceptor == 0.15.2',
'typer == 0.6.1',
'safetensors == 0.4.2',
'transformers == 4.41.1',
'safetensors == 0.4.5',
'transformers == 4.46.3',
'loguru == 0.6.0',
"sentencepiece == 0.2.0",
"numpy<2.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@
from transformers.generation import GenerationConfig

import optimum.tpu.xla_logger as logger
from optimum.tpu import AutoModelForCausalLM
from optimum.tpu.generation import TokenSelector
from optimum.tpu.modeling import AutoModelForCausalLM
from optimum.tpu.static_cache_xla import StaticCacheXla
from optimum.tpu.xla_mp_comm import AgentMailbox, RootMailbox

from .generator_base import Generator
from .pb.generate_pb2 import (
Batch,
CachedBatch,
FinishReason,
GeneratedText,
Generation,
InfoResponse,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
Tokens,
Batch,
CachedBatch,
FinishReason,
GeneratedText,
Generation,
InfoResponse,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
Tokens,
)


Expand Down Expand Up @@ -314,6 +314,9 @@ def __init__(
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
# The token selector will use the model's generation mixin internal variables to select the next token, and it
# expects special tokens to be initialized in the model.
model._prepare_special_tokens(generation_config=model.generation_config, device=model.device)
# Slots are empty to begin with, they will be populated as new batches arrive
self.slots = []
self.batch_id = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def __init__(
tokenizer.truncation_side = "left"
self.tokenizer = tokenizer
self.special_tokens = self.tokenizer.all_special_ids
# The token selector will use the model's generation mixin internal variables to select the next token, and it
# expects special tokens to be initialized in the model.
model = self.engine.pt_model
model._prepare_special_tokens(generation_config=model.generation_config, device='cpu')
# Slots number is static, it cannot grow over the size of the batch
self.slots = [Slot(i, tokenizer) for i in range(self.model.config.batch_size)]
self.batch_id = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from transformers import GemmaConfig, GenerationConfig, GenerationMixin


class GemmaConfigHf(GemmaConfig, gemma_config.GemmaConfig):
"""This class is used to support both the HF GemmaConfig and the Jetstream Pytorch GemmaConfig at the same time.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = None


class GemmaModelHf(GemmaModel, GenerationMixin):
"""Transformer module that uses HF GemmaConfig instead of Jetstream Pytorch GemmaConfig + device.
Expand All @@ -16,24 +25,8 @@ def __init__(
device,
env,
):
self.config = config
self.generation_config = GenerationConfig.from_model_config(config)

args = gemma_config.GemmaConfig(
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
dtype="bfloat16",
quant=False, # No quantization support for now
tokenizer=None,
)

args = GemmaConfigHf(**config.to_dict())
args.device = device
super().__init__(args, env)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,52 @@

from jetstream_pt.third_party.mixtral import config as mixtral_config
from jetstream_pt.third_party.mixtral.model import Transformer
from transformers import GenerationConfig, GenerationMixin, MixtralConfig


class MixtralConfigHf(MixtralConfig, mixtral_config.ModelArgs):
"""This class is used to support both the HF MixtralConfig and the Jetstream Pytorch ModelArgs at the same time."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__post_init__()

@property
def block_size(self):
return self.max_position_embeddings

@property
def n_layer(self):
return self.num_hidden_layers

@property
def n_head(self):
return self.num_attention_heads

@property
def dim(self):
return self.hidden_size

@property
def n_local_heads(self):
return self.num_local_experts or self.num_attention_heads

@property
def num_activated_experts(self):
return self.num_experts_per_tok


class MixtralModelHf(Transformer, GenerationMixin):
"""Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device.
"""
"""Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device."""

def __init__(
self,
config: MixtralConfig,
device,
env,
):
self.config = config
self.generation_config = GenerationConfig.from_model_config(config)

args = mixtral_config.ModelArgs(
block_size=config.max_position_embeddings,
vocab_size=config.vocab_size,
n_layer=config.num_hidden_layers,
n_head=config.num_attention_heads,
dim=config.hidden_size,
intermediate_size=config.intermediate_size,
n_local_heads=config.num_local_experts or config.num_attention_heads,
num_activated_experts=config.num_experts_per_tok,
device=device,
)
args = MixtralConfigHf(**config.to_dict())
args.device = device
super().__init__(args, env)


Expand Down
Loading

0 comments on commit a1919c2

Please sign in to comment.