From 6186dac9edfd9fdbdce23499c2c8ad267e83a597 Mon Sep 17 00:00:00 2001 From: Qing Lan Date: Thu, 5 Oct 2023 16:20:35 -0700 Subject: [PATCH] add flash2 support for huggingface accelerate (#1111) --- .../python/setup/djl_python/huggingface.py | 30 ++++++++++++++----- .../rolling_batch/scheduler_rolling_batch.py | 8 +++++ serving/docker/deepspeed.Dockerfile | 2 +- serving/docker/fastertransformer.Dockerfile | 5 ++-- serving/docker/pytorch-inf2.Dockerfile | 2 +- tests/integration/instant_benchmark.py | 11 +++++-- 6 files changed, 44 insertions(+), 14 deletions(-) diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 1fc8af589..49b927f8d 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -54,6 +54,11 @@ "LlamaForCausalLM" } +# https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu +FLASH_2_SUPPORTED_MODELS = { + "LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM" +} + PEFT_MODEL_TASK_TO_CLS = { "SEQ_CLS": AutoModelForSequenceClassification, "SEQ_2_SEQ_LM": AutoModelForSeq2SeqLM, @@ -79,6 +84,14 @@ def get_torch_dtype_from_str(dtype: str): raise ValueError(f"Invalid data type: {dtype}") +def enable_flash(): + if torch.cuda.is_available(): + major, _ = torch.cuda.get_device_capability() + if major >= 8: + return True + return False + + def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool, model_config): if rolling_batch_type == "auto": @@ -106,17 +119,15 @@ def __init__(self, tokenizer, stop_seq): self.tokenizer = tokenizer self.stop_seq = stop_seq - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): decoded_input_ids = self.tokenizer.decode(input_ids[0][-len(self.stop_seq):]) matches = re.search(self.stop_seq, decoded_input_ids) - if(matches is not None): + if matches is not None: return True - else: - return False - return True + return False class HuggingFaceService(object): @@ -135,6 +146,7 @@ def __init__(self): self.model_config = None self.peft_config = None self.stopping_criteria_list = None + self.disable_flash_attn = None def initialize(self, properties: dict): # model_id can point to huggingface model_id or local directory. @@ -189,6 +201,8 @@ def initialize(self, properties: dict): properties.get("dtype")) if "revision" in properties: kwargs["revision"] = properties.get('revision') + self.disable_flash_attn = properties.get( + "disable_flash_attn", "false").lower() == 'true' self.rolling_batch_type = properties.get("rolling_batch", None) self._read_model_config(model_id_or_path, @@ -222,7 +236,7 @@ def initialize(self, properties: dict): model_id_or_path=model_id_or_path, kwargs=kwargs) - if("stop_sequence" in properties): + if "stop_sequence" in properties: self.load_stopping_criteria_list(properties["stop_sequence"]) self.initialized = True @@ -245,7 +259,7 @@ def load_stopping_criteria_list(self, stop_sequence): Input: (str) stop_sequence - currently just one stop sequence supported Output: none (loads into member variable) """ - if(self.tokenizer is None): + if self.tokenizer is None: return stop_seq_list = self.parse_stop_sequence_input(stop_sequence) @@ -490,6 +504,8 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs): model_cls = AutoModelForSeq2SeqLM else: model_cls = AutoModelForCausalLM + if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash() and not self.disable_flash_attn: + kwargs['use_flash_attention_2'] = True if self.peft_config is not None: base_model = model_cls.from_pretrained( diff --git a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py index e242a1774..a24eed0a0 100644 --- a/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/scheduler_rolling_batch.py @@ -14,6 +14,7 @@ from seq_scheduler.lm_block import HuggingfaceBlock, BloomBlock, FalconBlock from seq_scheduler.search_config import SearchConfig from seq_scheduler.seq_batch_scheduler import SeqBatchScheduler +from djl_python.huggingface import FLASH_2_SUPPORTED_MODELS, enable_flash from collections import namedtuple, defaultdict from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig @@ -39,6 +40,7 @@ def __init__(self, model_id_or_path, device, properties, **kwargs): super().__init__(device, **kwargs) self._init_model_and_tokenizer(model_id_or_path, device=device, + properties=properties, multi_gpu=properties.get( 'multi_gpu', None), **kwargs) @@ -96,6 +98,7 @@ def _init_model_and_tokenizer(self, model_id_or_path, device=None, multi_gpu=None, + properties=None, **kwargs): if "waiting_steps" in kwargs: kwargs.pop("waiting_steps") @@ -120,6 +123,11 @@ def _init_model_and_tokenizer(self, if 'device_map' in kwargs: device_map = kwargs.pop('device_map') + if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash(): + if properties.get( + "disable_flash_attn", "false").lower() != 'true': + kwargs['use_flash_attention_2'] = True + if "lmi_dist_sharding" == multi_gpu: if 'neox' in model_id_or_path: try: diff --git a/serving/docker/deepspeed.Dockerfile b/serving/docker/deepspeed.Dockerfile index cab7f2e6b..cfec3b778 100644 --- a/serving/docker/deepspeed.Dockerfile +++ b/serving/docker/deepspeed.Dockerfile @@ -25,7 +25,7 @@ ARG lmi_dist_wheel="https://publish.djl.ai/lmi_dist/lmi_dist-nightly-py3-none-an ARG seq_scheduler_wheel="https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl" ARG peft_wheel="https://publish.djl.ai/peft/peft-0.5.0alpha-py3-none-any.whl" ARG protobuf_version=3.20.3 -ARG transformers_version=4.33.2 +ARG transformers_version=4.34.0 ARG accelerate_version=0.23.0 ARG diffusers_version=0.16.0 ARG bitsandbytes_version=0.41.1 diff --git a/serving/docker/fastertransformer.Dockerfile b/serving/docker/fastertransformer.Dockerfile index 276c765b6..0ae63e282 100644 --- a/serving/docker/fastertransformer.Dockerfile +++ b/serving/docker/fastertransformer.Dockerfile @@ -20,9 +20,10 @@ ARG ft_wheel="https://publish.djl.ai/fastertransformer/fastertransformer-0.24.0- ARG tb_wheel="https://publish.djl.ai/tritonserver/r23.04/tritontoolkit-23.4-py3-none-any.whl" ARG peft_wheel="https://publish.djl.ai/peft/peft-0.5.0alpha-py3-none-any.whl" ARG seq_scheduler_wheel="https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl" +ARG flash_attn_2_wheel="https://publish.djl.ai/flash_attn/flash_attn_2-2.0.1-cp39-cp39-linux_x86_64.whl" ARG ompi_version=4.1.4 ARG protobuf_version=3.20.3 -ARG transformers_version=4.33.2 +ARG transformers_version=4.34.0 ARG accelerate_version=0.23.0 ARG bitsandbytes_version=0.41.1 ARG optimum_version=1.13.2 @@ -69,7 +70,7 @@ RUN apt-get update && apt-get install -y wget git libnuma-dev zlib1g-dev rapidjs pip3 install ${torch_wheel} ${ft_wheel} ${tb_wheel} ${peft_wheel} ${seq_scheduler_wheel} safetensors protobuf==${protobuf_version} && \ pip3 install transformers==${transformers_version} accelerate==${accelerate_version} \ bitsandbytes==${bitsandbytes_version} optimum==${optimum_version} auto-gptq==${auto_gptq_version} \ - scipy einops && \ + scipy einops ${flash_attn_2_wheel} && \ pip3 install cmake sentencepiece bfloat16 tiktoken && \ pip3 cache purge && \ apt-get clean -y && rm -rf /var/lib/apt/lists/* && \ diff --git a/serving/docker/pytorch-inf2.Dockerfile b/serving/docker/pytorch-inf2.Dockerfile index b26728583..f0b87818a 100644 --- a/serving/docker/pytorch-inf2.Dockerfile +++ b/serving/docker/pytorch-inf2.Dockerfile @@ -18,7 +18,7 @@ ARG transformers_neuronx_version=0.7.84 ARG neuronx_distributed_version=0.4.0 ARG neuronx_cc_version=2.10.* ARG protobuf_version=3.20.3 -ARG transformers_version=4.33.2 +ARG transformers_version=4.34.0 ARG accelerate_version=0.23.0 ARG diffusers_version=0.16.0 EXPOSE 8080 diff --git a/tests/integration/instant_benchmark.py b/tests/integration/instant_benchmark.py index 79e0d31bd..6bcc5e984 100644 --- a/tests/integration/instant_benchmark.py +++ b/tests/integration/instant_benchmark.py @@ -114,12 +114,16 @@ def build_running_script(template, job, instance, container): with open(template) as f: template = json.load(f) job_template = template[job] - job_template['awscurl'] = bytes.fromhex(job_template['awscurl']).decode("utf-8") + job_template['awscurl'] = bytes.fromhex( + job_template['awscurl']).decode("utf-8") write_model_artifacts(job_template['properties'], job_template['requirements']) command_str = f"./launch_container.sh {container} $PWD/models {machine_translation(instance)}" - bash_command = ['echo "Start Launching container..."', command_str, job_template['awscurl']] + bash_command = [ + 'echo "Start Launching container..."', command_str, + job_template['awscurl'] + ] with open("instant_benchmark.sh", "w") as f: f.write('\n'.join(bash_command)) @@ -133,7 +137,8 @@ def build_running_script(template, job, instance, container): command = f"echo \"template={json.dumps(json.dumps(json.dumps(result)))}\" >> $GITHUB_OUTPUT" sp.call(command, shell=True) elif args.template and args.job and args.instance and args.container: - build_running_script(args.template, args.job, args.instance, args.container) + build_running_script(args.template, args.job, args.instance, + args.container) else: parser.print_help() raise ValueError("args not supported")