Skip to content

Commit

Permalink
Merge branch 'hf_parsing_with_icl_refactor' of github.com:maxisawesom…
Browse files Browse the repository at this point in the history
…e/llm-foundry into hf_parsing_with_icl_refactor
  • Loading branch information
maxisawesome committed Feb 12, 2024
2 parents 7849528 + 6ce8cc6 commit 0ffab21
Show file tree
Hide file tree
Showing 20 changed files with 537 additions and 187 deletions.
3 changes: 3 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

try:
from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.curriculum_learning_callback import \
CurriculumLearning
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
Expand All @@ -26,4 +28,5 @@
'EvalGauntlet',
'HuggingFaceCheckpointer',
'AsyncEval',
'CurriculumLearning',
]
105 changes: 105 additions & 0 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Enable curriculum learning by resuming with a different dataset.
This callback is currently experimental. The API may change without warning in
the future.
"""

import logging
from typing import Any, Dict

from composer.core import Callback, State
from composer.loggers import Logger
from streaming import StreamingDataset
from torch.utils.data import DataLoader

log = logging.getLogger(__name__)


class CurriculumLearning(Callback):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
This callback is currently experimental. The API may change without warning in the future.
Args:
dataset_index (int): The index of the dataset currently being used.
current_dataset_config (Dict): The configuration of the dataset currently
being used.
"""

def __init__(self, dataset_index: int, current_dataset_config: Dict):
self.dataset_index = dataset_index
self.saved_dataset_index = 0
self.all_dataset_configs = []
self.current_dataset_state = {}
# The current dataset config is resolved and passed in train.py
self.current_dataset_config = current_dataset_config

def before_load(self, state: State, logger: Logger):
del logger

# Save the current dataset state so we can restore it correctly
# if we are resuming with a new dataset.
train_loader = state.train_dataloader
# Check if we are using a DataLoader and StreamingDataset
if not isinstance(train_loader, DataLoader):
raise ValueError(
f'CurriculumLearning callback can only be used with a train ',
f'dataloader of type DataLoader, but got {type(train_loader)}.')
dataset = train_loader.dataset
if not isinstance(dataset, StreamingDataset):
raise ValueError(
f'CurriculumLearning callback only supports StreamingDataset ',
f'because it requires loading and saving dataset state. ',
f'Instead, got a dataset of type {type(dataset)}')
assert isinstance(dataset, StreamingDataset)
# Save the current dataset state so we can restore it if needed.
self.current_dataset_state = dataset.state_dict( # type: ignore
num_samples=0, from_beginning=False)

def after_load(self, state: State, logger: Logger):
del logger

# As saved_dataset_index is loaded from state_dict, this only runs when
# a user explicitly increments the dataset_index and not on any other
# resumption, including autoresume.
train_loader = state._train_dataloader
assert isinstance(
train_loader,
DataLoader), 'CurriculumLearning callback requires a DataLoader.'
dataset = train_loader.dataset
assert isinstance(
dataset, StreamingDataset
), 'CurriculumLearning callback requires a StreamingDataset.'
if self.saved_dataset_index < self.dataset_index:
# Ignore the dataset state that was read in from the checkpoint, and
# replace with the new dataset state. This preserves resumption info.
if self.current_dataset_state['epoch'] < 0:
# Make sure the epoch in the loaded state dict is not negative.
# Since `__iter__` has not yet been called on the dataset, the
# epoch index in the dataset will still be -1. We need to ensure
# that we set the epoch correctly to 0 in this case.
self.current_dataset_state['epoch'] = 0
dataset.load_state_dict( # type: ignore
self.current_dataset_state)
# Start a new epoch since we are using a new dataset.
# This will also reset the sample_in_epoch written to checkpoint,
# making sure that subsequent resumptions proceed correctly.
state.timestamp = state.timestamp.to_next_epoch()
# Append the new dataset config to the list of all dataset configs.
self.all_dataset_configs.append(self.current_dataset_config)
elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0:
# Make sure to track our current dataset config if we are just starting training.
self.all_dataset_configs.append(self.current_dataset_config)

def state_dict(self):
return {
'dataset_index': self.dataset_index,
'all_dataset_configs': self.all_dataset_configs
}

def load_state_dict(self, state: Dict[str, Any]):
self.saved_dataset_index = state.get('dataset_index', 0)
self.all_dataset_configs = state.get('all_dataset_configs', [])
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
os.path.join(local_save_path, license_filename),
)

mlflow_logger.register_model(
mlflow_logger.register_model_with_run_id(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
await_registration_for=3600,
await_creation_for=3600,
)
42 changes: 36 additions & 6 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func
from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,11 +128,14 @@ def build_finetuning_dataloader(cfg: DictConfig,

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
if cfg.dataset.get('remote') is not None or cfg.dataset.get(
'streams') is not None:
# Build streaming dataloader
streams = build_streams(cfg.dataset)
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
streams=streams,
local=cfg.dataset.get('local', None),
remote=cfg.dataset.get('remote', None),
split=cfg.dataset.get('split', None),
download_retry=cfg.dataset.get('download_retry', 2),
Expand Down Expand Up @@ -279,11 +282,38 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
'Using a streaming dataset requires setting both `remote` and `local`, ' +\
'but dataset.local is None.'
)
elif dataset_cfg.get('streams') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Those keys are used when building from a HuggingFace dataset, but ' +\
'setting `streams` instructs the dataset to build from a streaming dataset.'
)
illegal_keys = ['remote', 'local']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Please either use single stream (set remote/local only) ' +\
'or put remote/local under streams'
)

else:
raise ValueError(
'In the dataset config, you must set either `hf_name` to use a ' +\
'HuggingFace dataset or set `remote` to use a streaming ' +\
'dataset, but both were None.'
'In the dataset config, you must set `hf_name` to use a HuggingFace ' +\
'dataset, or set `remote` to use a streaming dataset, or set ' +\
'`streams` to use multiple streaming datasets, but all were None.'
)
if dataset_cfg.get('max_seq_len') is None:
raise ValueError(
Expand Down
36 changes: 25 additions & 11 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
Tuple, Union, cast)

import datasets as hf_datasets
import huggingface_hub as hf_hub
import numpy as np
from composer.utils import dist
from streaming import StreamingDataset
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.logging_utils import SpecificWarningFilter
Expand Down Expand Up @@ -257,12 +257,25 @@ def is_valid_ift_example(pad_token_id: int, max_seq_len: int,
non_padding_response)


def _stream_remote_local_validate(remote: Optional[str], local: Optional[str],
split: Optional[str]):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}')


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str): Local dataset directory where shards are cached by split.
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
Expand Down Expand Up @@ -313,7 +326,8 @@ class StreamingFinetuningDataset(StreamingDataset):

def __init__(self,
tokenizer: PreTrainedTokenizerBase,
local: str,
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
Expand Down Expand Up @@ -341,15 +355,15 @@ def __init__(self,
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}'
)

if remote is None or (local == remote):
if os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}'
)
if streams is None:
_stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
_stream_remote_local_validate(stream.remote, stream.local,
split)

super().__init__(
streams=streams,
local=local,
remote=remote,
split=split,
Expand Down
23 changes: 14 additions & 9 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ def get_sequence_id_from_batch(
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)


def build_streams(dataset_cfg: DictConfig):
streams_dict = dataset_cfg.pop('streams', None)
# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
return streams


def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -240,19 +253,11 @@ def build_text_dataloader(
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'

# get kwargs
streams_dict = cfg.dataset.pop('streams', None)
mlm_probability = cfg.dataset.pop('mlm_probability', None)
eos_token_id = cfg.dataset.pop('eos_token_id', None)
bos_token_id = cfg.dataset.pop('bos_token_id', None)

# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
streams = build_streams(cfg.dataset)

# build dataset potentially with streams
dataset = StreamingTextDataset(
Expand Down
11 changes: 11 additions & 0 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
def __init__(self, om_model_config: DictConfig,
tokenizer: PreTrainedTokenizerBase):
pretrained_model_name_or_path = om_model_config.pretrained_model_name_or_path
pretrained_lora_id_or_path = om_model_config.get(
'pretrained_lora_id_or_path', None)

if not om_model_config.get(
'trust_remote_code', True
Expand Down Expand Up @@ -249,6 +251,15 @@ def _autoset_attn_implementation_monkeypatch(
if peft_config_dict is not None:
peft_config = self._get_peft_config(peft_config_dict)

if pretrained_lora_id_or_path is not None:
if not peft_installed:
raise ValueError(
'PEFT is not installed, but lora_id_or_path was passed. Please install LLM Foundry with the peft extra to use lora_id_or_path.'
)
from peft import PeftModelForCausalLM
model = PeftModelForCausalLM.from_pretrained(
model, pretrained_lora_id_or_path)

super().__init__(
model=model,
shift_labels=True,
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/hf/hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union

from composer.models.huggingface import maybe_get_underlying_model
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder

Expand Down Expand Up @@ -142,7 +143,8 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel,

# OPT has an extra layer of wrapping, so special case here
if isinstance(causal_base_model, OPTDecoder):
model.model._fsdp_wrap = False
underlying_model = maybe_get_underlying_model(model)
underlying_model.model._fsdp_wrap = False
model_block = hf_get_hidden_layers(causal_base_model)
lm_head = model.get_output_embeddings()
# some models (OPT) implement .get_input_embeddings for the causal subclass
Expand Down
Loading

0 comments on commit 0ffab21

Please sign in to comment.