From a10a2ed48f9cdc4f76fe6a8580fe4dce1e67a036 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Jan 2024 10:37:33 -0500 Subject: [PATCH] make sure to register the base chatml template even if no system message is provided --- .github/workflows/tests.yml | 4 ++++ src/axolotl/cli/preprocess.py | 2 ++ src/axolotl/cli/train.py | 3 +++ src/axolotl/utils/data.py | 3 ++- 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f3e1414a0f..d0cde0cd62 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -106,3 +106,7 @@ jobs: - name: GPU Unit Tests monkeypatched w docker image run: | docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/ + - name: Prune image from docker + if: github.ref != 'refs/heads/main' + run: | + docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 8ea68575db..e7bc612b7e 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -40,6 +40,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs): f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}" ) register_chatml_template(parsed_cfg.default_system_message) + else: + register_chatml_template() if not parsed_cfg.dataset_prepared_path: msg = ( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index e18f45c338..7ab06422fa 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -43,7 +43,10 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: f"ChatML set. Adding default system message: {cfg.default_system_message}" ) register_chatml_template(cfg.default_system_message) + else: + register_chatml_template() + if cfg.rl: dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 6dd1ec5602..d9a590bc3c 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -16,6 +16,7 @@ load_from_disk, ) from huggingface_hub import hf_hub_download +from huggingface_hub.utils import HFValidationError from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase @@ -213,7 +214,7 @@ def for_d_in_datasets(dataset_configs): token=use_auth_token, ) ds_from_hub = True - except (FileNotFoundError, ConnectionError): + except (FileNotFoundError, ConnectionError, HFValidationError): pass ds_from_cloud = False