Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe_load option to restrict HF dataset downloads to allowed file types #798

Merged
merged 23 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ notebooks/
**/*.pt
**/mlruns/*
**/tokenizer-save-dir-*/**
**/.downloaded_finetuning/
163 changes: 77 additions & 86 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from typing import Tuple, Union

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
Expand All @@ -13,7 +12,9 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func

Expand Down Expand Up @@ -122,8 +123,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

collate_fn, dataloader_batch_size = _build_collate_fn(
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
cfg, tokenizer, device_batch_size)

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
# Build streaming dataloader
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
Expand All @@ -148,48 +154,53 @@ def build_finetuning_dataloader(cfg: DictConfig,
batching_method=cfg.dataset.get('batching_method', 'random'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

else:
backend, _, _ = parse_uri(cfg.dataset.hf_name)
# Build HF dataloader
dataset_name_or_path = cfg.dataset.hf_name
split = cfg.dataset.get('split')

# If dataset is a remote path, download it first.
backend, _, _ = parse_uri(dataset_name_or_path)
if backend not in ['', None]:
if cfg.dataset.get('split') is None:
if split is None:
raise ValueError(
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores.
split = split.replace('-', '_')
dataset_name_or_path = _download_remote_hf_dataset(
remote_path=dataset_name_or_path, split=split)

# Get the preprocessing function.
proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn')
irenedea marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(proto_preprocessing_fn, (dict, DictConfig)):
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(
dict(proto_preprocessing_fn))
else:
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name_or_path)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)
# Build dataset from HF.
dataset = dataset_constructor.build_from_hf(
dataset_name=dataset_name_or_path,
split=split,
safe_load=cfg.dataset.get('safe_load', False),
max_seq_len=cfg.dataset.max_seq_len,
preprocessing_fn=preprocessing_fn,
tokenizer=tokenizer,
hf_kwargs=cfg.dataset.get('kwargs', {}))

# Ensure dataset is large enough.
if cfg.drop_last:
world_size = dist.get_world_size()
minimum_dataset_size = world_size * dataloader_batch_size
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
f'Your dataset (name={cfg.dataset.hf_name}, split={split}) '
irenedea marked this conversation as resolved.
Show resolved Hide resolved
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
Expand All @@ -199,22 +210,24 @@ def build_finetuning_dataloader(cfg: DictConfig,
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle),
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)
# Initialize sampler.
sampler = dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle)

assert dataset is not None # for pyright
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func()

Expand Down Expand Up @@ -250,7 +263,7 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)
elif dataset_cfg.get('remote') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn']
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
Expand All @@ -275,11 +288,8 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)


def _build_hf_dataset_from_remote(
cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Builds a dataset from a remote object store.
def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
"""Downloads a dataset from a remote object store.

This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
Expand All @@ -290,38 +300,26 @@ def _build_hf_dataset_from_remote(
completed, the function removes the signal file.

Args:
cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
This includes:
- dataset.hf_name: The path of the HuggingFace dataset to download.
- dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
- dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.

tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
hf_name (str): The path of the HuggingFace dataset to download.
split (str): The dataset split to download (e.g., 'train', 'validation', 'test').

Returns:
Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
A local directory path where the dataset files are stored.

Raises:
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
destination_split if destination_split != 'data' else 'data_not',
DOWNLOADED_FT_DATASETS_DIRPATH,
split if split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
for extension in SUPPORTED_EXTENSIONS:
name = f'{remote_path.strip("/")}/{split}{extension}'
irenedea marked this conversation as resolved.
Show resolved Hide resolved
destination = str(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{destination_split}-00000-of-00001.{extension}')))
os.path.join(finetune_dir, 'data',
f'{split}-00000-of-00001{extension}')))

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
Expand All @@ -331,14 +329,14 @@ def _build_hf_dataset_from_remote(
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}'
for ext in supported_extensions
f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}'
for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {supported_extensions}\n' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}'
) from e
else:
Expand All @@ -350,25 +348,18 @@ def _build_hf_dataset_from_remote(
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
# Then, wait to ensure every node has finished trying to download the dataset
dist.barrier()

# clean up signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()

cfg.dataset.hf_name = finetune_dir
log.info(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
return dataset
break
return finetune_dir


def _build_collate_fn(
Expand Down
Loading
Loading