diff --git a/src/axolotl/monkeypatch/fused_modules.py b/src/axolotl/monkeypatch/fused_modules.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 386f4bfac4..f0fa807fa6 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -152,6 +152,7 @@ def _post_training(self, model, name): new_attn.q_proj.weight.data = q_proj new_attn.k_proj.weight.data = k_proj new_attn.v_proj.weight.data = v_proj + new_attn.o_proj.weight.data = self.o_proj.weight.data set_module_name(model, name, new_attn)