From f7443df1a5141fbd6e74456f72ccfa144d5e512d Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 1 Aug 2024 13:40:47 +0000 Subject: [PATCH 1/3] fix FA2 patching Signed-off-by: Yu Chin Fabian Lim --- .../src/fms_acceleration_ilab/flash_attn.py | 181 ++++++++++++------ .../framework_plugin_padding_free.py | 95 +++++++-- 2 files changed, 193 insertions(+), 83 deletions(-) diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py index ce471510..07499c6b 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import inspect +from functools import partial import torch -from transformers.utils import is_flash_attn_2_available +from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal from types import MethodType +from typing import Optional if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func # pylint: disable=import-error + from flash_attn import flash_attn_func, flash_attn_varlen_func # pylint: disable=import-error + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length): query = query.view(-1, query.size(-2), query.size(-1)) @@ -28,81 +33,133 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length) cu_seq_lens = torch.cat(( indices_q[position_ids==0], torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32) - )) + )) max_length = position_ids.max()+1 return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) +# model id -> position_ids +POSITION_IDS_CACHE = {} + +# - needed to store position ids when first come into model +# will pass these to the flash attention function +def build2( + model: torch.nn.Module, model_id: str, +): + # forward + old_forward = model.forward + + # the model will get out the position + def forward(self, *args, **kwargs): + # store position ids + POSITION_IDS_CACHE[model_id] = kwargs['position_ids'] + return old_forward(*args, **kwargs) + + return forward + def build_fa_forward( - attention: torch.nn.Module, causal: bool = True, dropout: float = None + attention: torch.nn.Module, model_id: str, ): - # assert not hasattr(self, '_position_ids'), "cannot patch fa attention" - position_ids: torch.Tensor = None + # this is really a dummpy replace old_forward = attention.forward - if dropout is not None: - attention.dropout = torch.nn.Dropout(p=dropout) - def forward(self, *args, **kwargs): - nonlocal position_ids - position_ids = kwargs['position_ids'] out, *others = old_forward(*args, **kwargs) - if dropout is not None: - out = self.dropout(out) return out, *others - def _flash_attention_forward( - self, + _flash_attn = partial( + _flash_attention_forward_with_posids, model_id + ) + + # do this replace of a method with a static + attention._flash_attention_forward = _flash_attn + + # return the forward + return forward + +# FIXME: it is difficult to keep up with all the different versions +# - this is a generic version that accepts +def _flash_attention_forward_with_posids( + model_id: str, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool = True, # make this optional to support < 4.43 + dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: bool = None, + **kwargs, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + """ + position_ids = POSITION_IDS_CACHE[model_id] + + if not use_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + causal = is_causal and query_length != 1 + + # for supporting < 4.43 + use_sliding_windows = kwargs.get("use_sliding_windows") + if use_sliding_windows is None: + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + try: + if is_flash_attn_greater_or_equal("2.4.1"): + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + except: + # FIXME: is_flash_attn_greater_or_equal expects a version + # object for < 4.43 + pass + + if softcap is not None: + flash_kwargs["softcap"] = softcap + + assert attention_mask is None, "should not be using attention mask" + assert position_ids is not None, "should be expecting position ids" + batch_size = query_states.size(0) + ( query_states, key_states, value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - **kwargs, - ): - # if not self._flash_attn_uses_top_left_mask: - # causal = self.is_causal - # else: - # # TODO: Remove the `query_length != 1` - # # check once Flash Attention for RoCm is bumped to 2.1. - # # For details, please see the comment in LlamaFlashAttention2 __init__. - # causal = self.is_causal and query_length != 1 - - assert attention_mask is None, "should not be using attention mask" - assert position_ids is not None, "should be expecting position ids" - batch_size = query_states.size(0) - ( - query_states, - key_states, - value_states, - _, - cu_seq_lens, - max_seq_lens, - ) = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids, query_length - ) + _, + cu_seq_lens, + max_seq_lens, + ) = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids, query_length + ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - return attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) - # do this replace - attention._flash_attention_forward = MethodType(_flash_attention_forward, attention) + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) - # return the forward - return forward + return attn_output \ No newline at end of file diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py index 6486791d..637ed7d2 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py @@ -59,6 +59,13 @@ def augmentation( train_args: TrainingArguments, modifiable_args: Tuple[LoraConfig], ): + # guarded + from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel + ModelPatcher, + ModelPatcherRule, + ModelPatcherTrigger, + ) + from functools import partial # pylint: disable=import-outside-toplevel # This check is done here to only patch the attention forward # if below a specific transformer version (4.43.3) that already @@ -68,20 +75,69 @@ def augmentation( # such as attention dropout, the version check should be shifted # into `build_fa_forward` to manage the forward replacement inside # the function. - if version.parse(transformers_version) < version.parse(TRANSFORMERS_VERSION): - # guarded - from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel - ModelPatcher, - ModelPatcherRule, - ModelPatcherTrigger, + + try: + # if this is importable, it means + # https://github.com/huggingface/transformers/pull/31629/files + # has been merged, and there is no more need to + from transformers import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel,no-name-in-module + warnings.warn(f"transformers version supports \ + padding free natively in various models.") + return model, modifiable_args + + except: + pass + + import torch.nn + + def _is_backbone(module: torch.nn.Module): + return any( + isinstance(mod, torch.nn.Embedding) + for mod in module.children() ) - from .flash_attn import build_fa_forward # pylint: disable=import-outside-toplevel - from functools import partial # pylint: disable=import-outside-toplevel - # TODO: have a generic version of this rule - # - do regex on RMSNorm class name - # - check on the tensors required for fast_rms_layernorm - model_type = model.config.model_type + # patch top level + model_type = model.config.model_type + from .flash_attn import build2 + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"{model_type}-top-pad-free", + # trigger=ModelPatcherTrigger(check=model.__class__), + trigger=ModelPatcherTrigger(check=_is_backbone), + forward_builder=partial( + build2, + model_id=id(model), + ), + ), + ) + + try: + # if this can be imported, then we need to patch it + # - because it does not have logic to handle the flattened batch + # - replace with our version that has the new logic + # - do this only once + + # pylint: disable=import-outside-toplevel + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + from .flash_attn import _flash_attention_forward_with_posids + + ModelPatcher.register( + ModelPatcherRule( + rule_id=f"flash_attn_forward", + import_and_maybe_reload=( + "transformers.modeling_flash_attention_utils._flash_attention_forward", + partial(_flash_attention_forward_with_posids, id(model)), + model.__module__, + # "transformers.models.mistral.modeling_mistral", + ) + ), + ) + except: + + # finally we need to patch the flash attention + # as they do not yet accept the position ids + from .flash_attn import build_fa_forward # pylint: disable=import-outside-toplevel def is_flash_attn_2(module): if ( module.__class__.__name__.endswith("FlashAttention2") @@ -95,13 +151,10 @@ def is_flash_attn_2(module): trigger=ModelPatcherTrigger(check=is_flash_attn_2), forward_builder=partial( build_fa_forward, - causal=True, + model_id=id(model), ), ), ) - else: - warnings.warn(f"transformers version is equal or later \ - than {TRANSFORMERS_VERSION}, attention forward will not be replaced.") return model, modifiable_args @@ -122,12 +175,12 @@ def _patch_dataloader( - we replace the collate function in the dataloader to flatten the batch into a long sequence with special tokens to define the attention computation boundaries """ - # Check if transformers already supports a collator that flattens the batch - # Otherwise, use the locally implemented DataCollatorWithFlattening - if version.parse(transformers_version) < version.parse(TRANSFORMERS_VERSION): - from .ilab_utils import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel - else: + try: + # Check if transformers already supports a collator that flattens the batch from transformers import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel,no-name-in-module + except: + # Otherwise, use the locally implemented DataCollatorWithFlattening + from .ilab_utils import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel # hijack the dataloader in accelerator.prepare to replace the collate_fn _old_prepare = accelerator.prepare From a19b115cf002984fac2edcca097c1b5127c614e9 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 1 Aug 2024 15:43:29 +0000 Subject: [PATCH 2/3] clean up Signed-off-by: Yu Chin Fabian Lim --- .../src/fms_acceleration_ilab/flash_attn.py | 6 +- .../framework_plugin_padding_free.py | 55 +++++++++---------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py index 07499c6b..e9038ccf 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py @@ -17,7 +17,6 @@ from functools import partial import torch from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal -from types import MethodType from typing import Optional if is_flash_attn_2_available(): @@ -42,7 +41,7 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length) # - needed to store position ids when first come into model # will pass these to the flash attention function -def build2( +def build_backbone_forward( model: torch.nn.Module, model_id: str, ): # forward @@ -122,8 +121,9 @@ def _flash_attention_forward_with_posids( deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" flash_kwargs["deterministic"] = deterministic except: - # FIXME: is_flash_attn_greater_or_equal expects a version + # FIXME: is_flash_attn_greater_or_equal expects a packaging.version # object for < 4.43 + # - we just assume that this deterministic flag is not impt pass if softcap is not None: diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py index 637ed7d2..0244f4ea 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py @@ -14,7 +14,6 @@ # Standard from typing import Dict, Tuple -from packaging import version import warnings # Third Party @@ -22,7 +21,6 @@ from peft import LoraConfig from transformers import ( TrainingArguments, - __version__ as transformers_version, DataCollatorForSeq2Seq, ) from accelerate import Accelerator @@ -30,9 +28,6 @@ from types import MethodType from torch.utils.data import DataLoader -# This is the version where padding-free was merged into transformers -TRANSFORMERS_VERSION = "4.44" - class PaddingFreeAccelerationPlugin(AccelerationPlugin): require_packages = ["flash_attn"] @@ -68,54 +63,54 @@ def augmentation( from functools import partial # pylint: disable=import-outside-toplevel # This check is done here to only patch the attention forward - # if below a specific transformer version (4.43.3) that already - # addresses padding free + # the PR was merged here # https://github.com/huggingface/transformers/pull/31629 - # Subsequently, when additional features are added to the patch - # such as attention dropout, the version check should be shifted - # into `build_fa_forward` to manage the forward replacement inside - # the function. try: - # if this is importable, it means - # https://github.com/huggingface/transformers/pull/31629/files + # if this is importable, it means the PR # has been merged, and there is no more need to from transformers import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel,no-name-in-module - warnings.warn(f"transformers version supports \ - padding free natively in various models.") + + # - if import successful this will print and return + warnings.warn( + "transformers version supports padding free natively in various models." + ) return model, modifiable_args except: pass - import torch.nn - + # Otherwise patching is required: + # 1. a custom forward has to be registered on the backbone + # to intercept the position ids def _is_backbone(module: torch.nn.Module): return any( isinstance(mod, torch.nn.Embedding) for mod in module.children() ) - # patch top level + # - patch backbone model_type = model.config.model_type - from .flash_attn import build2 + from .flash_attn import build_backbone_forward ModelPatcher.register( ModelPatcherRule( - rule_id=f"{model_type}-top-pad-free", - # trigger=ModelPatcherTrigger(check=model.__class__), + rule_id=f"{model_type}-backbone-pad-free", trigger=ModelPatcherTrigger(check=_is_backbone), forward_builder=partial( - build2, + build_backbone_forward, model_id=id(model), ), ), ) + # Next, the flash attention function needs to be patched + # how it is patched depends on the transformers version try: - # if this can be imported, then we need to patch it - # - because it does not have logic to handle the flattened batch - # - replace with our version that has the new logic - # - do this only once + # Case I: + # if transformers.modeling_flash_attention_utils + # can be imported, then we patch the flash attention function + # here. This is required because + # - this is an old version that does not have logic to handle the flattened batch # pylint: disable=import-outside-toplevel from transformers.modeling_flash_attention_utils import _flash_attention_forward @@ -129,14 +124,14 @@ def _is_backbone(module: torch.nn.Module): "transformers.modeling_flash_attention_utils._flash_attention_forward", partial(_flash_attention_forward_with_posids, id(model)), model.__module__, - # "transformers.models.mistral.modeling_mistral", ) ), ) except: - - # finally we need to patch the flash attention - # as they do not yet accept the position ids + # Case II: the flash attention functions are methods + # attached to the model classes + # - for similar reasons as Case I, they need to be patched on the + # FA2 modules from .flash_attn import build_fa_forward # pylint: disable=import-outside-toplevel def is_flash_attn_2(module): if ( From f5d16681d6e9d088dd324e143195e0ba9d6af22d Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 1 Aug 2024 15:45:58 +0000 Subject: [PATCH 3/3] fmt Signed-off-by: Yu Chin Fabian Lim --- .../src/fms_acceleration_ilab/flash_attn.py | 99 ++++++++++++------- .../framework_plugin_padding_free.py | 70 ++++++++----- .../src/fms_acceleration_ilab/ilab_utils.py | 2 +- 3 files changed, 109 insertions(+), 62 deletions(-) diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py index e9038ccf..26e26d01 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/flash_attn.py @@ -16,33 +16,59 @@ import inspect from functools import partial import torch + +# pylint: disable=no-name-in-module from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal from typing import Optional if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func # pylint: disable=import-error - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + # pylint: disable=import-error + from flash_attn import ( + flash_attn_func, + flash_attn_varlen_func, + ) + + _flash_supports_window_size = "window_size" in list( + inspect.signature(flash_attn_func).parameters + ) + def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length): query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) value = value.view(-1, value.size(-2), value.size(-1)) position_ids = position_ids.flatten() - indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) - cu_seq_lens = torch.cat(( - indices_q[position_ids==0], - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32) - )) - max_length = position_ids.max()+1 - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + indices_q = torch.arange( + position_ids.size(0), device=position_ids.device, dtype=torch.int32 + ) + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor( + position_ids.size(), device=position_ids.device, dtype=torch.int32 + ), + ) + ) + max_length = position_ids.max() + 1 + return ( + query, + key, + value, + indices_q, + (cu_seq_lens, cu_seq_lens), + (max_length, max_length), + ) + # model id -> position_ids POSITION_IDS_CACHE = {} + # - needed to store position ids when first come into model # will pass these to the flash attention function def build_backbone_forward( - model: torch.nn.Module, model_id: str, + model: torch.nn.Module, + model_id: str, ): # forward old_forward = model.forward @@ -50,24 +76,25 @@ def build_backbone_forward( # the model will get out the position def forward(self, *args, **kwargs): # store position ids - POSITION_IDS_CACHE[model_id] = kwargs['position_ids'] + POSITION_IDS_CACHE[model_id] = kwargs["position_ids"] return old_forward(*args, **kwargs) return forward - + + def build_fa_forward( - attention: torch.nn.Module, model_id: str, + attention: torch.nn.Module, + model_id: str, ): - # this is really a dummpy replace + # this is really a dummpy replace old_forward = attention.forward + def forward(self, *args, **kwargs): out, *others = old_forward(*args, **kwargs) return out, *others - _flash_attn = partial( - _flash_attention_forward_with_posids, model_id - ) + _flash_attn = partial(_flash_attention_forward_with_posids, model_id) # do this replace of a method with a static attention._flash_attention_forward = _flash_attn @@ -75,8 +102,9 @@ def forward(self, *args, **kwargs): # return the forward return forward + # FIXME: it is difficult to keep up with all the different versions -# - this is a generic version that accepts +# - this is a generic version that accepts def _flash_attention_forward_with_posids( model_id: str, query_states: torch.Tensor, @@ -84,9 +112,8 @@ def _flash_attention_forward_with_posids( value_states: torch.Tensor, attention_mask: torch.Tensor, query_length: int, - is_causal: bool = True, # make this optional to support < 4.43 + is_causal: bool = True, # make this optional to support < 4.43 dropout: float = 0.0, - position_ids: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, sliding_window: Optional[int] = None, use_top_left_mask: bool = False, @@ -94,16 +121,12 @@ def _flash_attention_forward_with_posids( deterministic: bool = None, **kwargs, ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - """ + # get the position ids out here position_ids = POSITION_IDS_CACHE[model_id] - + if not use_top_left_mask: causal = is_causal else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. causal = is_causal and query_length != 1 # for supporting < 4.43 @@ -111,18 +134,24 @@ def _flash_attention_forward_with_posids( if use_sliding_windows is None: # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + _flash_supports_window_size + and sliding_window is not None + and key_states.shape[1] > sliding_window ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + flash_kwargs = ( + {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + ) try: if is_flash_attn_greater_or_equal("2.4.1"): if deterministic is None: - deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + deterministic = ( + os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + ) flash_kwargs["deterministic"] = deterministic - except: - # FIXME: is_flash_attn_greater_or_equal expects a packaging.version - # object for < 4.43 + except AttributeError: + # FIXME: is_flash_attn_greater_or_equal expects a + # packaging.version object for < 4.43 # - we just assume that this deterministic flag is not impt pass @@ -160,6 +189,8 @@ def _flash_attention_forward_with_posids( **flash_kwargs, ) - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + attn_output = attn_output.view( + batch_size, -1, attn_output.size(-2), attn_output.size(-1) + ) - return attn_output \ No newline at end of file + return attn_output diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py index 0244f4ea..c5725959 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/framework_plugin_padding_free.py @@ -28,6 +28,7 @@ from types import MethodType from torch.utils.data import DataLoader + class PaddingFreeAccelerationPlugin(AccelerationPlugin): require_packages = ["flash_attn"] @@ -55,12 +56,12 @@ def augmentation( modifiable_args: Tuple[LoraConfig], ): # guarded - from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel + from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel ModelPatcher, ModelPatcherRule, ModelPatcherTrigger, ) - from functools import partial # pylint: disable=import-outside-toplevel + from functools import partial # pylint: disable=import-outside-toplevel # This check is done here to only patch the attention forward # the PR was merged here @@ -68,8 +69,11 @@ def augmentation( try: # if this is importable, it means the PR - # has been merged, and there is no more need to - from transformers import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel,no-name-in-module + # has been merged, and there is no more need to + # pylint: disable=import-outside-toplevel,no-name-in-module,unused-import + from transformers import ( + DataCollatorWithFlattening, + ) # - if import successful this will print and return warnings.warn( @@ -77,21 +81,20 @@ def augmentation( ) return model, modifiable_args - except: + except ImportError: pass # Otherwise patching is required: # 1. a custom forward has to be registered on the backbone # to intercept the position ids def _is_backbone(module: torch.nn.Module): - return any( - isinstance(mod, torch.nn.Embedding) - for mod in module.children() - ) + return any(isinstance(mod, torch.nn.Embedding) for mod in module.children()) # - patch backbone model_type = model.config.model_type + # pylint: disable=import-outside-toplevel from .flash_attn import build_backbone_forward + ModelPatcher.register( ModelPatcherRule( rule_id=f"{model_type}-backbone-pad-free", @@ -107,36 +110,39 @@ def _is_backbone(module: torch.nn.Module): # how it is patched depends on the transformers version try: # Case I: - # if transformers.modeling_flash_attention_utils + # if transformers.modeling_flash_attention_utils # can be imported, then we patch the flash attention function - # here. This is required because + # here. This is required because # - this is an old version that does not have logic to handle the flattened batch # pylint: disable=import-outside-toplevel - from transformers.modeling_flash_attention_utils import _flash_attention_forward + from transformers.modeling_flash_attention_utils import ( + _flash_attention_forward, + ) from .flash_attn import _flash_attention_forward_with_posids ModelPatcher.register( ModelPatcherRule( - rule_id=f"flash_attn_forward", + rule_id="flash_attn_forward", import_and_maybe_reload=( "transformers.modeling_flash_attention_utils._flash_attention_forward", partial(_flash_attention_forward_with_posids, id(model)), model.__module__, - ) + ), ), ) - except: - # Case II: the flash attention functions are methods + except ImportError: + # Case II: the flash attention functions are methods # attached to the model classes # - for similar reasons as Case I, they need to be patched on the # FA2 modules - from .flash_attn import build_fa_forward # pylint: disable=import-outside-toplevel + from .flash_attn import ( + build_fa_forward, + ) # pylint: disable=import-outside-toplevel + def is_flash_attn_2(module): - if ( - module.__class__.__name__.endswith("FlashAttention2") - ): + if module.__class__.__name__.endswith("FlashAttention2"): return True return False @@ -161,8 +167,8 @@ def get_callbacks_and_ready_for_train( return [] def _patch_dataloader( - self, - accelerator: Accelerator, + self, + accelerator: Accelerator, ): """ Hijacks the accelorator prepare inside `Trainer.train` @@ -172,22 +178,31 @@ def _patch_dataloader( """ try: # Check if transformers already supports a collator that flattens the batch - from transformers import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel,no-name-in-module - except: + # pylint: disable=import-outside-toplevel,no-name-in-module + from transformers import ( + DataCollatorWithFlattening, + ) + except ImportError: # Otherwise, use the locally implemented DataCollatorWithFlattening - from .ilab_utils import DataCollatorWithFlattening # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from .ilab_utils import ( + DataCollatorWithFlattening, + ) # hijack the dataloader in accelerator.prepare to replace the collate_fn _old_prepare = accelerator.prepare + def prepare(self, *args, device_placement=None): if len(args) > 1 or not isinstance(args[0], DataLoader): return _old_prepare(*args, device_placement=device_placement) dataloader = args[0] if not isinstance(dataloader.collate_fn, DataCollatorForSeq2Seq): - raise TypeError("The padding-free plugin currently only works with a \ + raise TypeError( + "The padding-free plugin currently only works with a \ `DataCollatorForSeq2Seq` collate_fn, \ - otherwise the collation can be unreliable") + otherwise the collation can be unreliable" + ) # Replace the collate_fn in dataloader dataloader.collate_fn = DataCollatorWithFlattening() @@ -196,6 +211,7 @@ def prepare(self, *args, device_placement=None): accelerator.prepare = MethodType(prepare, accelerator) + # register AccelerationPlugin.register_plugin( PaddingFreeAccelerationPlugin, diff --git a/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py b/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py index c8529669..330bf5eb 100644 --- a/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py +++ b/plugins/instruct-lab/src/fms_acceleration_ilab/ilab_utils.py @@ -16,6 +16,7 @@ import warnings from transformers import DefaultDataCollator, default_data_collator + @dataclass class DataCollatorWithFlattening(DefaultDataCollator): """ @@ -51,4 +52,3 @@ def __call__(self, features, return_tensors=None): else: ret["labels"] += [-100] + feature["input_ids"][1:] return default_data_collator([ret], return_tensors) - \ No newline at end of file