From 71de1c8fe2e1aec5e51813d377df5426ec1679cf Mon Sep 17 00:00:00 2001 From: Jan Philipp Harries <2862336+jphme@users.noreply.github.com> Date: Wed, 13 Sep 2023 06:16:40 +0200 Subject: [PATCH] Fix pretraining with iterable/streaming Dataset (#556) * return without packing prep/len * fix remove columns * fix encode arguments * add error when max steps not set * fix test --------- Co-authored-by: Jan Philipp Harries --- src/axolotl/utils/config.py | 4 ++++ src/axolotl/utils/data.py | 19 ++++++++++++++----- tests/test_data.py | 2 +- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 6de807eab9..e3febfe31c 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -191,6 +191,10 @@ def validate_config(cfg): LOG.warning( "You probably want to disable group_by_length as it will force a streamed dataset to download completely." ) + if cfg.pretraining_dataset and not cfg.max_steps: + raise ValueError( + "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!" + ) if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( not cfg.optimizer or "adamw" not in cfg.optimizer diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index f322b800b5..f024d19c47 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -3,7 +3,7 @@ import hashlib import logging from pathlib import Path -from typing import Tuple, Union +from typing import Dict, List, Tuple, Union import torch from datasets import ( @@ -74,6 +74,7 @@ def prepare_dataset(cfg, tokenizer): # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") eval_dataset = None + return train_dataset, eval_dataset, cfg.max_steps with zero_first(is_main_process()): train_dataset, eval_dataset = process_datasets_for_packing( @@ -527,9 +528,11 @@ def load_prepare_datasets( return train_dataset, eval_dataset -def encode_pretraining(tokenizer, max_tokens, examples): +def encode_pretraining( + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] +) -> Dict[str, List]: res = tokenizer( - examples["text"], + examples, truncation=True, max_length=max_tokens - 2, add_special_tokens=True, @@ -637,6 +640,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): encode = functools.partial(encode_pretraining, tokenizer, max_tokens) dataset = load_dataset(path, streaming=True, split="train") dataset = dataset.shuffle(seed=seed, buffer_size=10_000) - # TODO dynamically figure out which columns/features to remove - dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) + dataset = dataset.map( + encode, + batched=True, + input_columns="text", + remove_columns=[ + "text", + ], + ) return dataset diff --git a/tests/test_data.py b/tests/test_data.py index 9d7f5a0412..16af089a06 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -35,7 +35,7 @@ def test_encode_pretraining(self): "hello, hello", ] } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples) + result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"]) self.assertEqual(len(result["input_ids"]), 3)