Skip to content

Commit

Permalink
fix: switch to using the HuggingFace Transformers NEFT implementation (
Browse files Browse the repository at this point in the history
…#941)

* fix: switch to using the HuggingFace Transformers NEFT implementation

* linter

* add support for noisy_embedding_alpha with a warning about it being renamed

* restore pre/posttrain_hooks

* move validation of NEFT noise alpha into validate_config()

* linter
  • Loading branch information
kallewoof authored Dec 13, 2023
1 parent 5ea3aa3 commit ef24342
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 71 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ max_grad_norm:
# Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
noisy_embedding_alpha:
neftune_noise_alpha:

# Whether to bettertransformers
flash_optimum:
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,12 @@ def build(self, total_num_steps):
training_arguments_kwargs
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type

if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha

training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
Expand Down
65 changes: 0 additions & 65 deletions src/axolotl/monkeypatch/neft_embeddings.py

This file was deleted.

7 changes: 2 additions & 5 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
Expand Down Expand Up @@ -180,21 +179,19 @@ def terminate_handler(_, __, model):
return model, tokenizer


def pretrain_hooks(cfg, trainer):
def pretrain_hooks(_cfg, _trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.pretrain_hook(cfg, trainer)


def post_train_hooks(cfg, trainer):
def post_train_hooks(_cfg, _trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.post_train_hook(cfg, trainer)
14 changes: 14 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,20 @@ def validate_config(cfg):
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)

if cfg.noisy_embedding_alpha is not None:
# Deprecated, use neftune_noise_alpha
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
if cfg.neftune_noise_alpha is None:
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
else:
# User is providing both; bail and have them sort out their settings
raise ValueError(
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
)

if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down

0 comments on commit ef24342

Please sign in to comment.