Skip to content

Commit

Permalink
ADD: warning if hub_model_id ist set but not any save strategy (#1202)
Browse files Browse the repository at this point in the history
* warning if hub model id set but no save

* add warning

* move the warning

* add test

* allow more public methods for tests for now

* fix tests

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
JohanWork and winglian authored Jan 26, 2024
1 parent 1b18003 commit af29d81
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)

if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)

if cfg.gptq and cfg.model_revision:
raise ValueError(
"model_revision is not supported for GPTQ models. "
Expand Down
17 changes: 17 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def inject_fixtures(self, caplog):
self._caplog = caplog


# pylint: disable=too-many-public-methods
class ValidationTest(BaseValidation):
"""
Test the validation module
Expand Down Expand Up @@ -698,6 +699,22 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self):
):
validate_config(cfg)

def test_hub_model_id_save_value_warns(self):
cfg = DictDefault({"hub_model_id": "test"})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)

def test_hub_model_id_save_value(self):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0


class ValidationCheckModelConfig(BaseValidation):
"""
Expand Down

0 comments on commit af29d81

Please sign in to comment.