Skip to content

Commit

Permalink
Add AxolotlDPOTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
filippo82 committed Jan 25, 2024
1 parent ba944e6 commit 3a1cc2b
Showing 1 changed file with 41 additions and 17 deletions.
58 changes: 41 additions & 17 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@
LOG = logging.getLogger("axolotl.core.trainer_builder")


def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names

return kwargs


@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Expand Down Expand Up @@ -349,28 +365,13 @@ def compute_loss(self, model, inputs, return_outputs=False):
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names

return kwargs

@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = self._sanitize_kwargs_for_tagging(
kwargs = _sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)

Expand Down Expand Up @@ -471,6 +472,29 @@ def create_scheduler(
return self.lr_scheduler


class AxolotlDPOTrainer(DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""

tag_names = ["axolotl", "dpo"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)

return super().push_to_hub(*args, **kwargs)


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1076,7 +1100,7 @@ def build(self, total_num_steps):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer = DPOTrainer(
dpo_trainer = AxolotlDPOTrainer(
self.model,
self.model_ref,
args=training_args,
Expand Down

0 comments on commit 3a1cc2b

Please sign in to comment.