From d87df2c776c03d2486402656327648ad7132ecf2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 3 Dec 2024 15:06:09 -0500 Subject: [PATCH] prepare plugins needs to happen so registration can occur to build the plugin args (#2119) * prepare plugins needs to happen so registration can occur to build the plugin args use yaml.dump include dataset and more assertions * attempt to manually register plugins rather than use fn * fix fixture * remove fixture * move cli test to patched dir * fix cce validation --- scripts/cutcrossentropy_install.py | 6 +-- src/axolotl/cli/__init__.py | 4 +- .../integrations/cut_cross_entropy/args.py | 2 +- tests/e2e/patched/test_cli_integrations.py | 47 +++++++++++++++++++ 4 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 tests/e2e/patched/test_cli_integrations.py diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 3816e58143..d51e2dd99c 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -16,11 +16,11 @@ sys.exit(0) cce_spec = importlib.util.find_spec("cut_cross_entropy") -cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers") UNINSTALL_PREFIX = "" -if cce_spec and not cce_spec_transformers: - UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " +if cce_spec: + if not importlib.util.find_spec("cut_cross_entropy.transformers"): + UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && " print( UNINSTALL_PREFIX diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 4572cfa5c8..1e61b220b9 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -432,6 +432,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): except: # pylint: disable=bare-except # noqa: E722 gpu_version = None + prepare_plugins(cfg) + cfg = validate_config( cfg, capabilities={ @@ -444,8 +446,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): }, ) - prepare_plugins(cfg) - prepare_optim_env(cfg) prepare_opinionated_env(cfg) diff --git a/src/axolotl/integrations/cut_cross_entropy/args.py b/src/axolotl/integrations/cut_cross_entropy/args.py index 9a364e2d3e..c16d91ede2 100644 --- a/src/axolotl/integrations/cut_cross_entropy/args.py +++ b/src/axolotl/integrations/cut_cross_entropy/args.py @@ -33,7 +33,7 @@ class CutCrossEntropyArgs(BaseModel): @model_validator(mode="before") @classmethod def check_dtype_is_half(cls, data): - if not (data.get("bf16") or data.get("fp16")): + if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")): raise ValueError( "Cut Cross Entropy requires fp16/bf16 training for backward pass. " "Please set `bf16` or `fp16` to `True`." diff --git a/tests/e2e/patched/test_cli_integrations.py b/tests/e2e/patched/test_cli_integrations.py new file mode 100644 index 0000000000..6ca7c52aea --- /dev/null +++ b/tests/e2e/patched/test_cli_integrations.py @@ -0,0 +1,47 @@ +""" +test cases to make sure the plugin args are loaded from the config file +""" +from pathlib import Path + +import yaml + +from axolotl.cli import load_cfg +from axolotl.utils.dict import DictDefault + + +# pylint: disable=duplicate-code +class TestPluginArgs: + """ + test class for plugin args loaded from the config file + """ + + def test_liger_plugin_args(self, temp_dir): + test_cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "learning_rate": 0.000001, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "plugins": ["axolotl.integrations.liger.LigerPlugin"], + "liger_layer_norm": True, + "liger_rope": True, + "liger_rms_norm": False, + "liger_glu_activation": True, + "liger_fused_linear_cross_entropy": True, + } + ) + + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(test_cfg.to_dict())) + cfg = load_cfg(str(Path(temp_dir) / "config.yaml")) + assert cfg.liger_layer_norm is True + assert cfg.liger_rope is True + assert cfg.liger_rms_norm is False + assert cfg.liger_glu_activation is True + assert cfg.liger_fused_linear_cross_entropy is True