Skip to content

Commit

Permalink
make sure to register the base chatml template even if no system mess…
Browse files Browse the repository at this point in the history
…age is provided (#1207)
  • Loading branch information
winglian authored Jan 25, 2024
1 parent a01b998 commit badda37
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,7 @@ jobs:
- name: GPU Unit Tests monkeypatched w docker image
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
- name: Prune image from docker
if: github.ref != 'refs/heads/main'
run: |
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
2 changes: 2 additions & 0 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()

if not parsed_cfg.dataset_prepared_path:
msg = (
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
f"ChatML set. Adding default system message: {cfg.default_system_message}"
)
register_chatml_template(cfg.default_system_message)
else:
register_chatml_template()

if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
load_from_disk,
)
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -213,7 +214,7 @@ def for_d_in_datasets(dataset_configs):
token=use_auth_token,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError):
except (FileNotFoundError, ConnectionError, HFValidationError):
pass

ds_from_cloud = False
Expand Down

0 comments on commit badda37

Please sign in to comment.