Skip to content

Commit

Permalink
Make FinetuningStreamingDataset parameters more flexible (#1580)
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU authored Oct 14, 2024
1 parent 6a748e9 commit 4a47b5d
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 18 deletions.
20 changes: 14 additions & 6 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def build_finetuning_dataloader(
allowed_dataset_config_keys = set(
dataset_constructor_keys,
).union(_ALLOWED_DATASET_KEYS)
_validate_config(

extraneous_keys = _validate_config(
**dataset_cfg,
allowed_dataset_keys=allowed_dataset_config_keys,
)
Expand Down Expand Up @@ -253,13 +254,13 @@ def build_finetuning_dataloader(
streams_cfg,
) if streams_cfg is not None else None

# Take the constructor args from above, minus args that have been created separately
dataset_constructor_args = {
k: v
for k, v in dataset_cfg.items()
if k in dataset_constructor_keys and
if k in set(dataset_constructor_keys).union(extraneous_keys) and
k not in {'streams', 'packing_ratio'}
}

streaming_dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
streams=streams,
Expand Down Expand Up @@ -378,7 +379,7 @@ def _validate_config(
target_responses: Optional[str] = None,
allowed_dataset_keys: set[str] = _ALLOWED_DATASET_KEYS,
**kwargs: dict[str, Any],
) -> None:
) -> set[str]:
"""Validates the dataset configuration.
Makes sure that the dataset is properly configured for either
Expand Down Expand Up @@ -434,11 +435,16 @@ def _validate_config(
Raises:
ValueError: If the dataset configuration does not meet the requirements.
Returns:
set[str]: Return the extraneous keys.
"""
extraneous_keys = set()
if not set(kwargs.keys()).issubset(allowed_dataset_keys):
raise ValueError(
extraneous_keys = set(kwargs.keys()) - allowed_dataset_keys
log.warning(
'The dataset config contains the following extraneous keys: ' +\
', '.join(set(kwargs.keys()) - allowed_dataset_keys),
', '.join(extraneous_keys),
)

if hf_name is not None:
Expand Down Expand Up @@ -533,6 +539,8 @@ def _validate_config(
decoder_only_format,
)

return extraneous_keys


def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
"""Downloads a dataset from a remote object store.
Expand Down
6 changes: 1 addition & 5 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,6 @@ def __init__(
**kwargs: Any,
):

if len(kwargs) > 0:
raise ValueError(
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
Expand Down Expand Up @@ -658,6 +653,7 @@ def __init__(
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
**kwargs,
)

self.tokenizer = tokenizer
Expand Down
11 changes: 5 additions & 6 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ def __init__(
**kwargs: Any,
):

if len(kwargs) > 0:
raise ValueError(
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
Expand Down Expand Up @@ -188,6 +183,7 @@ def __init__(
batching_method=batching_method,
allow_unsafe_types=allow_unsafe_types,
replication=replication,
**kwargs,
)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
Expand Down Expand Up @@ -332,10 +328,13 @@ def build_text_dataloader(
StreamingTextDataset,
).parameters

valid_base_dataset_params = inspect.signature(StreamingDataset,).parameters

dataset_config_subset_for_streaming_text_dataset = {
k: v
for k, v in dataset_cfg.items()
if k in valid_streaming_text_dataset_parameters
if k in valid_streaming_text_dataset_parameters or
k in valid_base_dataset_params
}

# build dataset potentially with streams
Expand Down
159 changes: 158 additions & 1 deletion tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from typing import Any, Callable, ContextManager, Literal, Optional, Union
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, mock_open, patch

import catalogue
import numpy as np
Expand Down Expand Up @@ -1423,3 +1423,160 @@ def test_sharegpt_format(
device_batch_size=device_batch_size,
**cfg,
).dataloader

def test_ft_dataloader_with_extra_keys():
max_seq_len = 2
cfg = {
'dataset': {
'remote': '/remote',
'local': '/local',
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'shuffle': True,
'num_canonical_nodes': 472,
'target_responses': 'last',
'target_prompts': 'none',
'extra_key_1': 'extra_key_1',
'extra_key_2': 'extra_key_2',
'extra_key_3': 'extra_key_3',
},
'drop_last': False,
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None,
'persistent_workers': False,
'timeout': 0,
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

device_batch_size = 2

mock_stat = MagicMock()
mock_stat.st_size = 1024 # Mock st_size with a desired value
mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems

#with patch('streaming.base.stream.get_shards', return_value=None):
with patch('os.makedirs'), \
patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \
patch('json.load') as mock_json_load, \
patch('os.stat', return_value=mock_stat), \
patch('torch.distributed.is_available', return_value=True), \
patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.broadcast_object_list'), \
patch('torch.distributed.init_process_group'), \
patch('torch.distributed.destroy_process_group'), \
patch('torch.distributed.barrier'), \
patch('streaming.base.dataset.StreamingDataset.get_item'):

mock_json_load.return_value = {
'version':
2,
'shards': [{
'column_names': ['column1', 'column2'],
'column_encodings': ['int', 'float'],
'column_sizes': [4, 8],
'compression': None,
'format': 'mds',
'hashes': [],
'raw_data': {
'basename': 'shard.00000.mds',
'bytes': 1024,
'hashes': {},
},
'samples': 1000,
'size_limit': 67108864,
'version': 2,
'zip_data': None,
}],
}

with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'):
_ = build_finetuning_dataloader(
**cfg,
tokenizer=tokenizer,
device_batch_size=device_batch_size,
).dataloader

@pytest.mark.xfail
def test_text_dataloader_with_extra_keys():
max_seq_len = 1024
cfg = {
'dataset': {
'remote': '/remote',
'local': '/local',
'split': 'train',
'max_seq_len': max_seq_len,
'shuffle': True,
'num_canonical_nodes': 472,
'extra_key_1': 'extra_key_1',
'extra_key_2': 'extra_key_2',
'extra_key_3': 'extra_key_3',
},
'drop_last': False,
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None,
'persistent_workers': False,
'timeout': 0,
}

cfg = om.create(cfg)

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

device_batch_size = 2

mock_stat = MagicMock()
mock_stat.st_size = 1024 # Mock st_size with a desired value
mock_stat.st_mode = 33188 # Regular file mode for Unix-based systems

#with patch('streaming.base.stream.get_shards', return_value=None):
with patch('os.makedirs'), \
patch('builtins.open', new_callable=mock_open, read_data='{"version": 2, "shards": []}'), \
patch('json.load') as mock_json_load, \
patch('os.stat', return_value=mock_stat), \
patch('torch.distributed.is_available', return_value=True), \
patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.broadcast_object_list'), \
patch('torch.distributed.init_process_group'), \
patch('torch.distributed.destroy_process_group'), \
patch('torch.distributed.barrier'), \
patch('streaming.base.dataset.StreamingDataset.get_item'):

mock_json_load.return_value = {
'version':
2,
'shards': [{
'column_names': ['column1', 'column2'],
'column_encodings': ['int', 'float'],
'column_sizes': [4, 8],
'compression': None,
'format': 'mds',
'hashes': [],
'raw_data': {
'basename': 'shard.00000.mds',
'bytes': 1024,
'hashes': {},
},
'samples': 1000,
'size_limit': 67108864,
'version': 2,
'zip_data': None,
}],
}
with pytest.raises(TypeError, match=f'.*got an unexpected keyword argument.*'):
_ = build_text_dataloader(
**cfg,
tokenizer=tokenizer,
device_batch_size=device_batch_size,
).dataloader

0 comments on commit 4a47b5d

Please sign in to comment.