Skip to content

Commit

Permalink
Add ds model card, rebased (#2101) [skip ci]
Browse files Browse the repository at this point in the history
* rebased add_ds_model_card

* manual rebasing

* fix redundancy

* lint

* include case when ds_tag is none

* conform to kwargs in create_model_card
  • Loading branch information
bursteratom authored Dec 3, 2024
1 parent 822c904 commit ff4794c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
35 changes: 34 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
return kwargs


def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]

if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags

return kwargs


@dataclass
class AxolotlTrainingMixins:
"""
Expand Down Expand Up @@ -418,10 +434,12 @@ def __init__(
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Expand Down Expand Up @@ -919,6 +937,9 @@ def push_to_hub(self, *args, **kwargs) -> str:
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)

return super().push_to_hub(*args, **kwargs)
Expand Down Expand Up @@ -1042,8 +1063,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):

tag_names = ["axolotl", "dpo"]

def __init__(self, *args, **kwargs):
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None

def create_optimizer(self):
Expand Down Expand Up @@ -1082,6 +1104,9 @@ def push_to_hub(self, *args, **kwargs) -> str:
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)

return super().push_to_hub(*args, **kwargs)
Expand Down Expand Up @@ -1806,6 +1831,10 @@ def build(self, total_num_steps):
else:
trainer_kwargs["tokenizer"] = self.tokenizer

if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
Expand Down Expand Up @@ -2079,6 +2108,10 @@ def build(self, total_num_steps):
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer

if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
Expand Down
28 changes: 24 additions & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,31 @@ def terminate_handler(_, __, model_weakref):
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

try:
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)

model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
model_card_kwarg["dataset_name"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
else:
model_card_kwarg["dataset_tags"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]

trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
Expand Down

0 comments on commit ff4794c

Please sign in to comment.