diff --git a/docker/Dockerfile-cloud b/docker/Dockerfile-cloud index d7e3277d2f..c8249cb79c 100644 --- a/docker/Dockerfile-cloud +++ b/docker/Dockerfile-cloud @@ -2,7 +2,7 @@ ARG BASE_TAG=main FROM axolotlai/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" -ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" +ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/docker/Dockerfile-cloud-no-tmux b/docker/Dockerfile-cloud-no-tmux index 6dfea46779..1650631050 100644 --- a/docker/Dockerfile-cloud-no-tmux +++ b/docker/Dockerfile-cloud-no-tmux @@ -2,7 +2,7 @@ ARG BASE_TAG=main FROM axolotlai/axolotl:$BASE_TAG ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets" -ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub" +ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub" ENV HF_HOME="/workspace/data/huggingface-cache/hub" ENV HF_HUB_ENABLE_HF_TRANSFER="1" diff --git a/requirements.txt b/requirements.txt index d100139ca4..864beb9b13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.14.0 -transformers==4.46.3 +transformers==4.47.0 tokenizers>=0.20.1 bitsandbytes==0.45.0 -accelerate==1.1.0 +accelerate==1.2.0 datasets==3.1.0 deepspeed==0.15.4 pydantic==2.6.3 @@ -42,7 +42,7 @@ s3fs>=2024.5.0 gcsfs>=2024.5.0 # adlfs -trl==0.12.0 +trl==0.12.1 zstandard==0.22.0 fastcore diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index e8ef862854..d07b10ce3d 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -442,7 +442,7 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): "compute_capability": gpu_version, }, env_capabilities={ - "torch_version": str(torch.__version__).split("+", maxsplit=1)[0] + "torch_version": str(torch.__version__).split("+", maxsplit=1)[0], }, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 93384189e9..baac94da80 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -957,13 +957,15 @@ def create_accelerator_and_postprocess(self): return res - def log(self, logs: Dict[str, float]) -> None: + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: """ Log `logs` on the various objects watching training, including stored metrics. Args: logs (`Dict[str, float]`): The values to log. + start_time (`Optional[float]`): + The start of training. """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" @@ -971,7 +973,7 @@ def log(self, logs: Dict[str, float]) -> None: for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs) + return super().log(logs, start_time) def store_metrics( self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" @@ -1155,6 +1157,18 @@ def training_step( torch.cuda.empty_cache() return loss + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super(DPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ @@ -1163,6 +1177,18 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): tag_names = ["axolotl", "orpo"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): """ @@ -1171,6 +1197,45 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): tag_names = ["axolotl", "kto"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # train metrics should have no prefix, eval should have 'eval_' + prefix = "eval_" if train_eval == "eval" else "" + # accumulate average metrics from sums and lengths + for split in ["chosen", "rejected"]: + if f"count/{split}" in self._stored_metrics[train_eval]: + count_sum = ( + torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]) + .sum() + .item() + ) + for metric in ["rewards", "logps", "logits"]: + logs[f"{prefix}{metric}/{split}"] = ( + torch.Tensor( + self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + ) + .sum() + .item() + / count_sum + ) + # delete obsolete metric + del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + del self._stored_metrics[train_eval][f"count/{split}"] + # calculate reward margin + if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: + logs[f"{prefix}rewards/margins"] = ( + logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] + ) + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super(KTOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): """ @@ -1179,6 +1244,18 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): tag_names = ["axolotl", "cpo"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super(CPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): """ @@ -1187,6 +1264,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): tag_names = ["axolotl", "reward"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + return super(RewardTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + class TrainerBuilderBase(abc.ABC): """ diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py new file mode 100644 index 0000000000..5ee90f91ae --- /dev/null +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -0,0 +1,207 @@ +""" +fix for FSDP gradient accumulation +see https://github.com/huggingface/transformers/pull/35128 +""" +import inspect + +from accelerate.logging import get_logger +from transformers import LlamaForCausalLM +from transformers.trainer import Trainer + +from axolotl.monkeypatch.unsloth_ import detab_code + +LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum") + +ORIGINAL_CONTEXT_CODE = """ + with self.compute_loss_context_manager(): + if self.model_accepts_loss_kwargs: + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) +""" + +PATCHED_CONTEXT_CODE = """ + with self.compute_loss_context_manager(): + if self.model_accepts_loss_kwargs: + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + else: + loss = self.compute_loss(model, inputs) +""" + +ORIGINAL_LLAMA_FCLM_CODE = """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) +""" + +PATCHED_LLAMA_FCLM_CODE = """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention + num_items_in_batch = kwargs.pop("num_items_in_batch", None) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs) +""" + + +def get_training_step_code() -> str: + training_step = inspect.getsource( + Trainer.training_step # pylint: disable=protected-access + ) + return training_step + + +def check_training_step_is_patchable() -> bool: + training_step = get_training_step_code() + training_step, _ = detab_code(training_step) + return ORIGINAL_CONTEXT_CODE in training_step + + +def patch_training_step_for_ga(): + """ + monkeypatch for fixing the training loop for gradient accumulation + """ + + try: + training_step = get_training_step_code() + except OSError: + return + Trainer._original_training_step = training_step # pylint: disable=protected-access + training_step, _ = detab_code(training_step) + if ORIGINAL_CONTEXT_CODE not in training_step: + return + # assert ( + # ORIGINAL_CONTEXT_CODE in training_step + # ), "Original training_step code not found" + + training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE) + training_step = training_step.replace( + "def training_step(", + "def _fixed_training_step(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in training_step: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(training_step, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching training_step", main_process_only=True) + Trainer.training_step = ( # pylint: disable=protected-access + _fixed_training_step # pylint: disable=undefined-variable # noqa: F821 + ) + + +def get_model_forward_code() -> str: + forward = inspect.getsource( + LlamaForCausalLM.forward # pylint: disable=protected-access + ) + return forward + + +def check_forward_is_patchable() -> bool: + forward = get_model_forward_code() + forward, _ = detab_code(forward) + return ORIGINAL_LLAMA_FCLM_CODE in forward + + +def patch_forward_for_ga(): + """ + monkeypatch for fixing the training loop for gradient accumulation + """ + + try: + forward = get_model_forward_code() + except OSError: + return + LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access + forward, _ = detab_code(forward) + if ORIGINAL_LLAMA_FCLM_CODE not in forward: + return + # assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found" + + forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE) + forward = forward.replace( + "def forward(", + "def _fixed_forward(", + 1, + ) + + # load imports necessary + import transformers.models.llama.modeling_llama + + items_to_import = [] + for item in dir(transformers.models.llama.modeling_llama): + if item in forward: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.models.llama.modeling_llama import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(forward, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching forward", main_process_only=True) + LlamaForCausalLM.forward = ( # pylint: disable=protected-access + _fixed_forward # pylint: disable=undefined-variable # noqa: F821 + ) diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 38bbdc88fb..7358803ba1 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -9,10 +9,7 @@ from accelerate.logging import get_logger from peft import PeftModelForCausalLM from torch import nn -from transformers.models.llama.modeling_llama import ( - LlamaFlashAttention2, - LlamaForCausalLM, -) +from transformers.models.llama.modeling_llama import LlamaFlashAttention2 LOG = get_logger("axolotl.monkeypatch.unsloth") @@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states): return attn_output -def get_forward_code() -> str: - forward = inspect.getsource(LlamaForCausalLM.forward) - return forward - - def get_self_attn_code() -> str: forward = inspect.getsource(LlamaFlashAttention2.forward) return forward @@ -102,8 +94,11 @@ def UnslothForCausalLMLoss( # pylint: disable=invalid-name def detab_code(code: str) -> Tuple[str, str]: - spaces = re.match(r"([\s\t]{1,})", code).group(0) - code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE) + try: + spaces = re.match(r"([\s\t]{1,})", code).group(0) + code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE) + except AttributeError: + return code, "" return code, spaces diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3e120cca60..f2ee93c3c7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -386,6 +386,15 @@ def apply_patches(self) -> None: if self.cfg.flash_attention: self.patch_attention() + if self.cfg.model_config_type == "llama": + from axolotl.monkeypatch.trainer_grad_accum import ( + patch_forward_for_ga, + patch_training_step_for_ga, + ) + + patch_forward_for_ga() + patch_training_step_for_ga() + if self.cfg.sample_packing and self.cfg.s2_attention: raise ValueError( "Received `sample_packing=true` and `s2_attention=true`; however, \ diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 8e0d03380f..b58406185a 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -36,6 +36,9 @@ def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing): "sequence_len": 1024, "sample_packing": sample_packing, "flash_attention": True, + "unsloth_lora_mlp": True, + "unsloth_lora_qkv": True, + "unsloth_lora_o": True, "load_in_4bit": True, "adapter": "qlora", "lora_r": 16, @@ -82,6 +85,9 @@ def test_unsloth_llama_qlora_unpacked(self, temp_dir): { "base_model": "HuggingFaceTB/SmolLM2-135M", "sequence_len": 1024, + "unsloth_lora_mlp": True, + "unsloth_lora_qkv": True, + "unsloth_lora_o": True, "sample_packing": False, "load_in_4bit": True, "adapter": "qlora", @@ -133,6 +139,9 @@ def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention) { "base_model": "HuggingFaceTB/SmolLM2-135M", "sequence_len": 1024, + "unsloth_lora_mlp": True, + "unsloth_lora_qkv": True, + "unsloth_lora_o": True, "sample_packing": False, "load_in_4bit": True, "adapter": "qlora", diff --git a/tests/patched/test_llama_trainer_ga.py b/tests/patched/test_llama_trainer_ga.py new file mode 100644 index 0000000000..58c229cf34 --- /dev/null +++ b/tests/patched/test_llama_trainer_ga.py @@ -0,0 +1,25 @@ +""""Test module for checking whether the Hugging Face Transformers is working as expected.""" +import unittest + +from axolotl.monkeypatch.trainer_grad_accum import ( + check_forward_is_patchable, + check_training_step_is_patchable, +) + + +class TestTrainerGAIntegration(unittest.TestCase): + """llama monkeypatch integration tests.""" + + def test_train_step_patchable(self): + # ensures the current version of transformers has loss code that matches our patching code + self.assertTrue( + check_training_step_is_patchable(), + "HF transformers Trainer.training_step has changed and isn't patchable", + ) + + def test_model_forward_patchable(self): + # ensures the current version of transformers has loss code that matches our patching code + self.assertTrue( + check_forward_is_patchable(), + "HF transformers LlamaForCausalLM.forward has changed and isn't patchable", + )