diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cc162d210a..c74114a176 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -9,7 +9,7 @@ import sys from abc import abstractmethod from dataclasses import dataclass, field -from functools import partial +from functools import partial, wraps from pathlib import Path from typing import Optional @@ -120,6 +120,7 @@ class AxolotlTrainer(Trainer): """ args = None # type: AxolotlTrainingArguments + tag_names = ["axolotl"] def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs): self.num_epochs = num_epochs @@ -290,12 +291,41 @@ def compute_loss(self, model, inputs, return_outputs=False): # return (loss, outputs) if return_outputs else loss return super().compute_loss(model, inputs, return_outputs=return_outputs) + def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = self._sanitize_kwargs_for_tagging( + tag_names=self.tag_names, kwargs=kwargs + ) + + return super().push_to_hub(*args, **kwargs) + class AxolotlMambaTrainer(AxolotlTrainer): """ Mamba specific trainer to handle loss calculation """ + tag_names = ["axolotl", "mamba"] + def compute_loss( self, model, @@ -322,6 +352,8 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer): Trainer subclass that uses the OneCycleLR scheduler """ + tag_names = ["axolotl", "onecycle"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lr_scheduler = None @@ -351,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer): Trainer subclass that uses the OneCycleLR scheduler """ + tag_names = ["axolotl", "relora"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.lr_scheduler = None