Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add finutuning with streaming dataset example #945

Merged
merged 27 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(

def __call__(self, examples: List[Dict[str,
Any]]) -> Dict[str, torch.Tensor]:
for check_key in ['input_ids', 'labels', 'attention_mask']:
for check_key in ['input_ids', 'labels']:
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
if check_key not in examples[0]:
raise KeyError(
f'Examples returned by dataset do not include required key: {check_key}'
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
max_seq_len=cfg.dataset.max_seq_len,
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
)

else:
Expand Down Expand Up @@ -284,6 +285,9 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
'HuggingFace dataset or set `remote` to use a streaming ' +\
'dataset, but both were None.'
)
if dataset_cfg.get('max_seq_len') is None:
raise ValueError(
'In the dataset config, you must set the `max_seq_len`')


def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
Expand Down
23 changes: 23 additions & 0 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

import datasets as hf_datasets
import huggingface_hub as hf_hub
import numpy as np
from composer.utils import dist
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase
Expand Down Expand Up @@ -332,6 +333,7 @@ def __init__(self,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
max_seq_len: int = 2048,
**kwargs: Any):

if len(kwargs) > 0:
Expand Down Expand Up @@ -371,10 +373,31 @@ def __init__(self,
)

self.tokenizer = tokenizer
self.max_seq_len = max_seq_len

# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = super().__getitem__(idx)
if 'input_ids' in sample:
# Already tokenized data
if isinstance(sample['input_ids'], bytes):
sample['input_ids'] = np.frombuffer(
sample['input_ids'],
dtype=np.int64)[:self.max_seq_len].tolist().copy()
sample['labels'] = np.frombuffer(
sample['labels'],
dtype=np.int64)[:self.max_seq_len].tolist().copy()
elif isinstance(sample['input_ids'], np.ndarray):
sample['input_ids'] = sample[
'input_ids'][:self.max_seq_len].tolist().copy()
sample['labels'] = sample['labels'][:self.max_seq_len].tolist(
).copy()
else:
raise ValueError(
f'Expect input_ids to be bytes or numpy.ndarray type, but got {type(sample["input_ids"])}'
)

return sample
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)


Expand Down
5 changes: 3 additions & 2 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def main(args: Namespace) -> None:
tokenizer_kwargs.update({'model_max_length': args.max_seq_len})
if args.tokenizer:
tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs)
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
columns = {'input_ids': 'ndarray:uint32', 'labels': 'ndarray:uint32'}
else:
columns = {'prompt': 'str', 'response': 'str'}

Expand Down Expand Up @@ -255,7 +255,8 @@ def main(args: Namespace) -> None:
sample_to_write = {}
# convert to bytes
for key in columns.keys():
sample_to_write[key] = np.asarray(sample[key]).tobytes()
sample_to_write[key] = np.asarray(sample[key],
dtype=np.uint32)
out.write(sample_to_write)
else:
encoded_sample = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
max_seq_len: 512
global_seed: 17

data_local: ./my_data
data_remote: # If blank, files must be present in data_local

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME

# Model
model:
name: hf_causal_lm
pretrained_model_name_or_path: gpt2
pretrained: true # false: only use the architecture; true: initialize with pretrained weights

# Tokenizer
tokenizer:
name: gpt2
kwargs:
model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
name: finetuning
dataset:
############
remote: ${data_remote}
local: ${data_local}
split: train
############
shuffle: true
max_seq_len: ${max_seq_len}
decoder_only_format: true
drop_last: true
num_workers: 8

# Optimization
scheduler:
name: cosine_with_warmup
t_warmup: 100ba
alpha_f: 0.1

optimizer:
name: decoupled_adamw
lr: 6.0e-4
betas:
- 0.9
- 0.95
eps: 1.0e-08
weight_decay: 0.0

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 1ep
eval_interval: 1
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 8

# System
seed: ${global_seed}
device_eval_batch_size: 8
device_train_microbatch_size: 8
# device_train_microbatch_size: auto
precision: fp32

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}
66 changes: 55 additions & 11 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import ContextManager, Literal, Optional, Union
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
import torch
import transformers
Expand All @@ -25,7 +26,8 @@
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
is_valid_ift_example)
is_valid_ift_example,
tokenize_formatted_example)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
get_tokens_per_batch_func)
Expand All @@ -49,23 +51,51 @@ def get_abs_data_path(data_local: str):
return os.path.join(os.getcwd(), data_local)


def build_mock_ft_streaming_dataset(data_path: str, split: str):
columns = {'prompt': 'str', 'response': 'str'}
def build_mock_ft_streaming_dataset(
bigning marked this conversation as resolved.
Show resolved Hide resolved
data_path: str,
split: str,
pretokenize: bool,
use_bytes: bool,
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None):
if pretokenize:
if use_bytes:
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
else:
columns = {
'input_ids': 'ndarray:uint32',
'labels': 'ndarray:uint32'
}
else:
columns = {'prompt': 'str', 'response': 'str'}

dataset = [{
'prompt': 'This is just a test1',
'response': 'Hello World1'
}, {
'prompt': 'This is just a test2',
'response': 'Hello world2'
}, {
'prompt': 'This is just a test3',
'response': 'Hello world3'
}]

output_path = os.path.join(data_path, split)

with MDSWriter(columns=columns, out=output_path,
compression=None) as output_writer:
for sample in dataset:
output_writer.write(sample)
if pretokenize:
sample = tokenize_formatted_example(sample, tokenizer=tokenizer)
sample_to_write = {}
for key in columns.keys():
if use_bytes:
sample_to_write[key] = np.asarray(sample[key]).tobytes()
else:
sample_to_write[key] = np.asarray(sample[key],
dtype=np.uint32)
output_writer.write(sample_to_write)
else:
output_writer.write(sample)


@pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m'])
Expand Down Expand Up @@ -517,13 +547,25 @@ def test_finetuning_dataloader_custom_split_remote(split: str):
assert split in dest_arg, 'split destination should match split name'


def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):
@pytest.mark.parametrize('pretokenize', [True, False])
@pytest.mark.parametrize('use_bytes', [True, False])
def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool,
tmp_path: pathlib.Path):
max_seq_len = 2048

remote_path = os.path.join(tmp_path, 'remote')
local_path = os.path.join(tmp_path, 'local')

build_mock_ft_streaming_dataset(remote_path, 'train')
tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

build_mock_ft_streaming_dataset(remote_path,
'train',
pretokenize,
use_bytes=use_bytes,
tokenizer=tokenizer)

cfg = {
'name': 'finetuning',
Expand All @@ -547,12 +589,14 @@ def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)
dataloader = build_finetuning_dataloader(cfg, tokenizer, 2).dataloader

_ = build_finetuning_dataloader(cfg, tokenizer, 4)
expected_keys = ['input_ids', 'labels']
for batch in dataloader:
for key in expected_keys:
assert key in batch
assert batch[key].shape[0] == 2
break


def test_finetuning_dataloader_is_valid_ift_example():
Expand Down
Loading