Skip to content

Commit

Permalink
Add safe_load option to restrict HF dataset downloads to allowed file…
Browse files Browse the repository at this point in the history
… types (#798)
  • Loading branch information
irenedea authored Jan 5, 2024
1 parent 083b4b2 commit f0fd749
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 112 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ notebooks/
**/*.pt
**/mlruns/*
**/tokenizer-save-dir-*/**
**/.downloaded_finetuning/
163 changes: 77 additions & 86 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from typing import Tuple, Union

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
Expand All @@ -13,7 +12,9 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func

Expand Down Expand Up @@ -122,8 +123,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
# Build streaming dataloader
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
Expand All @@ -148,48 +154,53 @@ def build_finetuning_dataloader(cfg: DictConfig,
batching_method=cfg.dataset.get('batching_method', 'random'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

else:
backend, _, _ = parse_uri(cfg.dataset.hf_name)
# Build HF dataloader
dataset_name_or_path = cfg.dataset.hf_name
split = cfg.dataset.get('split')

# If dataset is a remote path, download it first.
backend, _, _ = parse_uri(dataset_name_or_path)
if backend not in ['', None]:
if cfg.dataset.get('split') is None:
if split is None:
raise ValueError(
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores.
split = split.replace('-', '_')
dataset_name_or_path = _download_remote_hf_dataset(
remote_path=dataset_name_or_path, split=split)

# Get the preprocessing function.
proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn')
if isinstance(proto_preprocessing_fn, (dict, DictConfig)):
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(
dict(proto_preprocessing_fn))
else:
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name_or_path)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)
# Build dataset from HF.
dataset = dataset_constructor.build_from_hf(
dataset_name=dataset_name_or_path,
split=split,
safe_load=cfg.dataset.get('safe_load', False),
max_seq_len=cfg.dataset.max_seq_len,
preprocessing_fn=preprocessing_fn,
tokenizer=tokenizer,
hf_kwargs=cfg.dataset.get('kwargs', {}))

# Ensure dataset is large enough.
if cfg.drop_last:
world_size = dist.get_world_size()
minimum_dataset_size = world_size * dataloader_batch_size
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
f'Your dataset (name={cfg.dataset.hf_name}, split={split}) '
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
Expand All @@ -199,22 +210,24 @@ def build_finetuning_dataloader(cfg: DictConfig,
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle),
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)
# Initialize sampler.
sampler = dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle)

assert dataset is not None # for pyright
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func()

Expand Down Expand Up @@ -250,7 +263,7 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)
elif dataset_cfg.get('remote') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn']
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
Expand All @@ -275,11 +288,8 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)


def _build_hf_dataset_from_remote(
cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Builds a dataset from a remote object store.
def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
"""Downloads a dataset from a remote object store.
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
Expand All @@ -290,38 +300,26 @@ def _build_hf_dataset_from_remote(
completed, the function removes the signal file.
Args:
cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
This includes:
- dataset.hf_name: The path of the HuggingFace dataset to download.
- dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
- dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.
tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
hf_name (str): The path of the HuggingFace dataset to download.
split (str): The dataset split to download (e.g., 'train', 'validation', 'test').
Returns:
Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
A local directory path where the dataset files are stored.
Raises:
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
destination_split if destination_split != 'data' else 'data_not',
DOWNLOADED_FT_DATASETS_DIRPATH,
split if split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
for extension in SUPPORTED_EXTENSIONS:
name = f'{remote_path.strip("/")}/{split}{extension}'
destination = str(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{destination_split}-00000-of-00001.{extension}')))
os.path.join(finetune_dir, 'data',
f'{split}-00000-of-00001{extension}')))

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
Expand All @@ -331,14 +329,14 @@ def _build_hf_dataset_from_remote(
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}'
for ext in supported_extensions
f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}'
for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {supported_extensions}\n' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}'
) from e
else:
Expand All @@ -350,25 +348,18 @@ def _build_hf_dataset_from_remote(
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
# Then, wait to ensure every node has finished trying to download the dataset
dist.barrier()

# clean up signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()

cfg.dataset.hf_name = finetune_dir
log.info(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
return dataset
break
return finetune_dir


def _build_collate_fn(
Expand Down
Loading

0 comments on commit f0fd749

Please sign in to comment.