Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD: warning if hub_model_id ist set but not any save strategy #1202

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)

if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
Copy link
Collaborator

@NanoCode012 NanoCode012 Feb 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also check for cfg.save_strategy and update the log message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah your right @NanoCode012 . The test should check for save_strategy, and the comments should be something like:

hub_model_id is set without any models being saved. To save a model, set save_strategy to steps or leave empty. Dose that sound right? If so I can create a pr and update.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Please feel free to tag me for review

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused both save_on_epochs and save_steps needs save_strategy to be steps see: https://github.com/OpenAccess-AI-Collective/axolotl/blob/5a5d47458d9aaf7ead798d15291ba3d9bef785c5/src/axolotl/utils/config.py#L426
https://github.com/OpenAccess-AI-Collective/axolotl/blob/5a5d47458d9aaf7ead798d15291ba3d9bef785c5/src/axolotl/utils/config.py#L442

However https://github.com/OpenAccess-AI-Collective/axolotl/blob/5a5d47458d9aaf7ead798d15291ba3d9bef785c5/README.md?plain=1#L777 says nothing about that, it only says set to no to skip. Should we also maybe update the comment to be somthing like set to 'steps' to save checkpoints?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save_strategy and save_steps are default HF Trainer Arguments. save_on_epochs is just a handy utility by axolotl that auto computes the save_steps based on total steps, hence needing steps.

This is the doc for save_strategy: https://huggingface.co/docs/transformers/v4.37.2/en/main_classes/trainer#transformers.TrainingArguments.save_strategy

LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)

if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
Expand Down
17 changes: 17 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def inject_fixtures(self, caplog):
self._caplog = caplog


# pylint: disable=too-many-public-methods
class ValidationTest(BaseValidation):
"""
Test the validation module
Expand Down Expand Up @@ -698,6 +699,22 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self):
):
validate_config(cfg)

def test_hub_model_id_save_value_warns(self):
cfg = DictDefault({"hub_model_id": "test"})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)

def test_hub_model_id_save_value(self):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0


class ValidationCheckModelConfig(BaseValidation):
"""
Expand Down
Loading