diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 9816cde168..8ba9ae988d 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -3,6 +3,8 @@ try: from llmfoundry.callbacks.async_eval_callback import AsyncEval + from llmfoundry.callbacks.curriculum_learning_callback import \ + CurriculumLearning from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet from llmfoundry.callbacks.fdiff_callback import FDiffMetrics from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer @@ -26,4 +28,5 @@ 'EvalGauntlet', 'HuggingFaceCheckpointer', 'AsyncEval', + 'CurriculumLearning', ] diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py new file mode 100644 index 0000000000..02f0e80309 --- /dev/null +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -0,0 +1,105 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Enable curriculum learning by resuming with a different dataset. + +This callback is currently experimental. The API may change without warning in +the future. +""" + +import logging +from typing import Any, Dict + +from composer.core import Callback, State +from composer.loggers import Logger +from streaming import StreamingDataset +from torch.utils.data import DataLoader + +log = logging.getLogger(__name__) + + +class CurriculumLearning(Callback): + """Starts an epoch with a different dataset when resuming from a checkpoint. + + This callback is currently experimental. The API may change without warning in the future. + + Args: + dataset_index (int): The index of the dataset currently being used. + current_dataset_config (Dict): The configuration of the dataset currently + being used. + """ + + def __init__(self, dataset_index: int, current_dataset_config: Dict): + self.dataset_index = dataset_index + self.saved_dataset_index = 0 + self.all_dataset_configs = [] + self.current_dataset_state = {} + # The current dataset config is resolved and passed in train.py + self.current_dataset_config = current_dataset_config + + def before_load(self, state: State, logger: Logger): + del logger + + # Save the current dataset state so we can restore it correctly + # if we are resuming with a new dataset. + train_loader = state.train_dataloader + # Check if we are using a DataLoader and StreamingDataset + if not isinstance(train_loader, DataLoader): + raise ValueError( + f'CurriculumLearning callback can only be used with a train ', + f'dataloader of type DataLoader, but got {type(train_loader)}.') + dataset = train_loader.dataset + if not isinstance(dataset, StreamingDataset): + raise ValueError( + f'CurriculumLearning callback only supports StreamingDataset ', + f'because it requires loading and saving dataset state. ', + f'Instead, got a dataset of type {type(dataset)}') + assert isinstance(dataset, StreamingDataset) + # Save the current dataset state so we can restore it if needed. + self.current_dataset_state = dataset.state_dict( # type: ignore + num_samples=0, from_beginning=False) + + def after_load(self, state: State, logger: Logger): + del logger + + # As saved_dataset_index is loaded from state_dict, this only runs when + # a user explicitly increments the dataset_index and not on any other + # resumption, including autoresume. + train_loader = state._train_dataloader + assert isinstance( + train_loader, + DataLoader), 'CurriculumLearning callback requires a DataLoader.' + dataset = train_loader.dataset + assert isinstance( + dataset, StreamingDataset + ), 'CurriculumLearning callback requires a StreamingDataset.' + if self.saved_dataset_index < self.dataset_index: + # Ignore the dataset state that was read in from the checkpoint, and + # replace with the new dataset state. This preserves resumption info. + if self.current_dataset_state['epoch'] < 0: + # Make sure the epoch in the loaded state dict is not negative. + # Since `__iter__` has not yet been called on the dataset, the + # epoch index in the dataset will still be -1. We need to ensure + # that we set the epoch correctly to 0 in this case. + self.current_dataset_state['epoch'] = 0 + dataset.load_state_dict( # type: ignore + self.current_dataset_state) + # Start a new epoch since we are using a new dataset. + # This will also reset the sample_in_epoch written to checkpoint, + # making sure that subsequent resumptions proceed correctly. + state.timestamp = state.timestamp.to_next_epoch() + # Append the new dataset config to the list of all dataset configs. + self.all_dataset_configs.append(self.current_dataset_config) + elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0: + # Make sure to track our current dataset config if we are just starting training. + self.all_dataset_configs.append(self.current_dataset_config) + + def state_dict(self): + return { + 'dataset_index': self.dataset_index, + 'all_dataset_configs': self.all_dataset_configs + } + + def load_state_dict(self, state: Dict[str, Any]): + self.saved_dataset_index = state.get('dataset_index', 0) + self.all_dataset_configs = state.get('all_dataset_configs', []) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index f9618c5fa2..082c767288 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -394,8 +394,8 @@ def _save_checkpoint(self, state: State, logger: Logger): os.path.join(local_save_path, license_filename), ) - mlflow_logger.register_model( + mlflow_logger.register_model_with_run_id( model_uri=local_save_path, name=self.mlflow_registered_model_name, - await_registration_for=3600, + await_creation_for=3600, ) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index c6881fd276..7071cbdda3 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -16,7 +16,7 @@ SUPPORTED_EXTENSIONS, dataset_constructor) from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio -from llmfoundry.data.text_data import get_tokens_per_batch_func +from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func log = logging.getLogger(__name__) @@ -128,11 +128,14 @@ def build_finetuning_dataloader(cfg: DictConfig, dataset = None # for pyright sampler = None - if cfg.dataset.get('remote') is not None: + if cfg.dataset.get('remote') is not None or cfg.dataset.get( + 'streams') is not None: # Build streaming dataloader + streams = build_streams(cfg.dataset) dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, - local=cfg.dataset.local, + streams=streams, + local=cfg.dataset.get('local', None), remote=cfg.dataset.get('remote', None), split=cfg.dataset.get('split', None), download_retry=cfg.dataset.get('download_retry', 2), @@ -279,11 +282,38 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'Using a streaming dataset requires setting both `remote` and `local`, ' +\ 'but dataset.local is None.' ) + elif dataset_cfg.get('streams') is not None: + # Using the streaming dataset codepath + illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] + discovered_illegal_keys = [] + for key in illegal_keys: + if dataset_cfg.get(key) is not None: + discovered_illegal_keys.append('`' + key + '`') + if discovered_illegal_keys: + raise ValueError( + 'The dataset config sets a value for `streams` as well as the ' +\ + f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ + 'Those keys are used when building from a HuggingFace dataset, but ' +\ + 'setting `streams` instructs the dataset to build from a streaming dataset.' + ) + illegal_keys = ['remote', 'local'] + discovered_illegal_keys = [] + for key in illegal_keys: + if dataset_cfg.get(key) is not None: + discovered_illegal_keys.append('`' + key + '`') + if discovered_illegal_keys: + raise ValueError( + 'The dataset config sets a value for `streams` as well as the ' +\ + f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\ + 'Please either use single stream (set remote/local only) ' +\ + 'or put remote/local under streams' + ) + else: raise ValueError( - 'In the dataset config, you must set either `hf_name` to use a ' +\ - 'HuggingFace dataset or set `remote` to use a streaming ' +\ - 'dataset, but both were None.' + 'In the dataset config, you must set `hf_name` to use a HuggingFace ' +\ + 'dataset, or set `remote` to use a streaming dataset, or set ' +\ + '`streams` to use multiple streaming datasets, but all were None.' ) if dataset_cfg.get('max_seq_len') is None: raise ValueError( diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 2b397eae96..7f2a5417b4 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -37,14 +37,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import warnings from functools import partial from pathlib import Path -from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union, - cast) +from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence, + Tuple, Union, cast) import datasets as hf_datasets import huggingface_hub as hf_hub import numpy as np from composer.utils import dist -from streaming import StreamingDataset +from streaming import Stream, StreamingDataset from transformers import PreTrainedTokenizerBase from llmfoundry.utils.logging_utils import SpecificWarningFilter @@ -257,12 +257,25 @@ def is_valid_ift_example(pad_token_id: int, max_seq_len: int, non_padding_response) +def _stream_remote_local_validate(remote: Optional[str], local: Optional[str], + split: Optional[str]): + if remote is None or (local == remote): + if local is not None and os.path.isdir(local): + contents = set(os.listdir(local)) + if split is not None and split not in contents: + raise ValueError( + f'local directory {local} does not contain split {split}') + + class StreamingFinetuningDataset(StreamingDataset): """Finetuning dataset with flexible tokenization using StreamingDataset. Args: tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to tokenize samples. + streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from, + which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or + ``remote``/``local``. Defaults to ``None``. local (str): Local dataset directory where shards are cached by split. remote (str, optional): Remote path or directory to download the dataset from. If ``None``, its data must exist locally. StreamingDataset uses either ``streams`` or @@ -313,7 +326,8 @@ class StreamingFinetuningDataset(StreamingDataset): def __init__(self, tokenizer: PreTrainedTokenizerBase, - local: str, + streams: Optional[Sequence[Stream]] = None, + local: Optional[str] = None, remote: Optional[str] = None, split: Optional[str] = None, download_retry: int = 2, @@ -341,15 +355,15 @@ def __init__(self, f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}' ) - if remote is None or (local == remote): - if os.path.isdir(local): - contents = set(os.listdir(local)) - if split not in contents: - raise ValueError( - f'local directory {local} does not contain split {split}' - ) + if streams is None: + _stream_remote_local_validate(remote, local, split) + else: + for stream in streams: + _stream_remote_local_validate(stream.remote, stream.local, + split) super().__init__( + streams=streams, local=local, remote=remote, split=split, diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 3301d455e5..8d7ff5849d 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -232,6 +232,19 @@ def get_sequence_id_from_batch( return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1) +def build_streams(dataset_cfg: DictConfig): + streams_dict = dataset_cfg.pop('streams', None) + # build streams + streams = None + if streams_dict is not None: + streams = [] + for _, stream in streams_dict.items(): + # stream is the streams kwargs + # fwd all kwargs with **stream allows streaming to check args + streams.append(Stream(**stream)) + return streams + + def build_text_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, @@ -240,19 +253,11 @@ def build_text_dataloader( assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' # get kwargs - streams_dict = cfg.dataset.pop('streams', None) mlm_probability = cfg.dataset.pop('mlm_probability', None) eos_token_id = cfg.dataset.pop('eos_token_id', None) bos_token_id = cfg.dataset.pop('bos_token_id', None) - # build streams - streams = None - if streams_dict is not None: - streams = [] - for _, stream in streams_dict.items(): - # stream is the streams kwargs - # fwd all kwargs with **stream allows streaming to check args - streams.append(Stream(**stream)) + streams = build_streams(cfg.dataset) # build dataset potentially with streams dataset = StreamingTextDataset( diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 7b96a261ca..9f1136e597 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -70,6 +70,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): def __init__(self, om_model_config: DictConfig, tokenizer: PreTrainedTokenizerBase): pretrained_model_name_or_path = om_model_config.pretrained_model_name_or_path + pretrained_lora_id_or_path = om_model_config.get( + 'pretrained_lora_id_or_path', None) if not om_model_config.get( 'trust_remote_code', True @@ -249,6 +251,15 @@ def _autoset_attn_implementation_monkeypatch( if peft_config_dict is not None: peft_config = self._get_peft_config(peft_config_dict) + if pretrained_lora_id_or_path is not None: + if not peft_installed: + raise ValueError( + 'PEFT is not installed, but lora_id_or_path was passed. Please install LLM Foundry with the peft extra to use lora_id_or_path.' + ) + from peft import PeftModelForCausalLM + model = PeftModelForCausalLM.from_pretrained( + model, pretrained_lora_id_or_path) + super().__init__( model=model, shift_labels=True, diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index c4d16ad8cd..1666188318 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -7,6 +7,7 @@ import functools from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union +from composer.models.huggingface import maybe_get_underlying_model from transformers import PreTrainedModel from transformers.models.opt.modeling_opt import OPTDecoder @@ -142,7 +143,8 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, # OPT has an extra layer of wrapping, so special case here if isinstance(causal_base_model, OPTDecoder): - model.model._fsdp_wrap = False + underlying_model = maybe_get_underlying_model(model) + underlying_model.model._fsdp_wrap = False model_block = hf_get_hidden_layers(causal_base_model) lm_head = model.get_output_embeddings() # some models (OPT) implement .get_input_embeddings for the causal subclass diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 79dc8c7f25..e9ad8054e2 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -28,6 +28,7 @@ from llmfoundry.models.layers.attention import (is_flash_v1_installed, is_flash_v2_installed) +from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY if is_flash_v2_installed(): try: # This try...except is needed because transformers requires it despite the 'if' statement above @@ -55,17 +56,11 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, - attn_bias_shape, +from llmfoundry.models.layers.attention import (attn_bias_shape, build_attn_bias, gen_slopes) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import \ - FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP from llmfoundry.models.layers.ffn import build_ffn as build_ffn -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY from llmfoundry.models.mpt.configuration_mpt import MPTConfig # NOTE: All utils are imported directly even if unused so that @@ -87,6 +82,10 @@ MODEL_INIT_REGISTRY, ) +from llmfoundry.models.utils.act_ckpt import (pass_on_block_idx, + build_act_ckpt_mod_to_blocks, + check_mapping_blocks_overlap) + try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func except: @@ -352,6 +351,13 @@ def __init__(self, config: MPTConfig): **config.to_dict(), ) for _ in range(config.n_layers) ]) + + # Tag all modules in the transformer blocks with the corresponding block_idx and max_block_idx + for i, block in enumerate(self.blocks): + block.block_idx = i + block.max_block_idx = config.n_layers - 1 + pass_on_block_idx(block) + self.norm_f = norm_class(config.d_model, device=config.init_device) self.rope = config.attn_config['rope'] @@ -908,41 +914,57 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', - None) or ['MPTBlock'] - if isinstance(act_ckpt_list, str): - act_ckpt_list = [act_ckpt_list] - elif not isinstance(act_ckpt_list, list): - raise ValueError( - f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}' + """The MPT activation checkpointing (act ckpt) function. + + When `activation_checkpointing` in fsdp_config is set to true, this function will be called on all the modules in the FSDP wrapped model and determine whether a given module should be activation checkpointed. It checks the checkpointing target (`activation_checkpointing_target` in `model`) which can be specified as below: + 1. null (or no such field): The whole MPTBlock will be activation checkpointed on all layers + 2. a list of modules to act ckpt on all layers, e.g., + activation_checkpointing_target: + - grouped_query_attention + - mptmlp + 3. a dictionary of module name with target_blocks, e.g., + activation_checkpointing_target: + { + "mptblock": target_blocks_1, + "grouped_query_attention": target_blocks_2 + } + target_blocks (target_blocks_1, target_blocks_2 above) can be: + - a single integer n: the first n transformer block will be activation checkpointed + - a string of first-n, middle-m, last-k, range-i-j: the first n, the middle m, the last k, or the range [i, j) layers will be activation checkpointed. E.g, 'first-2, last-2' means the first 2 and last 2 transformer blocks will be activation checkpointed + middle-m is range [start, end) where ``start = max(max_block_idx // 2 - m // 2, 0), end = min(start + m, max_block_idx + 1)`` + - a list of integers corresponds to the list of transformer block ids, e.g., [2] means the second transformer block will be activation checkpointed. [2, 3] means the second and third transformer blocks will be activation checkpointed + - a list of mixed integers and strings of first-n, middle-m, last-k, range-i-j + + An example in yaml config file: + fsdp_config: + activation_checkpointing: true + model: + activation_checkpointing_target: + { + "mptblock": 'first-5', + "grouped_query_attention": 'last-35' + } + """ + if not hasattr(module, 'block_idx'): + log.debug( + f'{module.__class__.__name__} cannot be activation checkpointed. Only transformer block or its submodules are eligible for activation checkpointing.' ) + return False - if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: - if len(act_ckpt_list) > 1: - log.info( - 'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).' - ) - return isinstance(module, MPTBlock) - - mod_types = () - for mod_name in act_ckpt_list: - if mod_name.lower() == 'mptblock': - mod_types += (MPTBlock,) - elif mod_name in ATTN_CLASS_REGISTRY: - mod_types += (ATTN_CLASS_REGISTRY[mod_name],) - elif mod_name in FFN_CLASS_REGISTRY: - mod_types += (FFN_CLASS_REGISTRY[mod_name],) - elif mod_name in NORM_CLASS_REGISTRY: - mod_types += (NORM_CLASS_REGISTRY[mod_name],) - else: - msg = ', '.join( - list(ATTN_CLASS_REGISTRY.keys()) + - list(FFN_CLASS_REGISTRY.keys()) + - list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) - raise ValueError( - f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' - ) - return isinstance(module, mod_types) + act_ckpt_target = getattr(self.config, + 'activation_checkpointing_target', None) + act_ckpt_mod_to_blocks = build_act_ckpt_mod_to_blocks( + act_ckpt_target, MPTBlock, module.max_block_idx) + + check_mapping_blocks_overlap(act_ckpt_mod_to_blocks, + module.max_block_idx) + + for k in act_ckpt_mod_to_blocks.keys(): + if isinstance(module, k): + blocks = act_ckpt_mod_to_blocks[k] + return True if blocks == -1 else module.block_idx in blocks + + return False def prepare_inputs_for_generation( self, diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py new file mode 100644 index 0000000000..08b718929a --- /dev/null +++ b/llmfoundry/models/utils/act_ckpt.py @@ -0,0 +1,147 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import torch + +from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY +from llmfoundry.models.layers.blocks import MPTBlock +from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY +from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY + + +def pass_on_block_idx(parent: torch.nn.Module): + if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'): + return + for child in parent.children(): + child.block_idx = parent.block_idx + child.max_block_idx = parent.max_block_idx + if child.children(): + pass_on_block_idx(child) + + +def get_act_ckpt_module(mod_name: str) -> Any: + """Get the module type from the module name.""" + if mod_name.lower() == 'mptblock': + mod_type = MPTBlock + elif mod_name in ATTN_CLASS_REGISTRY: + mod_type = ATTN_CLASS_REGISTRY[mod_name] + elif mod_name in FFN_CLASS_REGISTRY: + mod_type = FFN_CLASS_REGISTRY[mod_name] + elif mod_name in NORM_CLASS_REGISTRY: + mod_type = NORM_CLASS_REGISTRY[mod_name] + else: + msg = ', '.join( + list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) + raise ValueError( + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' + ) + return mod_type + + +def parse_ele_str(ele: str, max_block_idx: int) -> list: + """Parse a string in target_blocks and return a list of block ids to add. + + Supported formats are: first-n, middle-m, last-k, range-i-j which correspond + to the first n, the middle m, the last k, and the range [i, j). + """ + to_add = None + if ele.startswith('first-'): + assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}' + to_add = list(range(min(int(ele[6:]), max_block_idx + 1))) + elif ele.startswith('last-'): + assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}' + to_add = list( + range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1)) + elif ele.startswith('middle-'): + assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}' + num = int(ele[7:]) + start = max(max_block_idx // 2 - num // 2, 0) + end = min(start + num, max_block_idx + 1) + to_add = list(range(start, end)) + elif ele.startswith('range-'): + r = ele[6:].split('-') + assert len(r) == 2, f'Invalid target_blocks element {ele}' + start, end = int(r[0]), int(r[1]) + start = max(start, 0) + end = min(end, max_block_idx + 1) + to_add = list(range(start, end)) + else: + raise ValueError(f'Invalid target_blocks element {ele}') + return to_add + + +def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list: + """Parse the user input and return a list of block ids.""" + candidate_block_ids = [] + if isinstance(target_blocks, int): + candidate_block_ids = list(range(target_blocks)) + elif isinstance(target_blocks, list): + for ele in target_blocks: + if isinstance(ele, int): + candidate_block_ids.append(ele) + elif isinstance(ele, str): + to_add = parse_ele_str(ele, max_block_idx) + candidate_block_ids.extend(to_add) + else: + raise ValueError( + f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}' + ) + elif isinstance(target_blocks, str): + target_blocks = target_blocks.replace(' ', '') + for ele in target_blocks.split(','): + to_add = parse_ele_str(ele, max_block_idx) + candidate_block_ids.extend(to_add) + else: + raise ValueError( + f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}' + ) + + candidate_block_ids = list(set(candidate_block_ids)) + return candidate_block_ids + + +def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None: + """Check if the block ids in the mapping overlap with each other.""" + all_blocks = [None] * (max_block_idx + 1) + for k, v in mapping.items(): + if v == -1: + v = list(range(max_block_idx + 1)) + for vv in v: + if vv < 0 or vv > max_block_idx: + continue + else: + if all_blocks[vv] is not None: + raise ValueError( + f'Block {vv} is assigned to both {k} and {all_blocks[vv]}.' + ) + else: + all_blocks[vv] = k + + +def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, + max_block_idx: int) -> dict: + act_ckpt_mod_to_blocks = {} + if act_ckpt_target is None or act_ckpt_target == []: + mod = top_module + act_ckpt_mod_to_blocks[mod] = -1 + elif isinstance(act_ckpt_target, str): + mod = get_act_ckpt_module(act_ckpt_target) + act_ckpt_mod_to_blocks[mod] = -1 + elif isinstance(act_ckpt_target, list): + for target in act_ckpt_target: + mod = get_act_ckpt_module(target) + act_ckpt_mod_to_blocks[mod] = -1 + elif isinstance(act_ckpt_target, dict): + for k, v in act_ckpt_target.items(): + mod = get_act_ckpt_module(k) + block_ids = get_target_block_list(v, max_block_idx) + act_ckpt_mod_to_blocks[mod] = block_ids + else: + raise ValueError( + f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}' + ) + + return act_ckpt_mod_to_blocks diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 450cdab3ee..bf960603a2 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -12,7 +12,7 @@ import torch from composer import algorithms from composer.callbacks import (EarlyStopper, Generate, LRMonitor, - MemoryMonitor, OptimizerMonitor, + MemoryMonitor, MemorySnapshot, OptimizerMonitor, RuntimeEstimator, SpeedMonitor) from composer.core import Algorithm, Callback, Evaluator from composer.datasets.in_context_learning_evaluation import \ @@ -30,9 +30,10 @@ from torch.optim.optimizer import Optimizer from transformers import AutoTokenizer, PreTrainedTokenizerBase -from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics, - GlobalLRScaling, HuggingFaceCheckpointer, - LayerFreezing, MonolithicCheckpointSaver, +from llmfoundry.callbacks import (AsyncEval, CurriculumLearning, EvalGauntlet, + FDiffMetrics, GlobalLRScaling, + HuggingFaceCheckpointer, LayerFreezing, + MonolithicCheckpointSaver, ScheduledGarbageCollector) from llmfoundry.data.dataloader import build_dataloader from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, @@ -168,6 +169,8 @@ def build_callback( return LRMonitor() elif name == 'memory_monitor': return MemoryMonitor() + elif name == 'memory_snapshot': + return MemorySnapshot(**kwargs) elif name == 'speed_monitor': return SpeedMonitor(window_size=kwargs.get('window_size', 1), gpu_flops_available=kwargs.get( @@ -214,8 +217,18 @@ def build_callback( if config is None: raise ValueError( 'Parameters config is required for async eval callback') - return AsyncEval(**kwargs, training_params=config) + elif name == 'curriculum_learning': + if config is None: + raise ValueError( + 'Parameters config is required for curriculum learning callback' + ) + if 'train_loader' not in config: + raise ValueError( + 'Curriculum learning callback requires a train_loader key in the run config.' + ) + return CurriculumLearning(**kwargs, + current_dataset_config=config['train_loader']) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py index 368c5725c3..2584ada601 100644 --- a/llmfoundry/utils/warnings.py +++ b/llmfoundry/utils/warnings.py @@ -14,7 +14,7 @@ class VersionedDeprecationWarning(DeprecationWarning): ... warnings.warn( ... VersionedDeprecationWarning( ... "Function XYZ is deprecated.", - ... after_version="2.0.0" + ... remove_version="2.0.0" ... ) ... ) ... diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 99540008aa..a1a32d2a48 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -19,11 +19,9 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from rich.traceback import install -from transformers import (AutoModelForCausalLM, PreTrainedTokenizerBase, - T5ForConditionalGeneration) +from transformers import PreTrainedTokenizerBase install() -from llmfoundry.models import MPTForCausalLM from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_evaluators, build_logger, @@ -34,52 +32,6 @@ log = logging.getLogger(__name__) -def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - num_retries: int) -> ComposerModel: - try: - from peft import PeftModel - except ImportError as e: - raise ImportError( - f'Error importing from peft. Run `pip install -e .[gpu,peft]`. \n {e}' - ) - - model_registry = { - 'mpt_causal_lm': MPTForCausalLM, - 'hf_causal_lm': AutoModelForCausalLM, - 'hf_prefix_lm': AutoModelForCausalLM, - 'hf_t5': T5ForConditionalGeneration, - } - - retries = 0 - composer_model_wrapper = None - while retries < num_retries and composer_model_wrapper is None: - try: - trust_remote_code = model_cfg.get('trust_remote_code', True) - use_auth_token = model_cfg.get('use_auth_token', False) - model = model_registry[model_cfg.name].from_pretrained( - model_cfg.pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - ) - - peft_model = PeftModel.from_pretrained( - model, model_cfg.pretrained_lora_id_or_path) - - composer_model_wrapper = COMPOSER_MODEL_REGISTRY[model_cfg.name]( - peft_model, tokenizer) - except Exception as e: - retries += 1 - if retries >= num_retries: - raise e - else: - log.info( - f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' - ) - - assert composer_model_wrapper is not None - return composer_model_wrapper - - def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, fsdp_config: Optional[Dict], num_retries: int) -> ComposerModel: init_context = process_init_device(model_cfg, fsdp_config) @@ -175,12 +127,8 @@ def evaluate_model( 'The FSDP config block is not supported when loading ' + 'Hugging Face models in 8bit.') - if hasattr(model_cfg.model, 'pretrained_lora_id_or_path'): - composer_model = load_peft_model(model_cfg.model, tokenizer, - num_retries) - else: - composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, - num_retries) + composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, + num_retries) # Now add the eval metrics if eval_loader_config is not None: @@ -205,6 +153,7 @@ def evaluate_model( assert composer_model is not None log.info(f'Building trainer for {model_cfg.model_name}...') + trainer = Trainer( run_name=run_name, seed=seed, diff --git a/scripts/eval/yamls/hf_lora_eval.yml b/scripts/eval/yamls/hf_lora_eval.yml index 0b07efbe4c..08861b8569 100644 --- a/scripts/eval/yamls/hf_lora_eval.yml +++ b/scripts/eval/yamls/hf_lora_eval.yml @@ -2,13 +2,10 @@ max_seq_len: 2048 seed: 1 precision: amp_fp16 -# If you are using one model, put it here: -model_name_or_path: EleutherAI/gpt-neo-125m +model_name_or_path: facebook/opt-350m # If you are using a seperated lora weight, put it here: # lora weights must be compatible with the specified model -lora_id_or_path: edbeeching/gpt-neo-125M-imdb-lora # Example lora weights for gpt-neo-125m - -# otherwise, write a block for each model you want to test in the `models` section +lora_id_or_path: ybelkada/opt-350m-lora # Example lora weights for opt-350m models: - @@ -23,21 +20,6 @@ models: name: ${model_name_or_path} kwargs: model_max_length: ${max_seq_len} -# # if you are evaluating more than one model, list them all as YAML blocks without variable interpolation -# - -# model_name: mosaicml/mpt-7b -# model: -# name: hf_causal_lm -# pretrained_model_name_or_path: mosaicml/mpt-7b -# init_device: cpu -# pretrained: true -# config_overrides: -# max_seq_len: ${max_seq_len} -# tokenizer: -# name: mosaicml/mpt-7b -# kwargs: -# model_max_length: ${max_seq_len} - device_eval_batch_size: 4 diff --git a/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml b/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml index d5fe26d3fc..4047256614 100644 --- a/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml +++ b/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml @@ -24,9 +24,11 @@ train_loader: name: finetuning dataset: ############ - remote: ${data_remote} - local: ${data_local} - split: train + streams: + my_data: + remote: ${data_remote} + local: ${data_local} + split: train ############ shuffle: true max_seq_len: ${max_seq_len} diff --git a/setup.py b/setup.py index 1c42f966a4..c88f566f26 100644 --- a/setup.py +++ b/setup.py @@ -50,11 +50,11 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,oci,gcs]>=0.19,<0.20', + 'mosaicml[libcloud,wandb,oci,gcs]>=0.19.1,<0.20', 'mlflow>=2.10,<3', 'accelerate>=0.25,<0.26', # for HF inference `device_map` 'transformers>=4.37,<4.38', - 'mosaicml-streaming>=0.7.2,<0.8', + 'mosaicml-streaming>=0.7.4,<0.8', 'torch>=2.1,<2.2', 'datasets>=2.16,<2.17', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index ab2d569132..bc4214e76a 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -306,7 +306,7 @@ def test_huggingface_conversion_callback_interval( mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model = MagicMock() + mlflow_logger_mock.register_model_with_run_id = MagicMock() mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' @@ -334,10 +334,10 @@ def test_huggingface_conversion_callback_interval( input_example=ANY, signature=ANY, metadata={'task': 'llm/v1/completions'}) - assert mlflow_logger_mock.register_model.call_count == 1 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: assert mlflow_logger_mock.save_model.call_count == 0 - assert mlflow_logger_mock.register_model.call_count == 0 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 normal_checkpoints = [ name for name in os.listdir(os.path.join(tmp_path, 'checkpoints')) @@ -564,7 +564,7 @@ def test_huggingface_conversion_callback( mlflow_logger_mock = MagicMock(spec=MLFlowLogger) mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model = MagicMock() + mlflow_logger_mock.register_model_with_run_id = MagicMock() mlflow_logger_mock.model_registry_prefix = '' mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' mlflow_logger_mock._run_id = 'mlflow-run-id' @@ -628,10 +628,10 @@ def test_huggingface_conversion_callback( } } mlflow_logger_mock.save_model.assert_called_with(**expectation) - assert mlflow_logger_mock.register_model.call_count == 1 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 else: assert mlflow_logger_mock.log_model.call_count == 0 - assert mlflow_logger_mock.register_model.call_count == 0 + assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 # summon full params to check equivalence from torch.distributed.fsdp import FullyShardedDataParallel as FSDP diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 0a7edc3d7a..319e6eafdf 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -548,31 +548,38 @@ def test_finetuning_dataloader_custom_split_remote(split: str): @pytest.mark.parametrize('pretokenize', [True, False]) +@pytest.mark.parametrize('use_multiple_streams', [True, False]) @pytest.mark.parametrize('use_bytes', [True, False]) -def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool, +def test_finetuning_dataloader_streaming(pretokenize: bool, + use_multiple_streams: bool, + use_bytes: bool, tmp_path: pathlib.Path): max_seq_len = 2048 - remote_path = os.path.join(tmp_path, 'remote') - local_path = os.path.join(tmp_path, 'local') - tokenizer = build_tokenizer( tokenizer_name='gpt2', tokenizer_kwargs={'model_max_length': max_seq_len}, ) - build_mock_ft_streaming_dataset(remote_path, - 'train', - pretokenize, - use_bytes=use_bytes, - tokenizer=tokenizer) + streams_config = {'streams': {}} + num_streams = 2 + for i in range(num_streams): + remote_path = os.path.join(tmp_path, f'remote_{i}') + local_path = os.path.join(tmp_path, f'local_{i}') + build_mock_ft_streaming_dataset(remote_path, + 'train', + pretokenize, + use_bytes=use_bytes, + tokenizer=tokenizer) + streams_config['streams'][f'stream_{i}'] = { + 'remote': remote_path, + 'local': local_path, + 'split': 'train' + } cfg = { 'name': 'finetuning', 'dataset': { - 'remote': remote_path, - 'local': local_path, - 'split': 'train', 'max_seq_len': 2048, 'decoder_only_format': True, 'allow_pad_trimming': False, @@ -586,6 +593,10 @@ def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool, 'persistent_workers': False, 'timeout': 0 } + if use_multiple_streams: + cfg['dataset'].update(streams_config) + else: + cfg['dataset'].update(streams_config['streams']['stream_0']) cfg = om.create(cfg) diff --git a/tests/models/test_fsdp_act_checkpoint.py b/tests/models/test_fsdp_act_checkpoint.py index 987ea5f2a7..97063b25c4 100644 --- a/tests/models/test_fsdp_act_checkpoint.py +++ b/tests/models/test_fsdp_act_checkpoint.py @@ -17,17 +17,20 @@ @pytest.mark.gpu @pytest.mark.parametrize('activation_checkpointing', [True, False]) @pytest.mark.parametrize('activation_checkpointing_target', [ - 'grouped_query_attention', [], ['grouped_query_attention'], - ['mptblock', 'grouped_query_attention'] + 'grouped_query_attention', [], ['grouped_query_attention'], { + 'mptblock': [1], + 'grouped_query_attention': 'first-1, last-1' + } ]) def test_fsdp_act_checkpoint(activation_checkpointing: bool, - activation_checkpointing_target: Union[list, str]): + activation_checkpointing_target: Union[list, str, + dict]): device = get_device('gpu') model_cfg = { 'name': 'mpt_causal_lm', 'd_model': 128, 'n_heads': 4, - 'n_layers': 2, + 'n_layers': 3, 'expansion_ratio': 1, 'max_seq_len': 16, 'vocab_size': 50368, @@ -59,10 +62,7 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, assert not isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. blocks[0], CheckpointWrapper) - elif (not activation_checkpointing_target - ) or activation_checkpointing_target == [ - 'mptblock', 'grouped_query_attention' - ]: + elif (not activation_checkpointing_target): module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ 0]._fsdp_wrapped_module assert isinstance(module, CheckpointWrapper) @@ -72,6 +72,19 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + elif activation_checkpointing_target == { + 'mptblock': [1], + 'grouped_query_attention': 'first-1, last-1' + }: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[1]._fsdp_wrapped_module, CheckpointWrapper) + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[2]._fsdp_wrapped_module.attn, CheckpointWrapper) else: raise ValueError( f'Unknown activation_checkpointing_target: {activation_checkpointing_target}' diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 64a92f6cc6..a3a7fb3814 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -13,6 +13,7 @@ import torch.nn as nn from accelerate import init_empty_weights from composer.core.precision import Precision, get_precision_context +from composer.models.huggingface import maybe_get_underlying_model from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module from composer.utils import dist, get_device, reproducibility @@ -499,8 +500,18 @@ def test_loss_fn(): atol=1e-4), f'differed at step {i}' -def test_opt_wrapping(): - conf = { +@pytest.mark.parametrize('peft_config', [ + None, + { + 'peft_type': 'LORA', + 'task_type': 'CAUSAL_LM' + }, +]) +def test_opt_wrapping(peft_config: Optional[dict[str, str]]): + if peft_config is not None: + _ = pytest.importorskip('peft') + + conf: dict[str, dict[str, Union[str, dict]]] = { 'model': { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'facebook/opt-125m', @@ -510,6 +521,9 @@ def test_opt_wrapping(): 'name': 'facebook/opt-125m' } } + if peft_config is not None: + conf['model']['peft_config'] = peft_config + config = DictConfig(conf) tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer) @@ -519,10 +533,37 @@ def test_opt_wrapping(): model = ComposerHFCausalLM(config.model, tokenizer) # check that all the modules we except are blocked from FSDP wrapping - assert not model.model.model._fsdp_wrap - assert not model.model.model.decoder._fsdp_wrap - assert not model.model.model.decoder.embed_tokens._fsdp_wrap - assert not model.model.lm_head._fsdp_wrap + underlying_model = maybe_get_underlying_model(model.model) + assert not underlying_model.model._fsdp_wrap + assert not underlying_model.model.decoder._fsdp_wrap + assert not underlying_model.model.decoder.embed_tokens._fsdp_wrap + assert not underlying_model.lm_head._fsdp_wrap + + +def test_lora_id(): + peft = pytest.importorskip('peft') + + conf: dict[str, dict[str, Union[str, dict]]] = { + 'model': { + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'facebook/opt-350m', + 'pretrained': 'false', + 'pretrained_lora_id_or_path': 'ybelkada/opt-350m-lora', + }, + 'tokenizer': { + 'name': 'facebook/opt-350m' + } + } + + config = DictConfig(conf) + + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer) + tokenizer = build_tokenizer(config.tokenizer.name, + tokenizer_cfg.get('kwargs', {})) + + model = ComposerHFCausalLM(config.model, tokenizer) + + assert isinstance(model.model, peft.PeftModelForCausalLM) @pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys())