Skip to content

Commit

Permalink
Merge branch 'main' into anna/async-evalloader
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl authored Jan 23, 2024
2 parents 6ce9d83 + 4961436 commit b269d2c
Show file tree
Hide file tree
Showing 47 changed files with 2,156 additions and 572 deletions.
2 changes: 1 addition & 1 deletion .ci/FILE_HEADER
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Copyright 2022 MosaicML LLM Foundry authors
Copyright 2024 MosaicML LLM Foundry authors
SPDX-License-Identifier: Apache-2.0
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ notebooks/
**/*.pt
**/mlruns/*
**/tokenizer-save-dir-*/**
**/.downloaded_finetuning/
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.3.1
rev: v1.5.4
hooks:
- id: insert-license
args:
- --license-filepath
- .ci/FILE_HEADER
- --comment-style
- '#'
- --allow-past-years
types: [python]
- repo: https://github.com/PyCQA/docformatter
rev: v1.5.0
Expand Down
22 changes: 6 additions & 16 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,31 +166,21 @@ def get_eval_parameters(

def validate_interval(interval: Union[str, int, Time],
save_interval: Union[str, int, Time]) -> Time:
if isinstance(save_interval, str):
new_save_interval: Time = Time.from_timestring(save_interval)
elif isinstance(save_interval, int):
new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH)
else:
new_save_interval: Time = save_interval

if isinstance(interval, str):
result: Time = Time.from_timestring(interval)
elif isinstance(interval, int):
result: Time = Time(interval, TimeUnit.EPOCH)
else:
result: Time = interval
new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH)
async_interval = Time.from_input(interval, TimeUnit.EPOCH)

if new_save_interval.unit != result.unit:
if new_save_interval.unit != async_interval.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
if result < new_save_interval:
if async_interval < new_save_interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than save interval'
)
if result.value % new_save_interval.value != 0:
if async_interval.value % new_save_interval.value != 0:
raise ValueError(
'Async eval interval must be a multiple of save interval')
return result
return async_interval


def validate_eval_run_config(
Expand Down
19 changes: 10 additions & 9 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,10 @@ def __init__(
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)

if isinstance(save_interval, str):
save_interval = Time.from_timestring(save_interval)
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval: Time = save_interval
self.save_interval: Time = Time.from_input(save_interval,
TimeUnit.EPOCH)
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.save_interval, include_end_of_training=True)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
Expand Down Expand Up @@ -245,11 +241,16 @@ def _save_checkpoint(self, state: State, logger: Logger):
)

if self.remote_ud is not None:
log.info(f'Uploading HuggingFace formatted checkpoint')
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}'
)
self.remote_ud.upload_file(
state=state,
remote_file_name=os.path.join(save_dir, filename),
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir,
filename)),
overwrite=self.overwrite,
Expand Down
163 changes: 77 additions & 86 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from typing import Tuple, Union

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
Expand All @@ -13,7 +12,9 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func

Expand Down Expand Up @@ -122,8 +123,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
# Build streaming dataloader
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
Expand All @@ -148,48 +154,53 @@ def build_finetuning_dataloader(cfg: DictConfig,
batching_method=cfg.dataset.get('batching_method', 'random'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

else:
backend, _, _ = parse_uri(cfg.dataset.hf_name)
# Build HF dataloader
dataset_name_or_path = cfg.dataset.hf_name
split = cfg.dataset.get('split')

# If dataset is a remote path, download it first.
backend, _, _ = parse_uri(dataset_name_or_path)
if backend not in ['', None]:
if cfg.dataset.get('split') is None:
if split is None:
raise ValueError(
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores.
split = split.replace('-', '_')
dataset_name_or_path = _download_remote_hf_dataset(
remote_path=dataset_name_or_path, split=split)

# Get the preprocessing function.
proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn')
if isinstance(proto_preprocessing_fn, (dict, DictConfig)):
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(
dict(proto_preprocessing_fn))
else:
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name_or_path)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)
# Build dataset from HF.
dataset = dataset_constructor.build_from_hf(
dataset_name=dataset_name_or_path,
split=split,
safe_load=cfg.dataset.get('safe_load', False),
max_seq_len=cfg.dataset.max_seq_len,
preprocessing_fn=preprocessing_fn,
tokenizer=tokenizer,
hf_kwargs=cfg.dataset.get('hf_kwargs', {}))

# Ensure dataset is large enough.
if cfg.drop_last:
world_size = dist.get_world_size()
minimum_dataset_size = world_size * dataloader_batch_size
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
f'Your dataset (name={cfg.dataset.hf_name}, split={split}) '
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
Expand All @@ -199,22 +210,24 @@ def build_finetuning_dataloader(cfg: DictConfig,
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle),
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)
# Initialize sampler.
sampler = dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle)

assert dataset is not None # for pyright
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func()

Expand Down Expand Up @@ -250,7 +263,7 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)
elif dataset_cfg.get('remote') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn']
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
Expand All @@ -275,11 +288,8 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)


def _build_hf_dataset_from_remote(
cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Builds a dataset from a remote object store.
def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
"""Downloads a dataset from a remote object store.
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
Expand All @@ -290,38 +300,26 @@ def _build_hf_dataset_from_remote(
completed, the function removes the signal file.
Args:
cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
This includes:
- dataset.hf_name: The path of the HuggingFace dataset to download.
- dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
- dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.
tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
hf_name (str): The path of the HuggingFace dataset to download.
split (str): The dataset split to download (e.g., 'train', 'validation', 'test').
Returns:
Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
A local directory path where the dataset files are stored.
Raises:
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
destination_split if destination_split != 'data' else 'data_not',
DOWNLOADED_FT_DATASETS_DIRPATH,
split if split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
for extension in SUPPORTED_EXTENSIONS:
name = f'{remote_path.strip("/")}/{split}{extension}'
destination = str(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{destination_split}-00000-of-00001.{extension}')))
os.path.join(finetune_dir, 'data',
f'{split}-00000-of-00001{extension}')))

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
Expand All @@ -331,14 +329,14 @@ def _build_hf_dataset_from_remote(
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}'
for ext in supported_extensions
f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}'
for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {supported_extensions}\n' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}'
) from e
else:
Expand All @@ -350,25 +348,18 @@ def _build_hf_dataset_from_remote(
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
# Then, wait to ensure every node has finished trying to download the dataset
dist.barrier()

# clean up signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()

cfg.dataset.hf_name = finetune_dir
log.info(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
return dataset
break
return finetune_dir


def _build_collate_fn(
Expand Down
Loading

0 comments on commit b269d2c

Please sign in to comment.