Skip to content

Commit

Permalink
Better error
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Sep 6, 2023
1 parent e0c0a0c commit ed85c06
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/accelerate/commands/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,11 @@ def from_json_file(cls, json_file=None):
config_dict["use_cpu"] = False
if "debug" not in config_dict:
config_dict["debug"] = False
extra_keys = set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
if len(extra_keys) > 0:
raise ValueError(
f"Unknown keys in the config file: {list(extra_keys)}, please try upgrading your `accelerate` version or remove them."
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
" version or fix (and potentially remove) these keys from your config file."
)

return cls(**config_dict)
Expand Down Expand Up @@ -143,10 +144,11 @@ def from_yaml_file(cls, yaml_file=None):
config_dict["use_cpu"] = False
if "debug" not in config_dict:
config_dict["debug"] = False
extra_keys = set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())
extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys()))
if len(extra_keys) > 0:
raise ValueError(
f"Unknown keys in the config file: {list(extra_keys)}, please try upgrading your `accelerate` version or remove them."
f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`"
" version or fix (and potentially remove) these keys from your config file."
)
return cls(**config_dict)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def test_config_compatibility(self):

def test_invalid_keys(self):
with self.assertRaises(
RuntimeError, msg="Unknown keys in the config file: ['another_invalid_key', 'invalid_key']"
RuntimeError,
msg="The config file at 'invalid_keys.yaml' had unknown keys ('another_invalid_key', 'invalid_key')",
):
execute_subprocess_async(
self.base_cmd
Expand Down

0 comments on commit ed85c06

Please sign in to comment.