Skip to content

Commit

Permalink
FEAT: add tagging support to axolotl for DPOTrainer (#1209)
Browse files Browse the repository at this point in the history
* Add AxolotlDPOTrainer

* chore: lint

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
filippo82 and winglian authored Jan 27, 2024
1 parent 0a76c49 commit fe294e5
Showing 1 changed file with 36 additions and 19 deletions.
55 changes: 36 additions & 19 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@
LOG = logging.getLogger("axolotl.core.trainer_builder")


def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names

return kwargs


@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Expand Down Expand Up @@ -349,30 +365,13 @@ def compute_loss(self, model, inputs, return_outputs=False):
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names

return kwargs

@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = self._sanitize_kwargs_for_tagging(
tag_names=self.tag_names, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)

return super().push_to_hub(*args, **kwargs)

Expand Down Expand Up @@ -471,6 +470,24 @@ def create_scheduler(
return self.lr_scheduler


class AxolotlDPOTrainer(DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""

tag_names = ["axolotl", "dpo"]

@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)

return super().push_to_hub(*args, **kwargs)


class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
Expand Down Expand Up @@ -1076,7 +1093,7 @@ def build(self, total_num_steps):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer = DPOTrainer(
dpo_trainer = AxolotlDPOTrainer(
self.model,
self.model_ref,
args=training_args,
Expand Down

0 comments on commit fe294e5

Please sign in to comment.