Skip to content

Commit

Permalink
caseus fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxime committed Oct 13, 2023
1 parent f9a831b commit 6048973
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/axolotl/monkeypatch/llama_embeddings_hijack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""

import torch
import transformers.models.llama.modeling_llama
from transformers.utils import logging
Expand All @@ -6,6 +10,7 @@


def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/monkeypatch/mistral_embeddings_hijack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""

import torch
import transformers.models.mistral.modeling_mistral
from transformers.utils import logging
Expand All @@ -6,6 +10,7 @@


def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
Expand Down

0 comments on commit 6048973

Please sign in to comment.