Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD: warning hub model #1301

Merged
15 changes: 7 additions & 8 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ def legacy_validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)

if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
)

if cfg.gptq and cfg.revision_of_model:
Expand Down Expand Up @@ -448,10 +448,14 @@ def legacy_validate_config(cfg):
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
Expand All @@ -464,11 +468,6 @@ def legacy_validate_config(cfg):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)

if (
cfg.evaluation_strategy
and cfg.eval_steps
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,11 +780,11 @@ def check_saves(cls, data):
@model_validator(mode="before")
@classmethod
def check_push_save(cls, data):
if data.get("hub_model_id") and not (
data.get("save_steps") or data.get("saves_per_epoch")
if data.get("hub_model_id") and (
data.get("save_strategy") not in ["steps", "epoch", None]
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
"hub_model_id is set without any models being saved. To save a model, set save_strategy."
)
return data

Expand Down
48 changes: 41 additions & 7 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,17 +1067,51 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg):
):
validate_config(cfg)

def test_hub_model_id_save_value_warns(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)
assert len(self._caplog.records) == 1

def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 1

def test_hub_model_id_save_value_steps(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "steps"})
| minimal_cfg
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_epochs(self, minimal_cfg):
cfg = (
DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})
| minimal_cfg
)

def test_hub_model_id_save_value(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
Expand Down
Loading