From f82d21c442d607dad1c1f6c275062c73a0ff316c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 11 Sep 2023 09:42:00 -0400 Subject: [PATCH] refactor scripts/finetune.py into new cli modules --- .../axolotl/cli/__init__.py | 40 +----------------- src/axolotl/cli/inference.py | 26 ++++++++++++ src/axolotl/cli/merge_lora.py | 26 ++++++++++++ src/axolotl/cli/shard.py | 41 +++++++++++++++++++ src/axolotl/cli/train.py | 30 ++++++++++++++ 5 files changed, 124 insertions(+), 39 deletions(-) rename scripts/finetune.py => src/axolotl/cli/__init__.py (85%) create mode 100644 src/axolotl/cli/inference.py create mode 100644 src/axolotl/cli/merge_lora.py create mode 100644 src/axolotl/cli/shard.py create mode 100644 src/axolotl/cli/train.py diff --git a/scripts/finetune.py b/src/axolotl/cli/__init__.py similarity index 85% rename from scripts/finetune.py rename to src/axolotl/cli/__init__.py index c149ad073b..ff8eb3b910 100644 --- a/scripts/finetune.py +++ b/src/axolotl/cli/__init__.py @@ -8,9 +8,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -import fire import torch -import transformers import yaml # add src to the pythonpath so we don't need to pip install this @@ -20,7 +18,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging -from axolotl.train import TrainDatasetMeta, train +from axolotl.train import TrainDatasetMeta from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.data import prepare_dataset from axolotl.utils.dict import DictDefault @@ -80,17 +78,6 @@ def do_merge_lora( tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged")) -def shard( - *, - cfg: DictDefault, - cli_args: TrainerCliArgs, -): - model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - safe_serialization = cfg.save_safetensors is True - LOG.debug("Re-saving model w/ sharding") - model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) - - def do_inference( *, cfg: DictDefault, @@ -260,28 +247,3 @@ def check_accelerate_default_config(): LOG.warning( f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" ) - - -def do_cli(config: Path = Path("examples/"), **kwargs): - print_axolotl_text_art() - parsed_cfg = load_cfg(config, **kwargs) - check_accelerate_default_config() - parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( - return_remaining_strings=True - ) - if parsed_cli_args.inference: - do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.merge_lora: - do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) - elif parsed_cli_args.shard: - shard(cfg=parsed_cfg, cli_args=parsed_cli_args) - else: - dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) - if parsed_cli_args.prepare_ds_only: - return - train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) - - -if __name__ == "__main__": - fire.Fire(do_cli) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py new file mode 100644 index 0000000000..1a5a1a2be4 --- /dev/null +++ b/src/axolotl/cli/inference.py @@ -0,0 +1,26 @@ +""" +CLI to run inference on a trained model +""" +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.inference = True + + do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) + + +fire.Fire(do_cli) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py new file mode 100644 index 0000000000..473aa8260c --- /dev/null +++ b/src/axolotl/cli/merge_lora.py @@ -0,0 +1,26 @@ +""" +CLI to run merge a trained LoRA into a base model +""" +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.merge_lora = True + + do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args) + + +fire.Fire(do_cli) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py new file mode 100644 index 0000000000..ad7d9a1368 --- /dev/null +++ b/src/axolotl/cli/shard.py @@ -0,0 +1,41 @@ +""" +CLI to shard a trained model into 10GiB chunks +""" +import logging +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.scripts") + + +def shard( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + safe_serialization = cfg.save_safetensors is True + LOG.debug("Re-saving model w/ sharding") + model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + parsed_cli_args.shard = True + + shard(cfg=parsed_cfg, cli_args=parsed_cli_args) + + +fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py new file mode 100644 index 0000000000..3b43c963be --- /dev/null +++ b/src/axolotl/cli/train.py @@ -0,0 +1,30 @@ +""" +CLI to run training on a model +""" +from pathlib import Path + +import fire +import transformers + +from axolotl.cli import load_cfg, load_datasets, print_axolotl_text_art, check_accelerate_default_config +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train + + +def do_cli(config: Path = Path("examples/"), **kwargs): + # pylint: disable=duplicate-code + print_axolotl_text_art() + check_accelerate_default_config() + parsed_cfg = load_cfg(config, **kwargs) + parser = transformers.HfArgumentParser((TrainerCliArgs)) + parsed_cli_args, _ = parser.parse_args_into_dataclasses( + return_remaining_strings=True + ) + + dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) + if parsed_cli_args.prepare_ds_only: + return + train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) + + +fire.Fire(do_cli)