From ef2434253802a81b7d7a250efe7b1d878f837d7e Mon Sep 17 00:00:00 2001 From: kallewoof Date: Thu, 14 Dec 2023 07:15:34 +0900 Subject: [PATCH] fix: switch to using the HuggingFace Transformers NEFT implementation (#941) * fix: switch to using the HuggingFace Transformers NEFT implementation * linter * add support for noisy_embedding_alpha with a warning about it being renamed * restore pre/posttrain_hooks * move validation of NEFT noise alpha into validate_config() * linter --- README.md | 2 +- src/axolotl/core/trainer_builder.py | 6 ++ src/axolotl/monkeypatch/neft_embeddings.py | 65 ---------------------- src/axolotl/train.py | 7 +-- src/axolotl/utils/config.py | 14 +++++ 5 files changed, 23 insertions(+), 71 deletions(-) delete mode 100644 src/axolotl/monkeypatch/neft_embeddings.py diff --git a/README.md b/README.md index 1dfa524b16..44fd7d57f3 100644 --- a/README.md +++ b/README.md @@ -774,7 +774,7 @@ max_grad_norm: # Augmentation techniques # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings # currently only supported on Llama and Mistral -noisy_embedding_alpha: +neftune_noise_alpha: # Whether to bettertransformers flash_optimum: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 1b037420c5..ccd9d37c0d 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -712,6 +712,12 @@ def build(self, total_num_steps): training_arguments_kwargs ) training_arguments_kwargs["model_type"] = self.cfg.model_config_type + + if self.cfg.neftune_noise_alpha is not None: + training_arguments_kwargs[ + "neftune_noise_alpha" + ] = self.cfg.neftune_noise_alpha + training_args = ( AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, diff --git a/src/axolotl/monkeypatch/neft_embeddings.py b/src/axolotl/monkeypatch/neft_embeddings.py deleted file mode 100644 index 524d48f8ff..0000000000 --- a/src/axolotl/monkeypatch/neft_embeddings.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914 -""" -import torch -from peft import PeftModel -from transformers import PreTrainedModel - - -def patch_neft(alpha, model): - embeddings = None - if isinstance(model, PreTrainedModel): - embeddings = model.get_input_embeddings() - if isinstance(model, PeftModel): - embeddings = model.base_model.get_input_embeddings() - if not embeddings: - raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") - embeddings.noisy_embedding_alpha = alpha - old_forward = embeddings.forward - - # This hack seems to be needed to properly use a custom forward pass - # all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11 - bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter - embeddings, embeddings.__class__ - ) - setattr(embeddings, "forward", bound_method) - - embeddings._old_forward = old_forward # pylint: disable=protected-access - return model - - -def unpatch_neft(model): - embeddings = None - if isinstance(model, PreTrainedModel): - embeddings = model.get_input_embeddings() - if isinstance(model, PeftModel): - embeddings = model.base_model.get_input_embeddings() - if not embeddings: - raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") - if hasattr(embeddings, "_old_forward"): - embeddings.forward = embeddings._old_forward # pylint: disable=protected-access - del embeddings._old_forward # pylint: disable=protected-access - del embeddings.noisy_embedding_alpha - - -def neft_forward(self, inputs: torch.Tensor): - embeddings = self._old_forward(inputs) # pylint: disable=protected-access - - if self.training: - dims = torch.tensor(embeddings.size(1) * embeddings.size(2)) - mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims) - embeddings = embeddings + torch.zeros_like(embeddings).uniform_( - -mag_norm, mag_norm - ) - - return embeddings - - -def pretrain_hook(cfg, trainer): - if cfg.noisy_embedding_alpha: - trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model) - - -def post_train_hook(cfg, trainer): - if cfg.noisy_embedding_alpha: - unpatch_neft(trainer.model) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b65d1455fa..169fc51272 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -16,7 +16,6 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging -from axolotl.monkeypatch import neft_embeddings from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_parameters_except from axolotl.utils.models import load_model, load_tokenizer @@ -180,21 +179,19 @@ def terminate_handler(_, __, model): return model, tokenizer -def pretrain_hooks(cfg, trainer): +def pretrain_hooks(_cfg, _trainer): """ Run hooks right before kicking off the training :param cfg: :param trainer: :return: """ - neft_embeddings.pretrain_hook(cfg, trainer) -def post_train_hooks(cfg, trainer): +def post_train_hooks(_cfg, _trainer): """ Run hooks right after training completes :param cfg: :param trainer: :return: """ - neft_embeddings.post_train_hook(cfg, trainer) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index b04c207ddb..1b4ce92465 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -434,6 +434,20 @@ def validate_config(cfg): "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." ) + if cfg.noisy_embedding_alpha is not None: + # Deprecated, use neftune_noise_alpha + LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") + if cfg.neftune_noise_alpha is None: + cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha + else: + # User is providing both; bail and have them sort out their settings + raise ValueError( + "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting" + ) + + if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0: + raise ValueError("neftune_noise_alpha must be > 0.0") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25