Skip to content

Commit

Permalink
improve handling of the prepared ds path and other cfg defaults (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Oct 13, 2023
1 parent 490923f commit 1c412c7
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
Expand Down
13 changes: 13 additions & 0 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path

import fire
import transformers
from colorama import Fore

from axolotl.cli import (
check_accelerate_default_config,
Expand All @@ -14,8 +16,11 @@
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train

LOG = logging.getLogger("axolotl.cli.train")


def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
Expand All @@ -27,6 +32,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH

dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/common/const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Various shared constants
"""

DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
4 changes: 2 additions & 2 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import (
Expand Down Expand Up @@ -44,7 +45,6 @@
)

LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"


def md5(to_hash: str, encoding: str = "utf-8") -> str:
Expand Down Expand Up @@ -357,7 +357,7 @@ def for_d_in_datasets(dataset_configs):
if len(datasets) > 1:
LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed)
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub:
Expand Down

0 comments on commit 1c412c7

Please sign in to comment.