From 636ad606468c2214b89bc86ccd1e45759b14030c Mon Sep 17 00:00:00 2001 From: Maxime Date: Fri, 13 Oct 2023 07:43:11 +0000 Subject: [PATCH 1/6] add noisy embedding --- .../monkeypatch/llama_embeddings_hijack.py | 27 +++++++++++++++++++ .../monkeypatch/mistral_embeddings_hijack.py | 27 +++++++++++++++++++ src/axolotl/utils/models.py | 16 +++++++++++ 3 files changed, 70 insertions(+) create mode 100644 src/axolotl/monkeypatch/llama_embeddings_hijack.py create mode 100644 src/axolotl/monkeypatch/mistral_embeddings_hijack.py diff --git a/src/axolotl/monkeypatch/llama_embeddings_hijack.py b/src/axolotl/monkeypatch/llama_embeddings_hijack.py new file mode 100644 index 0000000000..7b62efb2b7 --- /dev/null +++ b/src/axolotl/monkeypatch/llama_embeddings_hijack.py @@ -0,0 +1,27 @@ +import torch +from transformers.utils import logging +import transformers.models.llama.modeling_llama + +logger = logging.get_logger(__name__) + +def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): + def noised_embed(orig_embed, noise_alpha, model): + def new_func(x): + # during training, we add noise to the embedding + # during generation, we don't add noise to the embedding + if model.training: + embed_init = orig_embed(x) + dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) + mag_norm = noise_alpha/torch.sqrt(dims) + return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm) + else: + return orig_embed(x) + return new_func + + def post_init(orig_post_init): + def new_func(self): + orig_post_init(self) + self.embed_tokens.forward = noised_embed(self.embed_tokens.forward, noise_alpha, self) + return new_func + + transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(transformers.models.llama.modeling_llama.LlamaModel.post_init) \ No newline at end of file diff --git a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py new file mode 100644 index 0000000000..e8aad081af --- /dev/null +++ b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py @@ -0,0 +1,27 @@ +import torch +from transformers.utils import logging +import transformers.models.mistral.modeling_mistral + +logger = logging.get_logger(__name__) + +def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5): + def noised_embed(orig_embed, noise_alpha, model): + def new_func(x): + # during training, we add noise to the embedding + # during generation, we don't add noise to the embedding + if model.training: + embed_init = orig_embed(x) + dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) + mag_norm = noise_alpha/torch.sqrt(dims) + return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm) + else: + return orig_embed(x) + return new_func + + def post_init(orig_post_init): + def new_func(self): + orig_post_init(self) + self.embed_tokens.forward = noised_embed(self.embed_tokens.forward, noise_alpha, self) + return new_func + + transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(transformers.models.mistral.modeling_mistral.MistralModel.post_init) \ No newline at end of file diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c60f00c2b..6f5931b504 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -179,6 +179,22 @@ def load_model( LOG.info("patching with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) + + if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: + from axolotl.monkeypatch.llama_embeddings_hijack import ( + replace_llama_embeddings_with_uniform_distribution, + ) + + LOG.info("patching with noisy embeddings") + replace_llama_embeddings_with_uniform_distribution(noise_alpha=cfg.noisy_embedding_alpha) + + if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: + from axolotl.monkeypatch.mistral_embeddings_hijack import ( + replace_mistral_embeddings_with_uniform_distribution, + ) + + LOG.info("patching with noisy embeddings") + replace_mistral_embeddings_with_uniform_distribution(noise_alpha=cfg.noisy_embedding_alpha) if cfg.is_llama_derived_model and cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( From 08147f786b4da8081e5474b305eb155964c9c855 Mon Sep 17 00:00:00 2001 From: Maxime Date: Fri, 13 Oct 2023 11:48:56 +0000 Subject: [PATCH 2/6] fix format --- .../monkeypatch/llama_embeddings_hijack.py | 21 +++++++++++++------ .../monkeypatch/mistral_embeddings_hijack.py | 21 +++++++++++++------ src/axolotl/utils/models.py | 12 +++++++---- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_embeddings_hijack.py b/src/axolotl/monkeypatch/llama_embeddings_hijack.py index 7b62efb2b7..3e3e0b0380 100644 --- a/src/axolotl/monkeypatch/llama_embeddings_hijack.py +++ b/src/axolotl/monkeypatch/llama_embeddings_hijack.py @@ -1,9 +1,10 @@ import torch -from transformers.utils import logging import transformers.models.llama.modeling_llama +from transformers.utils import logging logger = logging.get_logger(__name__) + def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): def noised_embed(orig_embed, noise_alpha, model): def new_func(x): @@ -12,16 +13,24 @@ def new_func(x): if model.training: embed_init = orig_embed(x) dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) - mag_norm = noise_alpha/torch.sqrt(dims) - return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm) + mag_norm = noise_alpha / torch.sqrt(dims) + return embed_init + torch.zeros_like(embed_init).uniform_( + -mag_norm, mag_norm + ) else: return orig_embed(x) + return new_func - + def post_init(orig_post_init): def new_func(self): orig_post_init(self) - self.embed_tokens.forward = noised_embed(self.embed_tokens.forward, noise_alpha, self) + self.embed_tokens.forward = noised_embed( + self.embed_tokens.forward, noise_alpha, self + ) + return new_func - transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(transformers.models.llama.modeling_llama.LlamaModel.post_init) \ No newline at end of file + transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init( + transformers.models.llama.modeling_llama.LlamaModel.post_init + ) diff --git a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py index e8aad081af..ffeea914e8 100644 --- a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py +++ b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py @@ -1,9 +1,10 @@ import torch -from transformers.utils import logging import transformers.models.mistral.modeling_mistral +from transformers.utils import logging logger = logging.get_logger(__name__) + def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5): def noised_embed(orig_embed, noise_alpha, model): def new_func(x): @@ -12,16 +13,24 @@ def new_func(x): if model.training: embed_init = orig_embed(x) dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) - mag_norm = noise_alpha/torch.sqrt(dims) - return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm) + mag_norm = noise_alpha / torch.sqrt(dims) + return embed_init + torch.zeros_like(embed_init).uniform_( + -mag_norm, mag_norm + ) else: return orig_embed(x) + return new_func - + def post_init(orig_post_init): def new_func(self): orig_post_init(self) - self.embed_tokens.forward = noised_embed(self.embed_tokens.forward, noise_alpha, self) + self.embed_tokens.forward = noised_embed( + self.embed_tokens.forward, noise_alpha, self + ) + return new_func - transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(transformers.models.mistral.modeling_mistral.MistralModel.post_init) \ No newline at end of file + transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init( + transformers.models.mistral.modeling_mistral.MistralModel.post_init + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6f5931b504..c133e9eb61 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -179,22 +179,26 @@ def load_model( LOG.info("patching with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - + if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: from axolotl.monkeypatch.llama_embeddings_hijack import ( replace_llama_embeddings_with_uniform_distribution, ) LOG.info("patching with noisy embeddings") - replace_llama_embeddings_with_uniform_distribution(noise_alpha=cfg.noisy_embedding_alpha) - + replace_llama_embeddings_with_uniform_distribution( + noise_alpha=cfg.noisy_embedding_alpha + ) + if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: from axolotl.monkeypatch.mistral_embeddings_hijack import ( replace_mistral_embeddings_with_uniform_distribution, ) LOG.info("patching with noisy embeddings") - replace_mistral_embeddings_with_uniform_distribution(noise_alpha=cfg.noisy_embedding_alpha) + replace_mistral_embeddings_with_uniform_distribution( + noise_alpha=cfg.noisy_embedding_alpha + ) if cfg.is_llama_derived_model and cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( From 8a2029de420b08342c51d56926596d771df87751 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Fri, 13 Oct 2023 13:53:32 +0200 Subject: [PATCH 3/6] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index f22ccb5939..98b535679d 100644 --- a/README.md +++ b/README.md @@ -672,6 +672,11 @@ adam_epsilon: # Gradient clipping max norm max_grad_norm: +# Augmentation techniques +# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings +# currently only supported on Llama and Mistral +noisy_embedding_alpha: + # Whether to bettertransformers flash_optimum: # Whether to use xformers attention patch https://github.com/facebookresearch/xformers: From ef6016c825db89d697c8fda7d86efb25560753ba Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Fri, 13 Oct 2023 13:59:50 +0200 Subject: [PATCH 4/6] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 98b535679d..57447b36de 100644 --- a/README.md +++ b/README.md @@ -675,7 +675,7 @@ max_grad_norm: # Augmentation techniques # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings # currently only supported on Llama and Mistral -noisy_embedding_alpha: +noisy_embedding_alpha: # Whether to bettertransformers flash_optimum: From 779466d3ffec12945d036900e04108e1c4767fc5 Mon Sep 17 00:00:00 2001 From: Maxime Date: Fri, 13 Oct 2023 12:06:34 +0000 Subject: [PATCH 5/6] linter issues --- src/axolotl/monkeypatch/llama_embeddings_hijack.py | 7 +++---- src/axolotl/monkeypatch/mistral_embeddings_hijack.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_embeddings_hijack.py b/src/axolotl/monkeypatch/llama_embeddings_hijack.py index 3e3e0b0380..7647dc6743 100644 --- a/src/axolotl/monkeypatch/llama_embeddings_hijack.py +++ b/src/axolotl/monkeypatch/llama_embeddings_hijack.py @@ -7,18 +7,17 @@ def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): def noised_embed(orig_embed, noise_alpha, model): - def new_func(x): + def new_func(input_ids): # during training, we add noise to the embedding # during generation, we don't add noise to the embedding if model.training: - embed_init = orig_embed(x) + embed_init = orig_embed(input_ids) dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) mag_norm = noise_alpha / torch.sqrt(dims) return embed_init + torch.zeros_like(embed_init).uniform_( -mag_norm, mag_norm ) - else: - return orig_embed(x) + return orig_embed(input_ids) return new_func diff --git a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py index ffeea914e8..9cfb0c6484 100644 --- a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py +++ b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py @@ -7,18 +7,17 @@ def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5): def noised_embed(orig_embed, noise_alpha, model): - def new_func(x): + def new_func(input_ids): # during training, we add noise to the embedding # during generation, we don't add noise to the embedding if model.training: - embed_init = orig_embed(x) + embed_init = orig_embed(input_ids) dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) mag_norm = noise_alpha / torch.sqrt(dims) return embed_init + torch.zeros_like(embed_init).uniform_( -mag_norm, mag_norm ) - else: - return orig_embed(x) + return orig_embed(input_ids) return new_func From 60489738009db1520d7e9d42af663e0d4ede0509 Mon Sep 17 00:00:00 2001 From: Maxime Date: Fri, 13 Oct 2023 12:26:48 +0000 Subject: [PATCH 6/6] caseus fixes --- src/axolotl/monkeypatch/llama_embeddings_hijack.py | 5 +++++ src/axolotl/monkeypatch/mistral_embeddings_hijack.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/src/axolotl/monkeypatch/llama_embeddings_hijack.py b/src/axolotl/monkeypatch/llama_embeddings_hijack.py index 7647dc6743..654ca3ba82 100644 --- a/src/axolotl/monkeypatch/llama_embeddings_hijack.py +++ b/src/axolotl/monkeypatch/llama_embeddings_hijack.py @@ -1,3 +1,7 @@ +""" +patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 +""" + import torch import transformers.models.llama.modeling_llama from transformers.utils import logging @@ -6,6 +10,7 @@ def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): + # pylint: disable=duplicate-code def noised_embed(orig_embed, noise_alpha, model): def new_func(input_ids): # during training, we add noise to the embedding diff --git a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py index 9cfb0c6484..ed5f259650 100644 --- a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py +++ b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py @@ -1,3 +1,7 @@ +""" +patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 +""" + import torch import transformers.models.mistral.modeling_mistral from transformers.utils import logging @@ -6,6 +10,7 @@ def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5): + # pylint: disable=duplicate-code def noised_embed(orig_embed, noise_alpha, model): def new_func(input_ids): # during training, we add noise to the embedding