From 586bd8d2213a1dd9bee4afb299902d37030d4562 Mon Sep 17 00:00:00 2001 From: Nick Doiron Date: Mon, 1 Apr 2024 23:48:59 -0400 Subject: [PATCH] fix pretraining_ on odd datasets (#1463) * can configure name of split of pretraining dataset * streaming data and dataset map * text column customized * allow text_column to be set in pretrain * pretrain type * load a bit of the dataset * fix dataset where splits have separate configs * ok name param here is the config * whitespace --- src/axolotl/prompt_strategies/pretrain.py | 6 ++++-- .../config/models/input/v0_4_1/__init__.py | 6 +++++- src/axolotl/utils/data.py | 20 +++++++++++++++---- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/axolotl/prompt_strategies/pretrain.py b/src/axolotl/prompt_strategies/pretrain.py index 913da3b34a..8430a7fcab 100644 --- a/src/axolotl/prompt_strategies/pretrain.py +++ b/src/axolotl/prompt_strategies/pretrain.py @@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy): def supports_batched(self): return True - def __init__(self, *args, max_length=None, **kwargs): + def __init__(self, *args, max_length=None, text_column="text", **kwargs): super().__init__(*args, **kwargs) if max_length: self.max_length = max_length + self.text_column = text_column def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False @@ -44,7 +45,7 @@ def _tokenize( return res def tokenize_prompt(self, prompt): - return self._tokenize(prompt["text"]) + return self._tokenize(prompt[self.text_column]) def load(tokenizer, cfg): @@ -53,6 +54,7 @@ def load(tokenizer, cfg): tokenizer, cfg.train_on_inputs, cfg.sequence_len, + text_column=cfg.pretraining_dataset[0]["text_column"] or "text", max_length=cfg.sequence_len * 64, ) return strat diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 5a927602ff..2850debd0c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -61,7 +61,11 @@ class RemappedParameters(BaseModel): class PretrainingDataset(BaseModel): """pretraining dataset configuration subset""" + name: Optional[str] = None path: Optional[str] = None + split: Optional[str] = "train" + text_column: Optional[str] = "text" + type: Optional[str] = "pretrain" class UserDefinedPrompterType(BaseModel): @@ -448,7 +452,7 @@ class Config: dataset_shard_idx: Optional[int] = None pretraining_dataset: Optional[ # type: ignore - conlist(Union[SFTDataset, PretrainingDataset], min_length=1) + conlist(Union[PretrainingDataset, SFTDataset], min_length=1) ] = Field( default=None, metadata={"help": {"streaming dataset to use for pretraining"}} ) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index e9e0f4fa69..6cc27fbdbd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer): ) else: path = cfg.pretraining_dataset + split = "train" name = None if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"] name = cfg.pretraining_dataset[0]["name"] + if "split" in cfg.pretraining_dataset[0]: + split = cfg.pretraining_dataset[0]["split"] ds_wrapper_partial = functools.partial( get_dataset_wrapper, @@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer): ) train_dataset = wrap_pretraining_dataset( - load_dataset(path, streaming=True, split="train", name=name), + load_dataset(path, streaming=True, split=split, name=name), tokenizer, cfg, ds_wrapper_partial, @@ -831,14 +834,23 @@ def wrap_pretraining_dataset( else: LOG.debug("NOT shuffling merged pretraining datasets") + # remove all the existing columns after mapping since they end up having + # a different length than the encoded/tokenized column + # this is empty during streaming/pretraining + remove_columns = [] + if dataset.features is None: + for first_row in dataset: + remove_columns = first_row.keys() + break + else: + remove_columns = dataset.features.keys() + dataset = dataset.map( encode, batched=True, batch_size=buffer_size, # input_columns="text", - # remove all the existing columns after mapping since they end up having - # a different length than the encoded/tokenized column - remove_columns=dataset.features.keys(), + remove_columns=remove_columns, ) return dataset