-
-
Notifications
You must be signed in to change notification settings - Fork 899
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
prepare plugins needs to happen so registration can occur to build th…
…e plugin args (#2119) * prepare plugins needs to happen so registration can occur to build the plugin args use yaml.dump include dataset and more assertions * attempt to manually register plugins rather than use fn * fix fixture * remove fixture * move cli test to patched dir * fix cce validation
- Loading branch information
Showing
4 changed files
with
53 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
test cases to make sure the plugin args are loaded from the config file | ||
""" | ||
from pathlib import Path | ||
|
||
import yaml | ||
|
||
from axolotl.cli import load_cfg | ||
from axolotl.utils.dict import DictDefault | ||
|
||
|
||
# pylint: disable=duplicate-code | ||
class TestPluginArgs: | ||
""" | ||
test class for plugin args loaded from the config file | ||
""" | ||
|
||
def test_liger_plugin_args(self, temp_dir): | ||
test_cfg = DictDefault( | ||
{ | ||
"base_model": "HuggingFaceTB/SmolLM2-135M", | ||
"learning_rate": 0.000001, | ||
"micro_batch_size": 1, | ||
"gradient_accumulation_steps": 1, | ||
"datasets": [ | ||
{ | ||
"path": "mhenrichsen/alpaca_2k_test", | ||
"type": "alpaca", | ||
}, | ||
], | ||
"plugins": ["axolotl.integrations.liger.LigerPlugin"], | ||
"liger_layer_norm": True, | ||
"liger_rope": True, | ||
"liger_rms_norm": False, | ||
"liger_glu_activation": True, | ||
"liger_fused_linear_cross_entropy": True, | ||
} | ||
) | ||
|
||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: | ||
fout.write(yaml.dump(test_cfg.to_dict())) | ||
cfg = load_cfg(str(Path(temp_dir) / "config.yaml")) | ||
assert cfg.liger_layer_norm is True | ||
assert cfg.liger_rope is True | ||
assert cfg.liger_rms_norm is False | ||
assert cfg.liger_glu_activation is True | ||
assert cfg.liger_fused_linear_cross_entropy is True |