Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD: warning hub model #1301

Merged
15 changes: 7 additions & 8 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,9 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)

if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check was made to ensure that one of save_steps/saves_per_epoch/save_strategy was set. In this case, you're missing those conditions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isen't the only thing that matters that you have set a save strategy? https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_strategy. And to not save you have to set it to NO as I have understood it, if you don't set it (None in axolotl) it becomes steps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be changed to check for "no" then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be, but currently the test checks if it is no or any other value. I wanted to check more widely, but can tighten it up if you think it is better @NanoCode012 . Also sorry for slow progress :(

LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
)

if cfg.gptq and cfg.model_revision:
Expand Down Expand Up @@ -423,10 +423,14 @@ def validate_config(cfg):
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
Expand All @@ -439,11 +443,6 @@ def validate_config(cfg):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)

if (
cfg.evaluation_strategy
and cfg.eval_steps
Expand Down
38 changes: 34 additions & 4 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,17 +689,47 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self):
):
validate_config(cfg)

def test_hub_model_id_save_value_warns(self):
cfg = DictDefault({"hub_model_id": "test"})
def test_hub_model_id_save_value_warns_save_stragey_no(self):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)

def test_hub_model_id_save_value_warns_random_value(self):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "test"})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)

def test_hub_model_id_save_value(self):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
def test_hub_model_id_save_value_steps(self):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "steps"})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_epochs(self):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "epoch"})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_none(self):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None})

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0

def test_hub_model_id_save_value_no_set_save_strategy(self):
cfg = DictDefault({"hub_model_id": "test"})
NanoCode012 marked this conversation as resolved.
Show resolved Hide resolved

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
Expand Down