Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add finutuning with streaming dataset example #945

Merged
merged 27 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/code-quality.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ on:
branches:
- main
- release/**
- add_finetuning_streaming_dataset_conversion
bigning marked this conversation as resolved.
Show resolved Hide resolved
bigning marked this conversation as resolved.
Show resolved Hide resolved
pull_request:
branches:
- main
- release/**
- add_finetuning_streaming_dataset_conversion
bigning marked this conversation as resolved.
Show resolved Hide resolved
workflow_call:
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(

def __call__(self, examples: List[Dict[str,
Any]]) -> Dict[str, torch.Tensor]:
for check_key in ['input_ids', 'labels', 'attention_mask']:
for check_key in ['input_ids', 'labels']:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
if check_key not in examples[0]:
raise KeyError(
f'Examples returned by dataset do not include required key: {check_key}'
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
max_seq_len=cfg.dataset.max_seq_len,
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
)

else:
Expand Down
56 changes: 43 additions & 13 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)

import datasets as hf_datasets
import huggingface_hub as hf_hub
import numpy as np
from composer.utils import dist
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -199,7 +201,7 @@ def _tokenize_prompt_response_formatted_example(
return tokenizer(text=prompt, text_target=response)


def _tokenize_formatted_example(
def tokenize_formatted_example(
example: Example,
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenizes a formatted example using the provided tokenizer.
Expand Down Expand Up @@ -228,6 +230,33 @@ def _tokenize_formatted_example(
raise ValueError(f'Unknown conversation type {example_format=}')


def is_valid_ift_example(pad_token_id: int, max_seq_len: int,
example: Dict) -> bool:
"""Check if it's an valid ift example.

This functions does the following check:
a. Length of input_ids should less than max_seq_len
b. Both input_ids and labels should not be empty
c. Labels should has at least 1 non-padding token.

Args:
pad_token_id (int): The id of the padding token.
max_seq_len (int): Maximum sequence length.
example (Dict): The input example after tokenization, which has
``input_ids`` and ``labels`` fields.

Returns:
bool: Indicator of whether the input example is valid
"""
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and non_empty_labels and
non_padding_response)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Expand Down Expand Up @@ -304,6 +333,7 @@ def __init__(self,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
max_seq_len: int = 2048,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -343,11 +373,20 @@ def __init__(self,
)

self.tokenizer = tokenizer
self.max_seq_len = max_seq_len

# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = super().__getitem__(idx)
return _tokenize_formatted_example(sample, tokenizer=self.tokenizer)
if 'input_ids' in sample:
# already tokenized data
bigning marked this conversation as resolved.
Show resolved Hide resolved
sample['input_ids'] = np.frombuffer(
bigning marked this conversation as resolved.
Show resolved Hide resolved
sample['input_ids'],
dtype=np.int64)[:self.max_seq_len].tolist().copy()
bigning marked this conversation as resolved.
Show resolved Hide resolved
sample['labels'] = np.frombuffer(sample['labels'],
dtype=np.int64).tolist().copy()
return sample
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)


class DatasetConstructor:
Expand Down Expand Up @@ -550,7 +589,7 @@ def build_from_hf(
def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)
return tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
Expand All @@ -567,17 +606,8 @@ def dataset_mapper(example: Dict):

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
partial(is_valid_ift_example, pad_token_id, max_seq_len),
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
Expand Down
80 changes: 70 additions & 10 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import os
import platform
import warnings
from argparse import ArgumentParser, Namespace
from typing import Dict, Iterable, List, Optional, Union

import datasets as hf_datasets
import numpy as np
import psutil
from streaming import MDSWriter
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm

from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (dataset_constructor,
is_valid_ift_example,
tokenize_formatted_example)
from llmfoundry.utils.builders import build_tokenizer


def parse_args() -> Namespace:
Expand All @@ -23,7 +29,7 @@ def parse_args() -> Namespace:
type=str,
required=True,
help=
'Name/path of the dataset (e.g., first argument to `datasets.load_dataset`)'
'Name of the dataset (e.g., first argument to `datasets.load_dataset`, for jsonl data format, it is `json`)'
)
parser.add_argument('--data_subset',
type=str,
Expand All @@ -38,6 +44,13 @@ def parse_args() -> Namespace:
default=None,
help='Name or import path of function used to preprocess (reformat) the dataset. ' +\
'See README for additional details.')
parser.add_argument(
'--data_files',
nargs='+',
default=[],
help=
'Data file for each split. If set, its length should be exact same as len(splits)'
)
parser.add_argument(
'--skip-preprocessing',
action='store_true',
Expand All @@ -63,6 +76,9 @@ def parse_args() -> Namespace:
default=None,
help='(Optional) name of compression algorithm to use.')
parser.add_argument('--num_workers', type=int, required=False, default=None)
parser.add_argument('--tokenizer', type=str, required=False, default=None)
parser.add_argument('--tokenizer_kwargs', type=str, required=False)
parser.add_argument('--max_seq_len', type=int, default=2048)

parsed = parser.parse_args()

Expand All @@ -73,6 +89,17 @@ def parse_args() -> Namespace:
f'--out_root={parsed.out_root} contains {os.listdir(parsed.out_root)} which cannot overlap with the requested splits {parsed.splits}.'
)

if parsed.tokenizer_kwargs is not None:
parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs)
else:
parsed.tokenizer_kwargs = {}

if len(parsed.data_files) > 0 and len(parsed.data_files) != len(
parsed.splits):
raise ValueError(
f'If set data_files, data_files and splits must be 1:1 mapping. Got {len(parsed.data_files)=} while {len(parsed.splits)=}'
)

return parsed


Expand Down Expand Up @@ -170,12 +197,23 @@ def main(args: Namespace) -> None:
'include the "--skip-preprocessing" flag to avoid this error.'
)

columns = ['prompt', 'response']
tokenizer = None
tokenizer_kwargs = args.tokenizer_kwargs
tokenizer_kwargs.update({'model_max_length': args.max_seq_len})
if args.tokenizer:
tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs)
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
else:
columns = {'prompt': 'str', 'response': 'str'}

for split_name in args.splits:
for i, split_name in enumerate(args.splits):
data_file = None
if len(args.data_files) > 0:
data_file = args.data_files[i]
dataset = hf_datasets.load_dataset(path=args.dataset,
name=args.data_subset,
split=split_name,
data_files=data_file,
streaming=True)
loader = build_dataloader(dataset=dataset,
batch_size=512,
Expand All @@ -190,12 +228,14 @@ def main(args: Namespace) -> None:
keep_local = True
else:
keep_local = False
with MDSWriter(columns={key: 'str' for key in columns},
with MDSWriter(columns=columns,
out=out,
compression=args.compression,
keep_local=keep_local) as out:
examples_removed = 0
for sample in tqdm(samples, desc=split_name):
formatted_sample = preprocessing_fn(sample)

if ('prompt'
not in formatted_sample) or ('response'
not in formatted_sample):
Expand All @@ -204,11 +244,31 @@ def main(args: Namespace) -> None:
'"prompt" and "response" are required keys but at least one was missing ' +\
f'from {formatted_sample=}.'
)
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns
}
out.write(encoded_sample)
if tokenizer is not None:
sample = tokenize_formatted_example(sample,
tokenizer=tokenizer)
if not is_valid_ift_example(tokenizer.pad_token_id,
args.max_seq_len, sample):
examples_removed += 1
continue

sample_to_write = {}
# convert to bytes
for key in columns.keys():
sample_to_write[key] = np.asarray(sample[key]).tobytes()
out.write(sample_to_write)
else:
encoded_sample = {
key: formatted_sample[key].encode('utf-8')
for key in columns.keys()
}
out.write(encoded_sample)
if tokenizer is not None and examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)


if __name__ == '__main__':
Expand Down
Loading
Loading