From 3553172e3cd98beca1fc10406648a2c9af5ac63b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 14 Oct 2023 09:27:07 -0400 Subject: [PATCH] fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728) --- src/axolotl/prompt_strategies/alpaca_chat.py | 12 ++++++++---- src/axolotl/utils/trainer.py | 4 +++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py index 17fe69be7b..975fee889e 100644 --- a/src/axolotl/prompt_strategies/alpaca_chat.py +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -1,6 +1,6 @@ -"""Module containing the AlpacaQAPromptTokenizingStrategy class""" +"""Module for Alpaca prompt strategy classes""" -from typing import Tuple +from typing import Any, Dict, Optional, Tuple from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, @@ -9,9 +9,13 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter -def load(tokenizer, cfg): +def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + prompt_style = PromptStyle.CHAT.value + if ds_cfg and "conversation" in ds_cfg: + prompt_style = ds_cfg["conversation"] + return AlpacaPromptTokenizingStrategy( - AlpacaPrompter(PromptStyle.CHAT.value), + AlpacaPrompter(prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ee8c634966..820202b80b 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer): ) # Phi doesn't want the attention_mask feature when training - if "CodeGenTokenizer" in tokenizer.__class__.__name__: + if "CodeGenTokenizer" in tokenizer.__class__.__name__ or ( + cfg.is_mistral_derived_model and cfg.flash_attention + ): train_dataset = train_dataset.remove_columns("attention_mask") if eval_dataset: eval_dataset = eval_dataset.remove_columns("attention_mask")