From 8cc732c062f1b7a52477caec2a7eaa83bfce0915 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 15 Sep 2023 12:28:09 -0400 Subject: [PATCH] fix linting --- .mypy.ini | 3 ++ src/axolotl/models/phi/__init__.py | 8 +++- .../monkeypatch/phi_attn_hijack_flash.py | 40 ------------------- src/axolotl/utils/callbacks.py | 2 +- 4 files changed, 10 insertions(+), 43 deletions(-) delete mode 100644 src/axolotl/monkeypatch/phi_attn_hijack_flash.py diff --git a/.mypy.ini b/.mypy.ini index 9a6e56bb80..478765a39d 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -23,6 +23,9 @@ ignore_missing_imports = True [mypy-peft] ignore_missing_imports = True +[mypy-wandb] +ignore_missing_imports = True + [mypy-bitsandbytes] ignore_missing_imports = True diff --git a/src/axolotl/models/phi/__init__.py b/src/axolotl/models/phi/__init__.py index f0a31e356e..0619f648df 100644 --- a/src/axolotl/models/phi/__init__.py +++ b/src/axolotl/models/phi/__init__.py @@ -1,2 +1,6 @@ -from .configuration_mixformer_sequential import MixFormerSequentialConfig -from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM +""" +MixFormers model architecture used for phi models +""" + +from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa +from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa diff --git a/src/axolotl/monkeypatch/phi_attn_hijack_flash.py b/src/axolotl/monkeypatch/phi_attn_hijack_flash.py deleted file mode 100644 index 3e68d392c7..0000000000 --- a/src/axolotl/monkeypatch/phi_attn_hijack_flash.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Flash attention monkey patch for phi mixformers model -""" - -import importlib -import logging - -from flash_attn.flash_attn_interface import ( - flash_attn_kvpacked_func, - flash_attn_qkvpacked_func, -) -from transformers import AutoConfig - -LOG = logging.getLogger("axolotl") - - -def replace_phi_attn_with_flash_attn(model_name: str): - # this is a wonky hack to get the remotely loaded module - model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - module_name = model_config.__class__.__module__.replace( - ".configuration_mixformer_sequential", ".modeling_mixformer_sequential" - ) - modeling_phi = importlib.import_module(module_name) - modeling_phi.SelfAttention.forward = flash_self_attn_forward - modeling_phi.CrossAttention.forward = flash_cross_attn_forward - modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = ["ParallelBlock"] - - -def flash_self_attn_forward(self, qkv, causal=None, key_padding_mask=None): - causal = self.causal if causal is None else causal - return flash_attn_qkvpacked_func( - qkv, dropout_p=self.drop.p, softmax_scale=self.softmax_scale, causal=causal - ) - - -def flash_cross_attn_forward(self, q, kv, causal=None, key_padding_mask=None): - causal = self.causal if causal is None else causal - return flash_attn_kvpacked_func( - q, kv, dropout_p=self.drop.p, softmax_scale=self.softmax_scale, causal=causal - ) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index bf8c4145bd..9fdb5af7e3 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -11,7 +11,6 @@ import pandas as pd import torch import torch.distributed as dist -import wandb from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm @@ -25,6 +24,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +import wandb from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.distributed import ( barrier,