From 1c412c7e9dd228209ad63afdcd5dc430a1ef82ab Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Oct 2023 07:46:07 -0400 Subject: [PATCH] improve handling of the prepared ds path and other cfg defaults (#701) --- src/axolotl/cli/inference.py | 1 + src/axolotl/cli/train.py | 13 +++++++++++++ src/axolotl/common/const.py | 5 +++++ src/axolotl/utils/data.py | 4 ++-- 4 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 src/axolotl/common/const.py diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index f3daac83dd..91405d8c66 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) + parsed_cfg.sample_packing = False parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index c64755872b..b49cbc6b60 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,10 +1,12 @@ """ CLI to run training on a model """ +import logging from pathlib import Path import fire import transformers +from colorama import Fore from axolotl.cli import ( check_accelerate_default_config, @@ -14,8 +16,11 @@ print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.train import train +LOG = logging.getLogger("axolotl.cli.train") + def do_cli(config: Path = Path("examples/"), **kwargs): # pylint: disable=duplicate-code @@ -27,6 +32,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs): parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True ) + if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path: + msg = ( + Fore.RED + + "--prepare_ds_only called without dataset_prepared_path set." + + Fore.RESET + ) + LOG.warning(msg) + parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) if parsed_cli_args.prepare_ds_only: diff --git a/src/axolotl/common/const.py b/src/axolotl/common/const.py new file mode 100644 index 0000000000..fd34ad4694 --- /dev/null +++ b/src/axolotl/common/const.py @@ -0,0 +1,5 @@ +""" +Various shared constants +""" + +DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index bac7d96c9e..c944dd27ba 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -16,6 +16,7 @@ from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( @@ -44,7 +45,6 @@ ) LOG = logging.getLogger("axolotl") -DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" def md5(to_hash: str, encoding: str = "utf-8") -> str: @@ -357,7 +357,7 @@ def for_d_in_datasets(dataset_configs): if len(datasets) > 1: LOG.info("shuffle merged datasets") dataset = dataset.shuffle(seed=seed) - if cfg.local_rank == 0 and cfg.dataset_prepared_path: + if cfg.local_rank == 0: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(prepared_ds_path) if cfg.push_dataset_to_hub: