Skip to content

Commit

Permalink
refactor scripts/finetune.py into new cli modules
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 13, 2023
1 parent fdb777b commit f82d21c
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 39 deletions.
40 changes: 1 addition & 39 deletions scripts/finetune.py → src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import fire
import torch
import transformers
import yaml

# add src to the pythonpath so we don't need to pip install this
Expand All @@ -20,7 +18,7 @@

from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta, train
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
Expand Down Expand Up @@ -80,17 +78,6 @@ def do_merge_lora(
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))


def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)


def do_inference(
*,
cfg: DictDefault,
Expand Down Expand Up @@ -260,28 +247,3 @@ def check_accelerate_default_config():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)


def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)


if __name__ == "__main__":
fire.Fire(do_cli)
26 changes: 26 additions & 0 deletions src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
CLI to run inference on a trained model
"""
from pathlib import Path

import fire
import transformers

from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.inference = True

do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)


fire.Fire(do_cli)
26 changes: 26 additions & 0 deletions src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path

import fire
import transformers

from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True

do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)


fire.Fire(do_cli)
41 changes: 41 additions & 0 deletions src/axolotl/cli/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path

import fire
import transformers

from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault

LOG = logging.getLogger("axolotl.scripts")


def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True

shard(cfg=parsed_cfg, cli_args=parsed_cli_args)


fire.Fire(do_cli)
30 changes: 30 additions & 0 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
CLI to run training on a model
"""
from pathlib import Path

import fire
import transformers

from axolotl.cli import load_cfg, load_datasets, print_axolotl_text_art, check_accelerate_default_config
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)

dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)


fire.Fire(do_cli)

0 comments on commit f82d21c

Please sign in to comment.