From d66b10141efba1421f055b5fc12574aa4c961d60 Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 13 Jan 2024 10:13:35 +0100 Subject: [PATCH] Disable caching on `--disable_caching` in CLI (#1110) * Disable caching on `--disable_caching` in CLI * chore: lint --------- Co-authored-by: Wing Lian --- src/axolotl/cli/preprocess.py | 9 ++++++++- src/axolotl/cli/train.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index e0eeea6b34..21436bf413 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -7,6 +7,7 @@ import fire import transformers from colorama import Fore +from datasets import disable_caching from axolotl.cli import ( check_accelerate_default_config, @@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs): check_accelerate_default_config() check_user_token() parser = transformers.HfArgumentParser((PreprocessCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( + parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses( return_remaining_strings=True ) + + if ( + remaining_args.get("disable_caching") is not None + and remaining_args["disable_caching"] + ): + disable_caching() if not parsed_cfg.dataset_prepared_path: msg = ( Fore.RED diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 2248784dff..6b5f496862 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -6,6 +6,7 @@ import fire import transformers +from datasets import disable_caching from axolotl.cli import ( check_accelerate_default_config, @@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs): check_accelerate_default_config() check_user_token() parser = transformers.HfArgumentParser((TrainerCliArgs)) - parsed_cli_args, _ = parser.parse_args_into_dataclasses( + parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses( return_remaining_strings=True ) + + if ( + remaining_args.get("disable_caching") is not None + and remaining_args["disable_caching"] + ): + disable_caching() if parsed_cfg.rl: dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) else: