Skip to content

Commit

Permalink
fix pretraining_ on odd datasets (#1463)
Browse files Browse the repository at this point in the history
* can configure name of split of pretraining dataset

* streaming data and dataset map

* text column customized

* allow text_column to be set in pretrain

* pretrain type

* load a bit of the dataset

* fix dataset where splits have separate configs

* ok name param here is the config

* whitespace
  • Loading branch information
mapmeld authored Apr 2, 2024
1 parent 86b7d22 commit 586bd8d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
6 changes: 4 additions & 2 deletions src/axolotl/prompt_strategies/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
def supports_batched(self):
return True

def __init__(self, *args, max_length=None, **kwargs):
def __init__(self, *args, max_length=None, text_column="text", **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
self.text_column = text_column

def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
Expand All @@ -44,7 +45,7 @@ def _tokenize(
return res

def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])
return self._tokenize(prompt[self.text_column])


def load(tokenizer, cfg):
Expand All @@ -53,6 +54,7 @@ def load(tokenizer, cfg):
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
max_length=cfg.sequence_len * 64,
)
return strat
6 changes: 5 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ class RemappedParameters(BaseModel):
class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset"""

name: Optional[str] = None
path: Optional[str] = None
split: Optional[str] = "train"
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"


class UserDefinedPrompterType(BaseModel):
Expand Down Expand Up @@ -448,7 +452,7 @@ class Config:
dataset_shard_idx: Optional[int] = None

pretraining_dataset: Optional[ # type: ignore
conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
)
Expand Down
20 changes: 16 additions & 4 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer):
)
else:
path = cfg.pretraining_dataset
split = "train"
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]

ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
Expand All @@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer):
)

train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
load_dataset(path, streaming=True, split=split, name=name),
tokenizer,
cfg,
ds_wrapper_partial,
Expand Down Expand Up @@ -831,14 +834,23 @@ def wrap_pretraining_dataset(
else:
LOG.debug("NOT shuffling merged pretraining datasets")

# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
# this is empty during streaming/pretraining
remove_columns = []
if dataset.features is None:
for first_row in dataset:
remove_columns = first_row.keys()
break
else:
remove_columns = dataset.features.keys()

dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
# input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
remove_columns=remove_columns,
)
return dataset

Expand Down

0 comments on commit 586bd8d

Please sign in to comment.