Skip to content

Commit

Permalink
ADD: warning hub model (#1301)
Browse files Browse the repository at this point in the history
* update warning for save_strategy

* update

* clean up

* update

* Update test_validation.py

* fix validation step

* update

* test_validation

* update

* fix

* fix

---------

Co-authored-by: NanoCode012 <[email protected]>
  • Loading branch information
JohanWork and NanoCode012 authored Apr 30, 2024
1 parent a56e062 commit b7ecc6a
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 18 deletions.
15 changes: 7 additions & 8 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)

if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
)

if cfg.gptq and cfg.revision_of_model:
Expand Down Expand Up @@ -448,10 +448,14 @@ def legacy_validate_config(cfg):
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
Expand All @@ -464,11 +468,6 @@ def legacy_validate_config(cfg):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)

if (
cfg.evaluation_strategy
and cfg.eval_steps
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,11 +780,11 @@ def check_saves(cls, data):
@model_validator(mode="before")
@classmethod
def check_push_save(cls, data):
if data.get("hub_model_id") and not (
data.get("save_steps") or data.get("saves_per_epoch")
if data.get("hub_model_id") and (
data.get("save_strategy") not in ["steps", "epoch", None]
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
)
return data

Expand Down
48 changes: 41 additions & 7 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,17 +1067,51 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg):
):
validate_config(cfg)

def test_hub_model_id_save_value_warns(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)
assert len(self._caplog.records) == 1

def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1

def test_hub_model_id_save_value_steps(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
| minimal_cfg
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_epochs(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
| minimal_cfg
)

def test_hub_model_id_save_value(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
Expand Down

0 comments on commit b7ecc6a

Please sign in to comment.