Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Dec 19, 2023
1 parent ed4b366 commit 0dc2049
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ notebooks/
**/*.pt
**/mlruns/*
**/tokenizer-save-dir-*/**
**/.downloaded_finetuning/
26 changes: 15 additions & 11 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import (_DOWNLOADED_FT_DATASETS_DIRPATH,
_SUPPORTED_EXTENSIONS,
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func
Expand Down Expand Up @@ -175,10 +175,12 @@ def build_finetuning_dataloader(cfg: DictConfig,

# Get the preprocessing function.
proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn')

preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(dict(proto_preprocessing_fn)) \
if isinstance(proto_preprocessing_fn,(dict, DictConfig)) \
else dataset_constructor.get_preprocessing_fn_from_str(proto_preprocessing_fn, dataset_name_or_path)
if isinstance(proto_preprocessing_fn, (dict, DictConfig)):
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(
dict(proto_preprocessing_fn))
else:
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name_or_path)

# Build dataset from HF.
dataset = dataset_constructor.build_from_hf(
Expand Down Expand Up @@ -308,11 +310,11 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
finetune_dir = os.path.join(
_DOWNLOADED_FT_DATASETS_DIRPATH,
DOWNLOADED_FT_DATASETS_DIRPATH,
split if split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in _SUPPORTED_EXTENSIONS:
for extension in SUPPORTED_EXTENSIONS:
name = f'{remote_path.strip("/")}/{split}{extension}'
destination = str(
os.path.abspath(
Expand All @@ -327,14 +329,14 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == _SUPPORTED_EXTENSIONS[-1]:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{remote_path}/{split}{ext}'
for ext in _SUPPORTED_EXTENSIONS
for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {_SUPPORTED_EXTENSIONS}\n' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}'
) from e
else:
Expand All @@ -356,6 +358,8 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
os.remove(signal_file_path)
dist.barrier()

break

return finetune_dir


Expand Down
26 changes: 17 additions & 9 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,21 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}
_DOWNLOADED_FT_DATASETS_DIRPATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(
os.path.realpath(__file__)))), 'downloaded_finetuning')
_SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']
DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir,
'.downloaded_finetuning'))
SUPPORTED_EXTENSIONS = ['.csv', '.jsonl', '.parquet']


def _is_empty_or_nonexistent(dirpath: str):
def _is_empty_or_nonexistent(dirpath: str) -> bool:
"""Check if a directory is empty or non-existent.
Args:
dirpath (str): Directory path to check.
Returns
True if directory is empty or non-existent. False otherwise.
"""
return not os.path.isdir(dirpath) or len(os.listdir(dirpath)) == 0


Expand Down Expand Up @@ -381,22 +389,22 @@ def build_from_hf(
if not os.path.isdir(dataset_name):
# dataset_name is not a local dir path, download if needed.
local_dataset_dir = os.path.join(
_DOWNLOADED_FT_DATASETS_DIRPATH, dataset_name)
DOWNLOADED_FT_DATASETS_DIRPATH, dataset_name)

if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
# Safely load a dataset from HF Hub with restricted file types.
hf_hub.snapshot_download(
dataset_name,
repo_type='dataset',
allow_patterns=[
'*' + ext for ext in _SUPPORTED_EXTENSIONS
'*' + ext for ext in SUPPORTED_EXTENSIONS
],
token=hf_kwargs.get('token', None),
local_dir_use_symlinks=False,
local_dir=local_dataset_dir)
if _is_empty_or_nonexistent(dirpath=local_dataset_dir):
raise FileNotFoundError(
f'safe_load is set to True. No data files with safe extensions {_SUPPORTED_EXTENSIONS} '
f'safe_load is set to True. No data files with safe extensions {SUPPORTED_EXTENSIONS} '
+ f'found for dataset {dataset_name}. ')
# Set dataset_name to the downloaded location.
dataset_name = local_dataset_dir
Expand All @@ -409,7 +417,7 @@ def build_from_hf(
f for _, _, files in os.walk(dataset_name) for f in files
]
if not all(
Path(f).suffix in _SUPPORTED_EXTENSIONS
Path(f).suffix in SUPPORTED_EXTENSIONS
for f in dataset_files):
raise ValueError(
f'Dataset at local path {dataset_name} contains invalid file types. '
Expand Down
11 changes: 7 additions & 4 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from llmfoundry.data import build_dataloader
from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
_DOWNLOADED_FT_DATASETS_DIRPATH,
_SUPPORTED_EXTENSIONS,
DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
_tokenize_formatted_example)
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
build_text_dataloader,
Expand Down Expand Up @@ -341,13 +341,13 @@ def test_finetuning_dataloader_safe_load(hf_name: str,

# If no raised errors, we should expect downlaoded files with only safe file types.
if expectation == does_not_raise():
download_dir = os.path.join(_DOWNLOADED_FT_DATASETS_DIRPATH, hf_name)
download_dir = os.path.join(DOWNLOADED_FT_DATASETS_DIRPATH, hf_name)
downloaded_files = [
file for _, _, files in os.walk(download_dir) for file in files
]
assert len(downloaded_files) > 0
assert all(
Path(file).suffix in _SUPPORTED_EXTENSIONS
Path(file).suffix in SUPPORTED_EXTENSIONS
for file in downloaded_files)


Expand Down Expand Up @@ -488,6 +488,9 @@ def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str):
def mock_get_file(path: str, destination: str, overwrite: bool = False):
if Path(destination).suffix == '.jsonl':
make_tiny_ft_dataset(path=destination, size=16)
else:
raise FileNotFoundError(
f'Test error in mock_get_file. {path} does not exist.')


@pytest.mark.parametrize('split', ['train', 'custom', 'custom-dash', 'data'])
Expand Down

0 comments on commit 0dc2049

Please sign in to comment.