Skip to content

Commit

Permalink
Merge branch 'main' into composer_lora
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 19, 2024
2 parents 935bce8 + 35bb339 commit beaa86c
Show file tree
Hide file tree
Showing 43 changed files with 2,019 additions and 531 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,4 @@ notebooks/
**/*.pt
**/mlruns/*
**/tokenizer-save-dir-*/**
**/.downloaded_finetuning/
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ Tutorial videos from the community:
Something missing? Contribute with a PR!

# Latest News
* [Blog: LLM Training and Inference with Intel Gaudi2 AI Accelerators](https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators)
* [Blog: Training LLMs at Scale with AMD MI250 GPUs](https://www.databricks.com/blog/training-llms-scale-amd-mi250-gpus)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
* [Blog: Announcing MPT-7B-8K: 8K Context Length for Document Understanding](https://www.mosaicml.com/blog/long-context-mpt-7b-8k)
* [Blog: Training LLMs with AMD MI250 GPUs and MosaicML](https://www.mosaicml.com/blog/amd-mi250)
* [Blog: MPT-30B: Raising the bar for open-source foundation models](https://www.mosaicml.com/blog/mpt-30b)
Expand Down Expand Up @@ -186,6 +189,12 @@ Notes:
1. `attn_impl: triton` does not work.
1. We don't yet have a Docker image where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.

### Intel Gaudi
Support for LLM Foundry on Intel Gaudi devices is experimental, please use the branch `habana_alpha` and see the [README on that branch](https://github.com/mosaicml/llm-foundry/blob/habana_alpha) which has [install instructions and known issues.](https://github.com/mosaicml/llm-foundry/tree/habana_alpha?tab=readme-ov-file#intel-gaudi)

For training and inference performance results on Intel Gaudi2 accelerators, see our blog: https://www.databricks.com/blog/llm-training-and-inference-intel-gaudi2-ai-accelerators


# Quickstart

> **Note**
Expand Down
26 changes: 8 additions & 18 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,31 +158,21 @@ def get_eval_parameters(

def validate_interval(interval: Union[str, int, Time],
save_interval: Union[str, int, Time]) -> Time:
if isinstance(save_interval, str):
new_save_interval: Time = Time.from_timestring(save_interval)
elif isinstance(save_interval, int):
new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH)
else:
new_save_interval: Time = save_interval

if isinstance(interval, str):
result: Time = Time.from_timestring(interval)
elif isinstance(interval, int):
result: Time = Time(interval, TimeUnit.EPOCH)
else:
result: Time = interval

if new_save_interval.unit != result.unit:

new_save_interval = Time.from_input(save_interval, TimeUnit.EPOCH)
async_interval = Time.from_input(interval, TimeUnit.EPOCH)

if new_save_interval.unit != async_interval.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
if result < new_save_interval:
if async_interval < new_save_interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than save interval'
)
if result.value % new_save_interval.value != 0:
if async_interval.value % new_save_interval.value != 0:
raise ValueError(
'Async eval interval must be a multiple of save interval')
return result
return async_interval


class AsyncEval(Callback):
Expand Down
19 changes: 10 additions & 9 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,10 @@ def __init__(
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)

if isinstance(save_interval, str):
save_interval = Time.from_timestring(save_interval)
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval: Time = save_interval
self.save_interval: Time = Time.from_input(save_interval,
TimeUnit.EPOCH)
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.save_interval, include_end_of_training=True)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
Expand Down Expand Up @@ -254,11 +250,16 @@ def _save_checkpoint(self, state: State, logger: Logger):
)

if self.remote_ud is not None:
log.info(f'Uploading HuggingFace formatted checkpoint')
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}'
)
self.remote_ud.upload_file(
state=state,
remote_file_name=os.path.join(save_dir, filename),
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir,
filename)),
overwrite=self.overwrite,
Expand Down
163 changes: 77 additions & 86 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from typing import Tuple, Union

import datasets as hf_datasets
import torch
from composer.core.data_spec import DataSpec
from composer.utils import dist, get_file, parse_uri
Expand All @@ -13,7 +12,9 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func

Expand Down Expand Up @@ -122,8 +123,13 @@ def build_finetuning_dataloader(cfg: DictConfig,
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
# Build streaming dataloader
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
Expand All @@ -148,48 +154,53 @@ def build_finetuning_dataloader(cfg: DictConfig,
batching_method=cfg.dataset.get('batching_method', 'random'),
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)

dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

else:
backend, _, _ = parse_uri(cfg.dataset.hf_name)
# Build HF dataloader
dataset_name_or_path = cfg.dataset.hf_name
split = cfg.dataset.get('split')

# If dataset is a remote path, download it first.
backend, _, _ = parse_uri(dataset_name_or_path)
if backend not in ['', None]:
if cfg.dataset.get('split') is None:
if split is None:
raise ValueError(
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores.
split = split.replace('-', '_')
dataset_name_or_path = _download_remote_hf_dataset(
remote_path=dataset_name_or_path, split=split)

# Get the preprocessing function.
proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn')
if isinstance(proto_preprocessing_fn, (dict, DictConfig)):
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict(
dict(proto_preprocessing_fn))
else:
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name_or_path)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg, tokenizer, device_batch_size)
# Build dataset from HF.
dataset = dataset_constructor.build_from_hf(
dataset_name=dataset_name_or_path,
split=split,
safe_load=cfg.dataset.get('safe_load', False),
max_seq_len=cfg.dataset.max_seq_len,
preprocessing_fn=preprocessing_fn,
tokenizer=tokenizer,
hf_kwargs=cfg.dataset.get('hf_kwargs', {}))

# Ensure dataset is large enough.
if cfg.drop_last:
world_size = dist.get_world_size()
minimum_dataset_size = world_size * dataloader_batch_size
if hasattr(dataset, '__len__'):
full_dataset_size = len(dataset)
if full_dataset_size < minimum_dataset_size:
raise ValueError(
f'Your dataset (name={cfg.dataset.hf_name}, split={cfg.dataset.split}) '
f'Your dataset (name={cfg.dataset.hf_name}, split={split}) '
+
f'has {full_dataset_size} samples, but your minimum batch size '
+
Expand All @@ -199,22 +210,24 @@ def build_finetuning_dataloader(cfg: DictConfig,
+
f'of samples in your dataset to at least {minimum_dataset_size}.'
)

assert dataset is not None
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle),
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)
# Initialize sampler.
sampler = dist.get_sampler(dataset,
drop_last=cfg.drop_last,
shuffle=cfg.dataset.shuffle)

assert dataset is not None # for pyright
dl = DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=dataloader_batch_size,
drop_last=cfg.drop_last,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.get('pin_memory', True),
prefetch_factor=cfg.get('prefetch_factor', 2),
persistent_workers=cfg.get('persistent_workers', True),
timeout=cfg.get('timeout', 0),
)

token_counting_func = get_tokens_per_batch_func()

Expand Down Expand Up @@ -250,7 +263,7 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)
elif dataset_cfg.get('remote') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn']
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
Expand All @@ -275,11 +288,8 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
)


def _build_hf_dataset_from_remote(
cfg: DictConfig, tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
"""Builds a dataset from a remote object store.
def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
"""Downloads a dataset from a remote object store.
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
Expand All @@ -290,38 +300,26 @@ def _build_hf_dataset_from_remote(
completed, the function removes the signal file.
Args:
cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
This includes:
- dataset.hf_name: The path of the HuggingFace dataset to download.
- dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
- dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.
tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
hf_name (str): The path of the HuggingFace dataset to download.
split (str): The dataset split to download (e.g., 'train', 'validation', 'test').
Returns:
Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
A local directory path where the dataset files are stored.
Raises:
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
# HF datasets does not support a split with dashes, so we replace dashes
# with underscores in the destination split.
destination_split = cfg.dataset.split.replace('-', '_')
finetune_dir = os.path.join(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))),
'downloaded_finetuning',
destination_split if destination_split != 'data' else 'data_not',
DOWNLOADED_FT_DATASETS_DIRPATH,
split if split != 'data' else 'data_not',
)
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
for extension in SUPPORTED_EXTENSIONS:
name = f'{remote_path.strip("/")}/{split}{extension}'
destination = str(
os.path.abspath(
os.path.join(
finetune_dir, 'data',
f'{destination_split}-00000-of-00001.{extension}')))
os.path.join(finetune_dir, 'data',
f'{split}-00000-of-00001{extension}')))

# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
Expand All @@ -331,14 +329,14 @@ def _build_hf_dataset_from_remote(
try:
get_file(path=name, destination=destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
if extension == SUPPORTED_EXTENSIONS[-1]:
files_searched = [
f'{cfg.dataset.hf_name}/{cfg.dataset.split}.{ext}'
for ext in supported_extensions
f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}'
for ext in SUPPORTED_EXTENSIONS
]
raise FileNotFoundError(
f'Could not find a file with any of ' + \
f'the supported extensions: {supported_extensions}\n' + \
f'the supported extensions: {SUPPORTED_EXTENSIONS}\n' + \
f'at {files_searched}'
) from e
else:
Expand All @@ -350,25 +348,18 @@ def _build_hf_dataset_from_remote(
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# Avoid the collective call until the local rank zero has finished trying to download the dataset
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
# Then, wait to ensure every node has finished trying to download the dataset
dist.barrier()

# clean up signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()

cfg.dataset.hf_name = finetune_dir
log.info(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
return dataset
break
return finetune_dir


def _build_collate_fn(
Expand Down
Loading

0 comments on commit beaa86c

Please sign in to comment.