Skip to content

Commit

Permalink
add flash2 support for huggingface accelerate (#1111)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Oct 5, 2023
1 parent b3b58a0 commit 6186dac
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 14 deletions.
30 changes: 23 additions & 7 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@
"LlamaForCausalLM"
}

# https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#efficient-inference-on-a-single-gpu
FLASH_2_SUPPORTED_MODELS = {
"LlamaForCausalLM", "RWForCausalLM", "FalconForCausalLM"
}

PEFT_MODEL_TASK_TO_CLS = {
"SEQ_CLS": AutoModelForSequenceClassification,
"SEQ_2_SEQ_LM": AutoModelForSeq2SeqLM,
Expand All @@ -79,6 +84,14 @@ def get_torch_dtype_from_str(dtype: str):
raise ValueError(f"Invalid data type: {dtype}")


def enable_flash():
if torch.cuda.is_available():
major, _ = torch.cuda.get_device_capability()
if major >= 8:
return True
return False


def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
model_config):
if rolling_batch_type == "auto":
Expand Down Expand Up @@ -106,17 +119,15 @@ def __init__(self, tokenizer, stop_seq):
self.tokenizer = tokenizer
self.stop_seq = stop_seq

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
decoded_input_ids = self.tokenizer.decode(input_ids[0][-len(self.stop_seq):])

matches = re.search(self.stop_seq, decoded_input_ids)

if(matches is not None):
if matches is not None:
return True
else:
return False

return True
return False

class HuggingFaceService(object):

Expand All @@ -135,6 +146,7 @@ def __init__(self):
self.model_config = None
self.peft_config = None
self.stopping_criteria_list = None
self.disable_flash_attn = None

def initialize(self, properties: dict):
# model_id can point to huggingface model_id or local directory.
Expand Down Expand Up @@ -189,6 +201,8 @@ def initialize(self, properties: dict):
properties.get("dtype"))
if "revision" in properties:
kwargs["revision"] = properties.get('revision')
self.disable_flash_attn = properties.get(
"disable_flash_attn", "false").lower() == 'true'
self.rolling_batch_type = properties.get("rolling_batch", None)

self._read_model_config(model_id_or_path,
Expand Down Expand Up @@ -222,7 +236,7 @@ def initialize(self, properties: dict):
model_id_or_path=model_id_or_path,
kwargs=kwargs)

if("stop_sequence" in properties):
if "stop_sequence" in properties:
self.load_stopping_criteria_list(properties["stop_sequence"])
self.initialized = True

Expand All @@ -245,7 +259,7 @@ def load_stopping_criteria_list(self, stop_sequence):
Input: (str) stop_sequence - currently just one stop sequence supported
Output: none (loads into member variable)
"""
if(self.tokenizer is None):
if self.tokenizer is None:
return

stop_seq_list = self.parse_stop_sequence_input(stop_sequence)
Expand Down Expand Up @@ -490,6 +504,8 @@ def _init_model_and_tokenizer(self, model_id_or_path: str, **kwargs):
model_cls = AutoModelForSeq2SeqLM
else:
model_cls = AutoModelForCausalLM
if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash() and not self.disable_flash_attn:
kwargs['use_flash_attention_2'] = True

if self.peft_config is not None:
base_model = model_cls.from_pretrained(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from seq_scheduler.lm_block import HuggingfaceBlock, BloomBlock, FalconBlock
from seq_scheduler.search_config import SearchConfig
from seq_scheduler.seq_batch_scheduler import SeqBatchScheduler
from djl_python.huggingface import FLASH_2_SUPPORTED_MODELS, enable_flash
from collections import namedtuple, defaultdict
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
Expand All @@ -39,6 +40,7 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
super().__init__(device, **kwargs)
self._init_model_and_tokenizer(model_id_or_path,
device=device,
properties=properties,
multi_gpu=properties.get(
'multi_gpu', None),
**kwargs)
Expand Down Expand Up @@ -96,6 +98,7 @@ def _init_model_and_tokenizer(self,
model_id_or_path,
device=None,
multi_gpu=None,
properties=None,
**kwargs):
if "waiting_steps" in kwargs:
kwargs.pop("waiting_steps")
Expand All @@ -120,6 +123,11 @@ def _init_model_and_tokenizer(self,
if 'device_map' in kwargs:
device_map = kwargs.pop('device_map')

if architectures[0] in FLASH_2_SUPPORTED_MODELS and enable_flash():
if properties.get(
"disable_flash_attn", "false").lower() != 'true':
kwargs['use_flash_attention_2'] = True

if "lmi_dist_sharding" == multi_gpu:
if 'neox' in model_id_or_path:
try:
Expand Down
2 changes: 1 addition & 1 deletion serving/docker/deepspeed.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ARG lmi_dist_wheel="https://publish.djl.ai/lmi_dist/lmi_dist-nightly-py3-none-an
ARG seq_scheduler_wheel="https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl"
ARG peft_wheel="https://publish.djl.ai/peft/peft-0.5.0alpha-py3-none-any.whl"
ARG protobuf_version=3.20.3
ARG transformers_version=4.33.2
ARG transformers_version=4.34.0
ARG accelerate_version=0.23.0
ARG diffusers_version=0.16.0
ARG bitsandbytes_version=0.41.1
Expand Down
5 changes: 3 additions & 2 deletions serving/docker/fastertransformer.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ ARG ft_wheel="https://publish.djl.ai/fastertransformer/fastertransformer-0.24.0-
ARG tb_wheel="https://publish.djl.ai/tritonserver/r23.04/tritontoolkit-23.4-py3-none-any.whl"
ARG peft_wheel="https://publish.djl.ai/peft/peft-0.5.0alpha-py3-none-any.whl"
ARG seq_scheduler_wheel="https://publish.djl.ai/seq_scheduler/seq_scheduler-0.1.0-py3-none-any.whl"
ARG flash_attn_2_wheel="https://publish.djl.ai/flash_attn/flash_attn_2-2.0.1-cp39-cp39-linux_x86_64.whl"
ARG ompi_version=4.1.4
ARG protobuf_version=3.20.3
ARG transformers_version=4.33.2
ARG transformers_version=4.34.0
ARG accelerate_version=0.23.0
ARG bitsandbytes_version=0.41.1
ARG optimum_version=1.13.2
Expand Down Expand Up @@ -69,7 +70,7 @@ RUN apt-get update && apt-get install -y wget git libnuma-dev zlib1g-dev rapidjs
pip3 install ${torch_wheel} ${ft_wheel} ${tb_wheel} ${peft_wheel} ${seq_scheduler_wheel} safetensors protobuf==${protobuf_version} && \
pip3 install transformers==${transformers_version} accelerate==${accelerate_version} \
bitsandbytes==${bitsandbytes_version} optimum==${optimum_version} auto-gptq==${auto_gptq_version} \
scipy einops && \
scipy einops ${flash_attn_2_wheel} && \
pip3 install cmake sentencepiece bfloat16 tiktoken && \
pip3 cache purge && \
apt-get clean -y && rm -rf /var/lib/apt/lists/* && \
Expand Down
2 changes: 1 addition & 1 deletion serving/docker/pytorch-inf2.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ARG transformers_neuronx_version=0.7.84
ARG neuronx_distributed_version=0.4.0
ARG neuronx_cc_version=2.10.*
ARG protobuf_version=3.20.3
ARG transformers_version=4.33.2
ARG transformers_version=4.34.0
ARG accelerate_version=0.23.0
ARG diffusers_version=0.16.0
EXPOSE 8080
Expand Down
11 changes: 8 additions & 3 deletions tests/integration/instant_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,16 @@ def build_running_script(template, job, instance, container):
with open(template) as f:
template = json.load(f)
job_template = template[job]
job_template['awscurl'] = bytes.fromhex(job_template['awscurl']).decode("utf-8")
job_template['awscurl'] = bytes.fromhex(
job_template['awscurl']).decode("utf-8")
write_model_artifacts(job_template['properties'],
job_template['requirements'])

command_str = f"./launch_container.sh {container} $PWD/models {machine_translation(instance)}"
bash_command = ['echo "Start Launching container..."', command_str, job_template['awscurl']]
bash_command = [
'echo "Start Launching container..."', command_str,
job_template['awscurl']
]
with open("instant_benchmark.sh", "w") as f:
f.write('\n'.join(bash_command))

Expand All @@ -133,7 +137,8 @@ def build_running_script(template, job, instance, container):
command = f"echo \"template={json.dumps(json.dumps(json.dumps(result)))}\" >> $GITHUB_OUTPUT"
sp.call(command, shell=True)
elif args.template and args.job and args.instance and args.container:
build_running_script(args.template, args.job, args.instance, args.container)
build_running_script(args.template, args.job, args.instance,
args.container)
else:
parser.print_help()
raise ValueError("args not supported")

0 comments on commit 6186dac

Please sign in to comment.