Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Dec 5, 2023
1 parent 61cd110 commit 6cadcf5
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 121 deletions.
226 changes: 114 additions & 112 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset, dataset_constructor
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func
import huggingface_hub as hf_hub

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,105 +124,110 @@ def build_finetuning_dataloader(cfg: DictConfig,
tokenizer.pad_token = tokenizer.eos_token

dataset = None # for pyright
if cfg.dataset.get('remote') is not None:
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
remote=cfg.dataset.get('remote', None),
split=cfg.dataset.get('split', None),
download_retry=cfg.dataset.get('download_retry', 2),
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash', None),
keep_zip=cfg.dataset.get('keep_zip', False),
epoch_size=cfg.dataset.get('epoch_size', None),
predownload=cfg.dataset.get('predownload', None),
cache_limit=cfg.dataset.get('cache_limit', None),
partition_algo=cfg.dataset.get('partition_algo', 'relaxed'),
num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None),
batch_size=device_batch_size,
shuffle=cfg.dataset.get('shuffle', False),
shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'),
shuffle_seed=cfg.dataset.get('shuffle_seed', 9176),
shuffle_block_size=cfg.dataset.get('shuffle_block_size', None),
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)
sampler = None

# Build streaming or HF dataset.
if cfg.dataset.get('remote') is not None:
# Build Streaming dataset
dataset = _build_streaming_dataset(cfg.dataset, tokenizer)
else:
backend, _, _ = parse_uri(cfg.dataset.hf_name)
hf_name, split = cfg.dataset.hf_name, cfg.dataset.split

backend, _, _ = parse_uri(hf_name)
if backend not in ['', None]:
# Download dataset from remote object store.
if cfg.dataset.get('split') is None:
raise ValueError(
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
else:
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

if cfg.drop_last:
world_size = dist.get_world_size()
minimum_dataset_size = world_size * dataloader_batch_size
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
+
f'your per device batch size is {dataloader_batch_size}. Please increase the number '
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle),
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
hf_name = _download_remote_dataset(hf_name, split)
elif cfg.dataset.get('safe_load') is True:
# Download dataset from huggingface hub with restrictions.
hf_kwargs = cfg.dataset.get('hf_kwargs', None)
token = hf_kwargs.get('token', None)
hf_name = _safe_download_hf_dataset(hf_name, token)

# Build HF dataset
dataset = dataset_constructor.build_from_hf(
dataset_name_or_path=hf_name,
split=split,
max_seq_len=cfg.dataset.max_seq_len,
proto_preprocessing_fn=cfg.dataset.preprocessing_fn,
tokenizer=tokenizer,
hf_kwargs=cfg.dataset.get('hf_kwargs', {})
)
sampler = dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

if cfg.drop_last:
world_size = dist.get_world_size()
minimum_dataset_size = world_size * dataloader_batch_size
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
f'is {minimum_dataset_size} because you are running on {world_size} gpus and '
+
f'your per device batch size is {dataloader_batch_size}. Please increase the number '
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func(
pad_token_id=tokenizer.pad_token_id)

return DataSpec(dataloader=dl, get_num_tokens_in_batch=token_counting_func)


def _build_streaming_dataset(dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase) -> StreamingFinetuningDataset:
return dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=dataset_cfg.local,
remote=dataset_cfg.get('remote', None),
split=dataset_cfg.get('split', None),
download_retry=dataset_cfg.get('download_retry', 2),
download_timeout=dataset_cfg.get('download_timeout', 60),
validate_hash=dataset_cfg.get('validate_hash', None),
keep_zip=dataset_cfg.get('keep_zip', False),
epoch_size=dataset_cfg.get('epoch_size', None),
predownload=dataset_cfg.get('predownload', None),
cache_limit=dataset_cfg.get('cache_limit', None),
partition_algo=dataset_cfg.get('partition_algo', 'relaxed'),
num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None),
batch_size=device_batch_size,
shuffle=dataset_cfg.get('shuffle', False),
shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'),
shuffle_seed=dataset_cfg.get('shuffle_seed', 9176),
shuffle_block_size=dataset_cfg.get('shuffle_block_size', None),
sampling_method=dataset_cfg.get('sampling_method', 'balanced'),
sampling_granularity=dataset_cfg.get('sampling_granularity', 1),
batching_method=dataset_cfg.get('batching_method', 'random'),
)


def _validate_config(dataset_cfg: DictConfig) -> None:
"""Validates the dataset configuration.
Expand Down Expand Up @@ -275,12 +281,23 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
'dataset, but both were None.'
)

def _downloaded_datasets_dir() -> str:
return os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning')

def _safe_download_hf_dataset(hf_name: str, token: Union[bool, str, None]) -> str:
local_dataset_dir = _downloaded_datasets_dir()
hf_hub.snapshot_download(
hf_name,
repo_type='dataset',
allow_patterns=["*.csv", "*.jsonl", "*.parquet"],
token=token, local_dir=local_dataset_dir)
return local_dataset_dir

def _build_hf_dataset_from_remote(
cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Builds a dataset from a remote object store.
def _download_remote_dataset(hf_name: str, split: str) -> str:
"""Downloads a dataset from a remote object store.
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
Expand All @@ -291,33 +308,26 @@ def _build_hf_dataset_from_remote(
completed, the function removes the signal file.
Args:
cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
This includes:
- dataset.hf_name: The path of the HuggingFace dataset to download.
- dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
- dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.
tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
hf_name (str): The path of the HuggingFace dataset to download.
split (str): The dataset split to download (e.g., 'train', 'validation', 'test').
Returns:
Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
A local directory path where the dataset files are stored.
Raises:
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
destination_split = split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
_downloaded_datasets_dir(),
destination_split if destination_split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
name = f'{hf_name.strip("/")}/{split}.{extension}'
destination = str(
os.path.abspath(
os.path.join(
Expand All @@ -334,7 +344,7 @@ def _build_hf_dataset_from_remote(
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
files_searched = [
f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}'
f'{hf_name}/{split}.{ext}'
for ext in supported_extensions
]
raise FileNotFoundError(
Expand All @@ -361,15 +371,7 @@ def _build_hf_dataset_from_remote(
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()

cfg.dataset.hf_name = finetune_dir
log.info(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
return dataset
return finetune_dir


def _build_collate_fn(
Expand Down
15 changes: 6 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ def get_preprocessing_fn_from_str(
return preprocessing_fn

def build_from_hf(
self, cfg: DictConfig, max_seq_len: int,
tokenizer: PreTrainedTokenizerBase
self, dataset_name_or_path: str, split: str, max_seq_len: int,
tokenizer: PreTrainedTokenizerBase, proto_preprocessing_fn: Union[dict, DictConfig, str], hf_kwargs: dict[str, Any]
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Load a HuggingFace Datasets, preprocess, and tokenize.
Expand All @@ -341,19 +341,16 @@ def build_from_hf(
Returns:
Dataset: The tokenized dataset.
"""
dataset_name = cfg.hf_name
# HF datasets does not support a split with dashes,so we replace split
# dashes with underscore.
split = cfg.split.replace('-', '_')
kwargs = cfg.get('hf_kwargs', {})
proto_preprocessing_fn = cfg.get('preprocessing_fn')
split = split.replace('-', '_')
if isinstance(proto_preprocessing_fn, dict) or isinstance(
proto_preprocessing_fn, DictConfig):
preprocessing_fn = self.get_preprocessing_fn_from_dict(
proto_preprocessing_fn)
else:
preprocessing_fn = self.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name)
proto_preprocessing_fn, dataset_name_or_path)

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'

Expand All @@ -368,9 +365,9 @@ def build_from_hf(
error: Optional[Exception] = None
filtered_dataset = None
try:
dataset = hf_datasets.load_dataset(dataset_name,
dataset = hf_datasets.load_dataset(dataset_name_or_path,
split=split,
**kwargs)
**hf_kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
Expand Down

0 comments on commit 6cadcf5

Please sign in to comment.