diff --git a/README.md b/README.md index f22ccb5939..57447b36de 100644 --- a/README.md +++ b/README.md @@ -672,6 +672,11 @@ adam_epsilon: # Gradient clipping max norm max_grad_norm: +# Augmentation techniques +# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings +# currently only supported on Llama and Mistral +noisy_embedding_alpha: + # Whether to bettertransformers flash_optimum: # Whether to use xformers attention patch https://github.com/facebookresearch/xformers: diff --git a/src/axolotl/monkeypatch/llama_embeddings_hijack.py b/src/axolotl/monkeypatch/llama_embeddings_hijack.py new file mode 100644 index 0000000000..654ca3ba82 --- /dev/null +++ b/src/axolotl/monkeypatch/llama_embeddings_hijack.py @@ -0,0 +1,40 @@ +""" +patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 +""" + +import torch +import transformers.models.llama.modeling_llama +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): + # pylint: disable=duplicate-code + def noised_embed(orig_embed, noise_alpha, model): + def new_func(input_ids): + # during training, we add noise to the embedding + # during generation, we don't add noise to the embedding + if model.training: + embed_init = orig_embed(input_ids) + dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) + mag_norm = noise_alpha / torch.sqrt(dims) + return embed_init + torch.zeros_like(embed_init).uniform_( + -mag_norm, mag_norm + ) + return orig_embed(input_ids) + + return new_func + + def post_init(orig_post_init): + def new_func(self): + orig_post_init(self) + self.embed_tokens.forward = noised_embed( + self.embed_tokens.forward, noise_alpha, self + ) + + return new_func + + transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init( + transformers.models.llama.modeling_llama.LlamaModel.post_init + ) diff --git a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py new file mode 100644 index 0000000000..ed5f259650 --- /dev/null +++ b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py @@ -0,0 +1,40 @@ +""" +patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 +""" + +import torch +import transformers.models.mistral.modeling_mistral +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5): + # pylint: disable=duplicate-code + def noised_embed(orig_embed, noise_alpha, model): + def new_func(input_ids): + # during training, we add noise to the embedding + # during generation, we don't add noise to the embedding + if model.training: + embed_init = orig_embed(input_ids) + dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) + mag_norm = noise_alpha / torch.sqrt(dims) + return embed_init + torch.zeros_like(embed_init).uniform_( + -mag_norm, mag_norm + ) + return orig_embed(input_ids) + + return new_func + + def post_init(orig_post_init): + def new_func(self): + orig_post_init(self) + self.embed_tokens.forward = noised_embed( + self.embed_tokens.forward, noise_alpha, self + ) + + return new_func + + transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init( + transformers.models.mistral.modeling_mistral.MistralModel.post_init + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c60f00c2b..c133e9eb61 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -180,6 +180,26 @@ def load_model( LOG.info("patching with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) + if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: + from axolotl.monkeypatch.llama_embeddings_hijack import ( + replace_llama_embeddings_with_uniform_distribution, + ) + + LOG.info("patching with noisy embeddings") + replace_llama_embeddings_with_uniform_distribution( + noise_alpha=cfg.noisy_embedding_alpha + ) + + if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: + from axolotl.monkeypatch.mistral_embeddings_hijack import ( + replace_mistral_embeddings_with_uniform_distribution, + ) + + LOG.info("patching with noisy embeddings") + replace_mistral_embeddings_with_uniform_distribution( + noise_alpha=cfg.noisy_embedding_alpha + ) + if cfg.is_llama_derived_model and cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( replace_llama_rope_with_xpos_rope,