From 1065a8eb5599093b0dbf9f700f8990ef28cd3e45 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:58:48 +0200 Subject: [PATCH] Test CausalLM generate & pipeline (#110) * add tests & standardize generate * fix bugs & pipeline * clean engine before test * cleanup * style * test pipelines * Quality * fix tests * style * run nvidia-smi * fix CI - can't compile for 32000 generated tokens on A10G * fewer new tokens * still fewer tokens or tests won't pass * lets test 7b gemma as well --------- Co-authored-by: Morgan Funtowicz --- .github/workflows/pr_functional_tests.yml | 4 + .github/workflows/pr_integration_tests.yml | 4 + examples/text-generation.py | 4 +- src/optimum/nvidia/hub.py | 2 +- src/optimum/nvidia/models/gemma.py | 1 + src/optimum/nvidia/models/mistral.py | 1 + src/optimum/nvidia/models/whisper.py | 18 +- .../nvidia/pipelines/text_generation.py | 23 +-- src/optimum/nvidia/runtime.py | 182 +++++++++++++----- tests/integration/test_causal_lm.py | 145 ++++++++++++++ tests/integration/test_whisper.py | 21 +- tests/integration/utils_testing.py | 41 ++++ 12 files changed, 354 insertions(+), 92 deletions(-) create mode 100644 tests/integration/test_causal_lm.py create mode 100644 tests/integration/utils_testing.py diff --git a/.github/workflows/pr_functional_tests.yml b/.github/workflows/pr_functional_tests.yml index 0043bbdf..5f1d0d6d 100644 --- a/.github/workflows/pr_functional_tests.yml +++ b/.github/workflows/pr_functional_tests.yml @@ -55,6 +55,10 @@ jobs: run: | python3 -m pip install --upgrade -e .[quality,tests] + - name: Run nvidia-smi + run: | + nvidia-smi + - name: Run optimum-nvidia functional test-suite run: | pytest -n 4 -s -vvvvv -p no:warnings -o log_cli=true --ignore=tests/integration/ tests/ diff --git a/.github/workflows/pr_integration_tests.yml b/.github/workflows/pr_integration_tests.yml index b2d79996..ffc08ac9 100644 --- a/.github/workflows/pr_integration_tests.yml +++ b/.github/workflows/pr_integration_tests.yml @@ -56,6 +56,10 @@ jobs: run: | python3 -m pip install --upgrade -e .[quality,tests] + - name: Run nvidia-smi + run: | + nvidia-smi + - name: Run optimum-nvidia integration test-suite run: | pytest -s -vvvvv -n 1 -p no:warnings -o log_cli=true tests/integration/ \ No newline at end of file diff --git a/examples/text-generation.py b/examples/text-generation.py index 4379a8a4..6b837837 100644 --- a/examples/text-generation.py +++ b/examples/text-generation.py @@ -79,7 +79,5 @@ max_new_tokens=args.max_new_tokens, ) - generated_text = tokenizer.batch_decode( - generated.flatten(0, 1), skip_special_tokens=True - ) + generated_text = tokenizer.batch_decode(generated, skip_special_tokens=True) print(generated_text) diff --git a/src/optimum/nvidia/hub.py b/src/optimum/nvidia/hub.py index b4f9717a..252ef066 100644 --- a/src/optimum/nvidia/hub.py +++ b/src/optimum/nvidia/hub.py @@ -403,7 +403,7 @@ def _save_pretrained(self, save_directory: Path) -> None: "Please open-up an issue at https://github.com/huggingface/optimum-nvidia" ) - self.transformers_config.save_pretrained(save_directory) + self.config.save_pretrained(save_directory) if self.generation_config is not None: self.generation_config.save_pretrained(save_directory) diff --git a/src/optimum/nvidia/models/gemma.py b/src/optimum/nvidia/models/gemma.py index dd3ea867..8f106117 100644 --- a/src/optimum/nvidia/models/gemma.py +++ b/src/optimum/nvidia/models/gemma.py @@ -75,6 +75,7 @@ def from_config(config: TransformersPretrainedConfig) -> "TensorRTConfig": share_embedding_table=False, max_lora_rank=64, quantization=qconfig, + rotary_base=config.rope_theta, ) trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) diff --git a/src/optimum/nvidia/models/mistral.py b/src/optimum/nvidia/models/mistral.py index 2cdbab85..fc54eb32 100644 --- a/src/optimum/nvidia/models/mistral.py +++ b/src/optimum/nvidia/models/mistral.py @@ -74,6 +74,7 @@ def from_config(config: TransformersPretrainedConfig) -> "TensorRTConfig": max_lora_rank=64, head_size=config.hidden_size / config.num_attention_heads, quantization=qconfig, + rotary_base=config.rope_theta, ) trt_config.mapping.gpus_per_node = min(trt_config.mapping.world_size, 8) diff --git a/src/optimum/nvidia/models/whisper.py b/src/optimum/nvidia/models/whisper.py index aa179f74..9219de82 100644 --- a/src/optimum/nvidia/models/whisper.py +++ b/src/optimum/nvidia/models/whisper.py @@ -689,7 +689,7 @@ def __init__( generation_config = GenerationConfig() self.generation_config = generation_config - self.transformers_config = transformers_config + self.config = transformers_config # Encoder. serialize_path = engines_folders[0] / "rank0.engine" @@ -1078,7 +1078,7 @@ def generate( def raise_unsupported(value: Any, name: str, default: Any = None): if value != default: raise ValueError( - f"TensorRTForSpeechSeq2Seq.generate does not support {name} (got {value}). Please open an issue at https://github.com/huggingface/optimum-nvidia/issues." + f"TensorRTForSpeechSeq2Seq.generate does not support the argument {name} (got {name}={value}). Please open an issue at https://github.com/huggingface/optimum-nvidia/issues." ) raise_unsupported(stopping_criteria, name="stopping_criteria") @@ -1109,7 +1109,7 @@ def raise_unsupported(value: Any, name: str, default: Any = None): ) self._set_token_ids( generation_config=generation_config, - config=self.transformers_config, + config=self.config, kwargs=kwargs, ) self._set_thresholds_and_condition( @@ -1126,9 +1126,7 @@ def raise_unsupported(value: Any, name: str, default: Any = None): batch_size, total_input_frames = self._retrieve_total_input_frames( input_features=inputs, input_stride=input_stride, kwargs=kwargs ) - num_segment_frames = ( - input_stride * self.transformers_config.max_source_positions - ) + num_segment_frames = input_stride * self.config.max_source_positions is_shortform = total_input_frames <= num_segment_frames if not is_shortform: raise ValueError( @@ -1138,7 +1136,7 @@ def raise_unsupported(value: Any, name: str, default: Any = None): init_tokens = self._retrieve_init_tokens( inputs, generation_config=generation_config, - config=self.transformers_config, + config=self.config, num_segment_frames=num_segment_frames, kwargs=kwargs, ) @@ -1171,16 +1169,16 @@ def raise_unsupported(value: Any, name: str, default: Any = None): if ( max_new_tokens + decoder_input_ids.shape[-1] - > self.transformers_config.max_target_positions + > self.config.max_target_positions ): max_new_tokens = kwargs.get("max_new_tokens", 0) raise ValueError( f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " f"is {max_new_tokens}. Thus, the combined length of " f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " - f"`max_target_positions` of the Whisper model: {self.transformers_config.max_target_positions}. " + f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less than {self.transformers_config.max_target_positions}." + f"so that their combined length is less than {self.config.max_target_positions}." ) encoder_input_lengths = torch.tensor( diff --git a/src/optimum/nvidia/pipelines/text_generation.py b/src/optimum/nvidia/pipelines/text_generation.py index 6a7ed558..c9e6e821 100644 --- a/src/optimum/nvidia/pipelines/text_generation.py +++ b/src/optimum/nvidia/pipelines/text_generation.py @@ -38,9 +38,6 @@ class TextGenerationPipeline(Pipeline): __slots__ = ( "tokenizer", "_runtime", - "_bos_token_id", - "_eos_token_id", - "_pad_token_id", ) def __init__(self, model: CausalLM, tokenizer: PreTrainedTokenizer): @@ -52,16 +49,15 @@ def __init__(self, model: CausalLM, tokenizer: PreTrainedTokenizer): self.tokenizer = tokenizer self._runtime = model - self._bos_token_id = tokenizer.bos_token_id - self._eos_token_id = tokenizer.eos_token_id - self._pad_token_id = tokenizer.pad_token_id - - def __call__(self, inputs: Union[str, List[str]], **kwargs): + def __call__( + self, inputs: Union[str, List[str]], add_special_tokens: bool = True, **kwargs + ): ( preprocess_params, forward_params, postprocess_params, - ) = self._sanitize_parameters(**kwargs) + ) = self._sanitize_parameters(add_special_tokens=add_special_tokens, **kwargs) + model_inputs = self.preprocess(inputs, **preprocess_params) model_outputs = self._forward(model_inputs, **forward_params) outputs = self.postprocess(model_outputs, **postprocess_params) @@ -147,7 +143,7 @@ def _forward(self, model_inputs, **generate_kwargs): prompt_text = model_inputs.pop("prompt_text") attention_mask = model_inputs.get("attention_mask", None) - max_new_tokens = generate_kwargs.pop("max_new_tokens", -1) + max_new_tokens = generate_kwargs.pop("max_new_tokens", None) min_length = generate_kwargs.pop("min_length", -1) num_beams = generate_kwargs.pop("num_beams", 1) temperature = generate_kwargs.pop("temperature", 1.0) @@ -188,9 +184,6 @@ def _forward(self, model_inputs, **generate_kwargs): repetition_penalty=repetition_penalty, length_penalty=length_penalty, seed=seed, - bos_token_id=self._bos_token_id, - eos_token_id=self._eos_token_id, - pad_token_id=self._pad_token_id, ) return { @@ -243,13 +236,13 @@ def postprocess( for sequence in generated_sequence: # Decode text - beam_text = self.tokenizer.batch_decode( + text = self.tokenizer.decode( sequence, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) - record = {"generated_text": beam_text} + record = {"generated_text": text} records.append(record) return records diff --git a/src/optimum/nvidia/runtime.py b/src/optimum/nvidia/runtime.py index e908f144..5e1aa14a 100644 --- a/src/optimum/nvidia/runtime.py +++ b/src/optimum/nvidia/runtime.py @@ -13,19 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import warnings from logging import getLogger from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union import tensorrt_llm.bindings as ctrrt import torch from transformers import GenerationConfig +from transformers.generation.utils import GenerationMixin if TYPE_CHECKING: - from transformers import PretrainedConfig + from transformers import PretrainedConfig, PreTrainedModel + from transformers.generation.logits_process import LogitsProcessorList + from transformers.generation.stopping_criteria import StoppingCriteriaList + from transformers.generation.streamers import BaseStreamer LOGGER = getLogger(__name__) @@ -53,7 +58,9 @@ def engine_path(self) -> Path: return self._engines_folders_path -class CausalLM(CompiledModel): +class CausalLM(CompiledModel, GenerationMixin): + main_input_name = "input_ids" + __slots__ = ( "_device", "_config", @@ -95,16 +102,28 @@ def __init__( max_beam_width=self._config.model_config.max_beam_width, max_sequence_length=self._config.model_config.max_seq_len, ) + self._session_config.cuda_graph_mode = use_cuda_graph # Create the engine engine_file = self._config.engine_filename(self._mapping) - self._session = ctrrt.GptSession( - config=self._session_config, - model_config=self._config.model_config, - world_config=self._mapping, - engine_file=str(engines_folder.joinpath(engine_file)), - ) + + try: + self._session = ctrrt.GptSession( + config=self._session_config, + model_config=self._config.model_config, + world_config=self._mapping, + engine_file=str(engines_folder.joinpath(engine_file)), + ) + except RuntimeError as e: + if "maxTokensInPagedKvCache" in repr( + e + ) and "must be large enough to process at least 1 sequence" in repr(e): + raise RuntimeError( + f"Could not initialize TensorRT-LLM decoder session, likely due a large maximum output length set at compilation time (max_output_len={self._config.model_config.max_seq_len}). Please try and set a lower value for `max_output_length` when building the engine. Error: {e}" + ) + else: + raise e # Additional cached properties self._use_packed_inputs = self._config.model_config.use_packed_input @@ -118,61 +137,131 @@ def __init__( generation_config = GenerationConfig() self.generation_config = generation_config - self.transformers_config = transformers_config - - @property - def config(self) -> ctrrt.GptJsonConfig: - return self._config + # Required for GenerationMixin compatibility. + self.config = transformers_config + @torch.no_grad() def generate( self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - max_new_tokens: int = -1, - min_length: int = -1, - num_beams: int = 1, - temperature: float = 1.0, - top_k: int = 50, - top_p: float = 1.0, - repetition_penalty: float = 1.0, - length_penalty: float = 1.0, - seed: int = 0, - pad_token_id: int = 0, - bos_token_id: int = 1, - eos_token_id: int = 2, - ) -> Tuple[torch.Tensor, torch.Tensor]: + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional["LogitsProcessorList"] = None, + stopping_criteria: Optional["StoppingCriteriaList"] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.LongTensor: + def raise_unsupported(value: Any, name: str, default: Any = None): + if value != default: + raise ValueError( + f"{self.__class__.__name__}.generate does not support the argument {name} (got {name}={value}). Please open an issue at https://github.com/huggingface/optimum-nvidia/issues." + ) + + raise_unsupported(stopping_criteria, name="stopping_criteria") + raise_unsupported(prefix_allowed_tokens_fn, name="prefix_allowed_tokens_fn") + raise_unsupported(synced_gpus, name="synced_gpus") + raise_unsupported(logits_processor, name="logits_processor") + raise_unsupported(assistant_model, name="assistant_model") + raise_unsupported(streamer, name="streamer") + raise_unsupported(negative_prompt_ids, name="negative_prompt_ids") + raise_unsupported( + negative_prompt_attention_mask, name="negative_prompt_attention_mask" + ) + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, + # three conditions must be met + # 1) the generation config must have been created from the model config (`_from_model_config` field); + # 2) the generation config must have seen no modification since its creation (the hash is the same); + # 3) the user must have set generation parameters in the model config. + if ( + self.generation_config._from_model_config + and self.generation_config._original_object_hash + == hash(self.generation_config) + and self.config._has_non_default_generation_parameters() + ): + new_generation_config = GenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use and modify the model generation configuration (see" + " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update( + **kwargs + ) # All unused kwargs must be model kwargs + + if ( + generation_config.pad_token_id is None + and generation_config.eos_token_id is not None + ): + if model_kwargs.get("attention_mask", None) is None: + LOGGER.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + LOGGER.warning( + f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." + ) + generation_config.pad_token_id = eos_token_id + device = self._device + seed = model_kwargs.pop("seed", 42) # If no GenerationConfig is provided, let's allocate one with default settings - generation_config = ctrrt.SamplingConfig(min(num_beams, self.max_beam_width)) - generation_config.random_seed = [seed] - generation_config.temperature = [temperature] - generation_config.top_k = [top_k] - generation_config.top_p = [top_p] - generation_config.repetition_penalty = [repetition_penalty] - generation_config.length_penalty = [length_penalty] - - if min_length > 0: - generation_config.min_length = [min_length] + sampling_config = ctrrt.SamplingConfig( + min(generation_config.num_beams, self.max_beam_width) + ) + sampling_config.random_seed = [seed] + sampling_config.temperature = [generation_config.temperature] + sampling_config.top_k = [generation_config.top_k] + sampling_config.top_p = [generation_config.top_p] + sampling_config.repetition_penalty = [generation_config.repetition_penalty] + sampling_config.length_penalty = [generation_config.length_penalty] + + if generation_config.min_new_tokens is not None: + sampling_config.min_length = [generation_config.min_new_tokens] + + input_ids, _, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) with torch.no_grad(): if not isinstance(input_ids, torch.Tensor): raise TypeError("input_ids should be a PyTorch tensor (torch.Tensor)") + attention_mask = model_kwargs["attention_mask"] input_ids, lengths = self._prepare_inputs(input_ids, attention_mask) if torch.any(torch.gt(lengths, self.max_prompt_length)): raise ValueError( f"Input length {lengths} is bigger than maximum prompt length ({self.max_prompt_length})." ) + input_length = input_ids.shape[1] trt_inputs = ctrrt.GenerationInput( - end_id=eos_token_id, - pad_id=pad_token_id, + end_id=generation_config.eos_token_id, + pad_id=generation_config.pad_token_id, ids=input_ids.to(device), lengths=lengths.to(device), packed=self._use_packed_inputs, ) + max_new_tokens = generation_config.max_new_tokens if max_new_tokens is None or max_new_tokens < 1: max_new_tokens = self.max_output_length - input_ids.shape[1] @@ -184,9 +273,16 @@ def generate( lengths=torch.empty(0, device=device, dtype=torch.int32), ) - self._session.generate(trt_outputs, trt_inputs, generation_config) + self._session.generate(trt_outputs, trt_inputs, sampling_config) + + total_length = trt_outputs.lengths + output_ids = trt_outputs.ids.flatten(0, 1) + + # For some reason not in line with Transformers in case we finish early with BOS token (missing last BOS token). + if total_length - input_length < max_new_tokens: + total_length += 1 - return trt_outputs.ids, trt_outputs.lengths + return output_ids[:, :total_length], total_length def _prepare_inputs( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None diff --git a/tests/integration/test_causal_lm.py b/tests/integration/test_causal_lm.py new file mode 100644 index 00000000..4a4d9cf6 --- /dev/null +++ b/tests/integration/test_causal_lm.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import pytest +import torch +from transformers import AutoModelForCausalLM as TransformersAutoModelForCausalLM +from transformers import AutoTokenizer +from transformers import pipeline as transformers_pipeline +from utils_testing import clean_cached_engines_for_model + +from optimum.nvidia import AutoModelForCausalLM +from optimum.nvidia.pipelines import pipeline + + +MODEL_MAP = { + "gemma": ["google/gemma-2b-it", "google/gemma-7b-it"], + "llama": "meta-llama/Llama-2-7b-chat-hf", + "mistral": "mistralai/Mistral-7B-Instruct-v0.2", +} + + +@pytest.mark.parametrize("model_type", MODEL_MAP.keys()) +def test_generation(model_type: str): + model_ids = ( + [MODEL_MAP[model_type]] + if isinstance(MODEL_MAP[model_type], str) + else MODEL_MAP[model_type] + ) + + torch_dtype = torch.float16 # TODO: test fp8, int4, int8, fp32 + + # TODO: test batched generation as well. + # TODO: This is flaky depending on the prompt for Mistral / Gemma, maybe see if it is a bug or not. + prompts = ["Today I am in Paris and I would like to eat crepes."] + + max_new_tokens = 15 + + for model_id in model_ids: + # Make sure we remove the potentially already built engines. + clean_cached_engines_for_model(model_id) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + inp = tokenizer(prompts, padding=True, return_tensors="pt").to("cuda") + + torch_model = TransformersAutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch_dtype, attn_implementation="eager" + ) + torch_model = torch_model.eval() + torch_model = torch_model.to("cuda") # TODO: remove? + + kwargs = { + "top_k": 1, + "top_p": 0, + "length_penalty": 1, + "repetition_penalty": 1, + "temperature": 1, + } + torch_generated_ids = torch_model.generate( + **inp, num_beams=1, do_sample=False, max_new_tokens=max_new_tokens, **kwargs + ) + + # Free a bit of memory. + del torch_model + gc.collect() + torch.cuda.empty_cache() + + trt_model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch_dtype, max_output_length=1000 + ) + + trt_generated_ids, _ = trt_model.generate( + **inp, num_beams=1, do_sample=False, max_new_tokens=max_new_tokens, **kwargs + ) + + assert torch.equal(trt_generated_ids, torch_generated_ids) + + +@pytest.mark.parametrize("model_type", MODEL_MAP.keys()) +def test_pipeline(model_type: str): + model_ids = ( + [MODEL_MAP[model_type]] + if isinstance(MODEL_MAP[model_type], str) + else MODEL_MAP[model_type] + ) + + kwargs = { + "top_k": 1, + "top_p": 0, + "length_penalty": 1, + "repetition_penalty": 1, + "temperature": 1, + } + + for model_id in model_ids: + # Make sure we remove the potentially already built engines. + clean_cached_engines_for_model(model_id) + + pipe_torch = transformers_pipeline( + task="text-generation", + model=model_id, + device="cuda", + torch_dtype=torch.float16, + ) + + with torch.no_grad(): + res_torch = pipe_torch( + "Today I am in Paris and I would like to eat crepes.", + add_special_tokens=True, + max_new_tokens=20, + **kwargs, + ) + + # Free a bit of memory. + del pipe_torch + gc.collect() + torch.cuda.empty_cache() + + pipe_trt = pipeline( + task="text-generation", model=model_id, max_output_length=1000 + ) + + with torch.no_grad(): + res_trt = pipe_trt( + "Today I am in Paris and I would like to eat crepes.", + max_new_tokens=20, + **kwargs, + ) + + assert res_torch[0]["generated_text"] == res_trt[0]["generated_text"] diff --git a/tests/integration/test_whisper.py b/tests/integration/test_whisper.py index a98166e3..bbd016c1 100644 --- a/tests/integration/test_whisper.py +++ b/tests/integration/test_whisper.py @@ -13,20 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import shutil import tempfile from glob import glob from pathlib import Path from typing import Optional import datasets -import huggingface_hub import pytest import torch from transformers import AutoProcessor from transformers import ( WhisperForConditionalGeneration as TransformersWhisperForConditionalGeneration, ) +from utils_testing import clean_cached_engines_for_model from optimum.nvidia.models.whisper import WhisperForConditionalGeneration @@ -38,24 +37,6 @@ ] -def clean_cached_engines_for_model(model_id: str): - cache_dir = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE - object_id = model_id.replace("/", "--") - full_model_path = Path(cache_dir, f"models--{object_id}") - if full_model_path.is_dir(): - # Resolve refs (for instance to convert main to the associated commit sha) - revision_file = Path(full_model_path, "refs", "main") - revision = "" - if revision_file.is_file(): - with open(revision_file) as f: - revision = f.read() - cached_path = Path(full_model_path, "snapshots", revision) - - for path in [cached_path / "encoder", cached_path / "decoder"]: - if path.exists() and path.is_dir(): - shutil.rmtree(path) - - @pytest.mark.parametrize("model_id", TEST_MODELS) def test_whisper(model_id: str): # Make sure we remove the potentially already built engines. diff --git a/tests/integration/utils_testing.py b/tests/integration/utils_testing.py new file mode 100644 index 00000000..ab0934ab --- /dev/null +++ b/tests/integration/utils_testing.py @@ -0,0 +1,41 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# http://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +from pathlib import Path + +import huggingface_hub + + +def clean_cached_engines_for_model(model_id: str): + cache_dir = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + object_id = model_id.replace("/", "--") + full_model_path = Path(cache_dir, f"models--{object_id}") + if full_model_path.is_dir(): + # Resolve refs (for instance to convert main to the associated commit sha) + revision_file = Path(full_model_path, "refs", "main") + revision = "" + if revision_file.is_file(): + with open(revision_file) as f: + revision = f.read() + cached_path = Path(full_model_path, "snapshots", revision) + + for path in [ + cached_path / "encoder", + cached_path / "decoder", + cached_path / "engines", + ]: + if path.exists() and path.is_dir(): + shutil.rmtree(path)