Skip to content

Commit

Permalink
prepare plugins needs to happen so registration can occur to build th…
Browse files Browse the repository at this point in the history
…e plugin args (#2119)

* prepare plugins needs to happen so registration can occur to build the plugin args

use yaml.dump

include dataset and more assertions

* attempt to manually register plugins rather than use fn

* fix fixture

* remove fixture

* move cli test to patched dir

* fix cce validation
  • Loading branch information
winglian authored Dec 3, 2024
1 parent 1ef7031 commit d87df2c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
6 changes: 3 additions & 3 deletions scripts/cutcrossentropy_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
sys.exit(0)

cce_spec = importlib.util.find_spec("cut_cross_entropy")
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")

UNINSTALL_PREFIX = ""
if cce_spec and not cce_spec_transformers:
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
if cce_spec:
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "

print(
UNINSTALL_PREFIX
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None

prepare_plugins(cfg)

cfg = validate_config(
cfg,
capabilities={
Expand All @@ -444,8 +446,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
},
)

prepare_plugins(cfg)

prepare_optim_env(cfg)

prepare_opinionated_env(cfg)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/integrations/cut_cross_entropy/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CutCrossEntropyArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def check_dtype_is_half(cls, data):
if not (data.get("bf16") or data.get("fp16")):
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
raise ValueError(
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
"Please set `bf16` or `fp16` to `True`."
Expand Down
47 changes: 47 additions & 0 deletions tests/e2e/patched/test_cli_integrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
test cases to make sure the plugin args are loaded from the config file
"""
from pathlib import Path

import yaml

from axolotl.cli import load_cfg
from axolotl.utils.dict import DictDefault


# pylint: disable=duplicate-code
class TestPluginArgs:
"""
test class for plugin args loaded from the config file
"""

def test_liger_plugin_args(self, temp_dir):
test_cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
"liger_layer_norm": True,
"liger_rope": True,
"liger_rms_norm": False,
"liger_glu_activation": True,
"liger_fused_linear_cross_entropy": True,
}
)

with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(test_cfg.to_dict()))
cfg = load_cfg(str(Path(temp_dir) / "config.yaml"))
assert cfg.liger_layer_norm is True
assert cfg.liger_rope is True
assert cfg.liger_rms_norm is False
assert cfg.liger_glu_activation is True
assert cfg.liger_fused_linear_cross_entropy is True

0 comments on commit d87df2c

Please sign in to comment.