diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index abd9fdef83..83a9a7d8ea 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -3,10 +3,9 @@ """Dataloader builder utilities.""" -from typing import Union +from typing import Any, Dict from composer import DataSpec -from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase from llmfoundry import registry @@ -18,9 +17,9 @@ def build_dataloader( - cfg: DictConfig, + cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, - device_batch_size: Union[int, float], + device_batch_size: int, ) -> DataSpec: """Builds a dataloader from a config. @@ -30,14 +29,15 @@ def build_dataloader( device_batch_size (int): The size of the batches (number of examples) that the dataloader will produce. """ - kwargs = { - 'cfg': cfg, + name = cfg.pop('name') + kwargs: Dict[str, Any] = { + **cfg, 'tokenizer': tokenizer, 'device_batch_size': device_batch_size, } return construct_from_registry( - name=cfg.name, + name=name, registry=registry.dataloaders, partial_function=False, pre_validation_function=None, diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 3d2c77506e..af5eccbc77 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import os -from typing import Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from composer.core.data_spec import DataSpec @@ -23,6 +23,7 @@ ) from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio from llmfoundry.data.text_data import build_streams +from llmfoundry.utils.config_utils import to_dict_container from llmfoundry.utils.exceptions import ( MissingHuggingFaceURLSplitError, NotEnoughDatasetSamplesError, @@ -44,9 +45,16 @@ def build_finetuning_dataloader( - cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: Union[int, float], + device_batch_size: int, + dataset: Dict[str, Any], + num_workers: int, + drop_last: bool = False, + pin_memory: bool = True, + prefetch_factor: int = 2, + persistent_workers: bool = True, + name: Optional[str] = None, + timeout: int = 0, ) -> DataSpec: """Builds a finetuning dataloader for training or evaluating. @@ -57,18 +65,17 @@ def build_finetuning_dataloader( on which you intend to use, as explained below. Args: - cfg (DictConfig): An omegaconf dictionary used to configure the loader: - cfg.name (str): The type of dataloader to build. Must = "finetuning". - --- - *** HuggingFace dataset config fields *** - cfg.dataset.hf_name (str, optional): The name of the HuggingFace dataset + name (str): The type of dataloader to build. Must = "finetuning". + --- + *** HuggingFace dataset config fields *** + dataset.hf_name (str, optional): The name of the HuggingFace dataset to use. Can also be a remote http(s) directory or object store bucket containing the file {split}.jsonl in the format (prompt, response), in which case the builder will create a HuggingFace dataset. - cfg.dataset.hf_kwargs (DictConfig, optional): Additional kwargs to + dataset.hf_kwargs (DictConfig, optional): Additional kwargs to pass to `datasets.load_dataset`, which can be used to load a dataset from local files. - cfg.dataset.preprocessing_fn (str, optional): The name/import path of + dataset.preprocessing_fn (str, optional): The name/import path of the preprocessing function to use for formatting the data examples. If ``None`` (default), the builder will use the preprocessing function registered under `hf_name` (see `tasks.py`), if one exists, @@ -80,30 +87,30 @@ def build_finetuning_dataloader( `from import.path import function_name` and use the imported function as the preprocessing function. *** Streaming dataset config fields *** - cfg.dataset.remote (str, optional): Location of a MDS-formatted + dataset.remote (str, optional): Location of a MDS-formatted streaming dataset to use. Setting this will tell the builder to create a streaming dataset rather than a HuggingFace dataset. - cfg.dataset.local (str, optional): Local path where remote data + dataset.local (str, optional): Local path where remote data will be streamed to. Only valid if `cfg.dataset.remote` has also been set. *** Shared dataset configs fields *** - cfg.dataset.max_seq_len (int): The maximum length of sequences + dataset.max_seq_len (int): The maximum length of sequences in the batch. See :class:`Seq2SeqFinetuningCollator` docstring for details. - cfg.dataset.decoder_only_format (bool): Whether to format the + dataset.decoder_only_format (bool): Whether to format the examples for a decoder-only model. See :class:`Seq2SeqFinetuningCollator` docstring for details. - cfg.dataset.target_responses (str): Which responses are used as training targets. + dataset.target_responses (str): Which responses are used as training targets. Defaults to "last", meaning only the final response in multi-turn examples will serve as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for details. - cfg.dataset.target_prompts (str): Which prompts are used as training targets. + dataset.target_prompts (str): Which prompts are used as training targets. Defaults to "none", meaning prompts are never used as training targets. See :class:`Seq2SeqFinetuningCollator` docstring for details. - cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow + dataset.allow_pad_trimming (bool, optional): Whether to allow the collator to trim padding. See :class:`Seq2SeqFinetuningCollator` docstring for details. Default: ``False``. - cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes + dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes a collator wrapper that packs device_batch_size*packing_ratio raw examples into device_batch_size packed examples. This helps minimize padding while preserving sequence integrity. @@ -123,19 +130,19 @@ def build_finetuning_dataloader( statistics, max_seq_len, and tolerance for discarding samples! The script `scripts/misc/profile_packing.py` can help you choose the best packing_ratio. - cfg.dataset.shuffle (bool): Whether to shuffle the dataset. + dataset.shuffle (bool): Whether to shuffle the dataset. ___ See :class:`StreamingFinetuningDataset` for info on other standard config - options within `cfg.dataset` that will be passed as kwargs if + options within `dataset` that will be passed as kwargs if using the streaming codepath. --- - See :class:`DataLoader` for standard argument options to the pytorch - dataloader, such as `cfg.drop_last`, `cfg.num_workers`, etc. tokenizer (transformers.PreTrainedTokenizer): The tokenizer used to prepare the data from raw text. Any missing sentinel tokens will be added by the collator. device_batch_size (int, float): The size of the batches (number of examples) that the dataloader will produce. + See :class:`DataLoader` for standard argument options to the pytorch + dataloader, such as `drop_last`, `num_workers`, etc. Returns: A pytorch dataloader @@ -145,18 +152,31 @@ def build_finetuning_dataloader( padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ - _validate_config(cfg.dataset) + dataset_cfg = dataset + _validate_config(**dataset_cfg) # Use EOS as the pad token if none exists - if tokenizer.pad_token is None: + if tokenizer.pad_token is None: # type: ignore (sometimes it's none and that's ok) tokenizer.pad_token = tokenizer.eos_token + # this full config is necessary for properly profiling the packing ratio + dataloader_cfg = { + 'name': name, + 'dataset': dataset_cfg, + 'drop_last': drop_last, + 'num_workers': num_workers, + 'pin_memory': pin_memory, + 'prefetch_factor': prefetch_factor, + 'persistent_workers': persistent_workers, + 'timeout': timeout, + } + replication_factor, dataset_batch_size = construct_from_registry( name='dataset_replication_validator', registry=registry.dataset_replication_validators, partial_function=False, kwargs={ - 'cfg': cfg, + 'dataset_cfg': dataset_cfg, 'tokenizer': tokenizer, 'device_batch_size': device_batch_size, }, @@ -167,51 +187,59 @@ def build_finetuning_dataloader( registry=registry.collators, partial_function=False, kwargs={ - 'cfg': cfg, + 'dataloader_cfg': dataloader_cfg, 'tokenizer': tokenizer, 'dataset_batch_size': dataset_batch_size, }, ) - dataset = None # for pyright + streaming_dataset = None # for pyright sampler = None - if cfg.dataset.get( + if dataset_cfg.get( 'remote', - ) is not None or cfg.dataset.get('streams') is not None: + ) is not None or dataset_cfg.get('streams') is not None: # Build streaming dataloader - streams = build_streams(cfg.dataset) - dataset = dataset_constructor.build_from_streaming( + streams_cfg = dataset_cfg.get('streams', None) + streams_cfg = to_dict_container( + streams_cfg, + ) if streams_cfg is not None else None + streams = build_streams( + streams_cfg, + ) if streams_cfg is not None else None + + # note: we don't need to use ** here because we're setting default values for almost all arguments + streaming_dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, streams=streams, - local=cfg.dataset.get('local', None), - remote=cfg.dataset.get('remote', None), - split=cfg.dataset.get('split', None), - download_retry=cfg.dataset.get('download_retry', 2), - download_timeout=cfg.dataset.get('download_timeout', 60), - validate_hash=cfg.dataset.get('validate_hash', None), - keep_zip=cfg.dataset.get('keep_zip', False), - epoch_size=cfg.dataset.get('epoch_size', None), - predownload=cfg.dataset.get('predownload', None), - cache_limit=cfg.dataset.get('cache_limit', None), - partition_algo=cfg.dataset.get('partition_algo', 'relaxed'), - num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), + local=dataset_cfg.get('local', None), + remote=dataset_cfg.get('remote', None), + split=dataset_cfg.get('split', None), + download_retry=dataset_cfg.get('download_retry', 2), + download_timeout=dataset_cfg.get('download_timeout', 60), + validate_hash=dataset_cfg.get('validate_hash', None), + keep_zip=dataset_cfg.get('keep_zip', False), + epoch_size=dataset_cfg.get('epoch_size', None), + predownload=dataset_cfg.get('predownload', None), + cache_limit=dataset_cfg.get('cache_limit', None), + partition_algo=dataset_cfg.get('partition_algo', 'relaxed'), + num_canonical_nodes=dataset_cfg.get('num_canonical_nodes', None), batch_size=dataset_batch_size, - shuffle=cfg.dataset.get('shuffle', False), - shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'), - shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), - shuffle_block_size=cfg.dataset.get('shuffle_block_size', None), - sampling_method=cfg.dataset.get('sampling_method', 'balanced'), - sampling_granularity=cfg.dataset.get('sampling_granularity', 1), - batching_method=cfg.dataset.get('batching_method', 'random'), - max_seq_len=cfg.dataset.max_seq_len, - allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False), + shuffle=dataset_cfg.get('shuffle', False), + shuffle_algo=dataset_cfg.get('shuffle_algo', 'py1e'), + shuffle_seed=dataset_cfg.get('shuffle_seed', 9176), + shuffle_block_size=dataset_cfg.get('shuffle_block_size', None), + sampling_method=dataset_cfg.get('sampling_method', 'balanced'), + sampling_granularity=dataset_cfg.get('sampling_granularity', 1), + batching_method=dataset_cfg.get('batching_method', 'random'), + max_seq_len=dataset_cfg['max_seq_len'], + allow_unsafe_types=dataset_cfg.get('allow_unsafe_types', False), replication=replication_factor, ) else: # Build HF dataloader - dataset_name_or_path = cfg.dataset.hf_name - split = cfg.dataset.get('split') + dataset_name_or_path = dataset_cfg['hf_name'] + split = dataset_cfg.get('split') if split is None: raise MissingHuggingFaceURLSplitError() @@ -225,7 +253,7 @@ def build_finetuning_dataloader( split = split.replace('-', '_') # Get the preprocessing function. - proto_preprocessing_fn = cfg.dataset.get('preprocessing_fn') + proto_preprocessing_fn = dataset_cfg.get('preprocessing_fn') if isinstance(proto_preprocessing_fn, (dict, DictConfig)): preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_dict( dict(proto_preprocessing_fn), @@ -237,63 +265,64 @@ def build_finetuning_dataloader( ) # Build dataset from HF. - dataset = dataset_constructor.build_from_hf( + streaming_dataset = dataset_constructor.build_from_hf( dataset_name=dataset_name_or_path, split=split, - safe_load=cfg.dataset.get('safe_load', False), - max_seq_len=cfg.dataset.max_seq_len, + safe_load=dataset_cfg.get('safe_load', False), + max_seq_len=dataset_cfg['max_seq_len'], preprocessing_fn=preprocessing_fn, tokenizer=tokenizer, - target_prompts=cfg.dataset.get( + target_prompts=dataset_cfg.get( 'target_prompts', _DEFAULT_TARGET_PROMPTS, ), - target_responses=cfg.dataset.get( + target_responses=dataset_cfg.get( 'target_responses', _DEFAULT_TARGET_RESPONSES, ), - decoder_only_format=cfg.dataset.decoder_only_format, - hf_kwargs=cfg.dataset.get('hf_kwargs', {}), + decoder_only_format=dataset_cfg['decoder_only_format'], + hf_kwargs=dataset_cfg.get('hf_kwargs', {}), ) # Ensure dataset is large enough. - if cfg.drop_last: + if drop_last: world_size = dist.get_world_size() // replication_factor minimum_dataset_size = world_size * dataloader_batch_size - if hasattr(dataset, '__len__'): - full_dataset_size = len(dataset) + if hasattr(streaming_dataset, '__len__'): + full_dataset_size = len(streaming_dataset) if full_dataset_size < minimum_dataset_size: raise NotEnoughDatasetSamplesError( - dataset_name=cfg.dataset.hf_name, + dataset_name=dataset_cfg['hf_name'], split=split, dataloader_batch_size=dataloader_batch_size, world_size=world_size, full_dataset_size=full_dataset_size, minimum_dataset_size=minimum_dataset_size, ) + # Initialize sampler. sampler = dist.get_sampler( - dataset, - drop_last=cfg.drop_last, - shuffle=cfg.dataset.shuffle, + streaming_dataset, + drop_last=drop_last, + shuffle=dataset_cfg['shuffle'], num_replicas=dist.get_world_size() // replication_factor if replication_factor > 1 else None, rank=dist.get_global_rank() // replication_factor if replication_factor > 1 else None, ) - assert dataset is not None # for pyright + assert streaming_dataset is not None # for pyright dl = DataLoader( - dataset, + streaming_dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, + drop_last=drop_last, sampler=sampler, - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + timeout=timeout, ) return construct_from_registry( @@ -302,12 +331,25 @@ def build_finetuning_dataloader( partial_function=False, kwargs={ 'dl': dl, - 'dataset_cfg': cfg.dataset, + 'dataset_cfg': dataset_cfg, }, ) -def _validate_config(dataset_cfg: DictConfig) -> None: +def _validate_config( + max_seq_len: int, + decoder_only_format: bool = False, + hf_name: Optional[str] = None, + local: Optional[str] = None, + remote: Optional[str] = None, + hf_kwargs: Optional[Dict[str, Any]] = None, + preprocessing_fn: Optional[str] = None, + safe_load: Optional[bool] = None, + streams: Optional[Dict[str, Any]] = None, + target_prompts: Optional[str] = None, + target_responses: Optional[str] = None, + **kwargs: Dict[str, Any], +) -> None: """Validates the dataset configuration. Makes sure that the dataset is properly configured for either @@ -320,14 +362,50 @@ def _validate_config(dataset_cfg: DictConfig) -> None: Raises: ValueError: If the dataset configuration does not meet the requirements. """ - if dataset_cfg.get('hf_name') is not None: + # Check for extraneous keys in the dataset config + allowed_additional_kwargs = { + 'local', + 'remote', + 'split', + 'download_retry', + 'download_timeout', + 'validate_hash', + 'keep_zip', + 'epoch_size', + 'predownload', + 'cache_limit', + 'partition_algo', + 'num_canonical_nodes', + 'batch_size', + 'shuffle', + 'shuffle_algo', + 'shuffle_seed', + 'shuffle_block_size', + 'sampling_method', + 'sampling_granularity', + 'batching_method', + 'max_seq_len', + 'allow_unsafe_types', + 'replication', + 'packing_ratio', + 'allow_pad_trimming', + 'seq_parallel_replication', + 'auto_packing_replication', + } + if not set(kwargs.keys()).issubset(allowed_additional_kwargs): + raise ValueError( + 'The dataset config contains the following extraneous keys: ' +\ + ', '.join(set(kwargs.keys()) - allowed_additional_kwargs), + ) + + if hf_name is not None: # Using the HuggingFace dataset codepath illegal_keys = ['local', 'remote'] - discovered_illegal_keys = [ - '`' + key + '`' - for key in illegal_keys - if dataset_cfg.get(key) is not None - ] + discovered_illegal_keys = [] + if local is not None: + discovered_illegal_keys.append('`local`') + if remote is not None: + discovered_illegal_keys.append('`remote`') if discovered_illegal_keys: raise ValueError( 'The dataset config sets a value for `hf_name` as well as the ' +\ @@ -335,12 +413,17 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'Those keys are used when building from a streaming dataset, but ' +\ 'setting `hf_name` instructs the dataset to build from a HuggingFace dataset.', ) - elif dataset_cfg.get('remote') is not None: + elif remote is not None: # Using the streaming dataset codepath - illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] + illegal_keys = { + 'hf_name': hf_name, + 'hf_kwargs': hf_kwargs, + 'preprocessing_fn': preprocessing_fn, + 'safe_load': safe_load, + } discovered_illegal_keys = [] - for key in illegal_keys: - if dataset_cfg.get(key) is not None: + for key, value in illegal_keys.items(): + if value is not None: discovered_illegal_keys.append('`' + key + '`') if discovered_illegal_keys: raise ValueError( @@ -349,17 +432,22 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'Those keys are used when building from a HuggingFace dataset, but ' +\ 'setting `remote` instructs the dataset to build from a streaming dataset.', ) - if dataset_cfg.get('local') is None: + if local is None: raise ValueError( 'Using a streaming dataset requires setting both `remote` and `local`, ' +\ 'but dataset.local is None.', ) - elif dataset_cfg.get('streams') is not None: + elif streams is not None: # Using the streaming dataset codepath - illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load'] + illegal_keys = { + 'hf_name': hf_name, + 'hf_kwargs': hf_kwargs, + 'preprocessing_fn': preprocessing_fn, + 'safe_load': safe_load, + } discovered_illegal_keys = [] - for key in illegal_keys: - if dataset_cfg.get(key) is not None: + for key, value in illegal_keys.items(): + if value is not None: discovered_illegal_keys.append('`' + key + '`') if discovered_illegal_keys: raise ValueError( @@ -368,10 +456,10 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'Those keys are used when building from a HuggingFace dataset, but ' +\ 'setting `streams` instructs the dataset to build from a streaming dataset.', ) - illegal_keys = ['remote', 'local'] + illegal_keys = {'remote': remote, 'local': local} discovered_illegal_keys = [] - for key in illegal_keys: - if dataset_cfg.get(key) is not None: + for key, value in illegal_keys.items(): + if value is not None: discovered_illegal_keys.append('`' + key + '`') if discovered_illegal_keys: raise ValueError( @@ -387,25 +475,15 @@ def _validate_config(dataset_cfg: DictConfig) -> None: 'dataset, or set `remote` to use a streaming dataset, or set ' +\ '`streams` to use multiple streaming datasets, but all were None.', ) - max_seq_len = dataset_cfg.get('max_seq_len') - if max_seq_len is None: - raise ValueError( - 'In the dataset config, you must set the `max_seq_len`', - ) - - if max_seq_len != int(max_seq_len): - raise ValueError('max_seq_len must be an integer') - dataset_cfg['max_seq_len'] = int(max_seq_len) # Raise an error if the target_prompts + target_responses + decoder_only_format settings # are invalid - target_responses = str( - dataset_cfg.get('target_responses', _DEFAULT_TARGET_RESPONSES), - ).lower() - target_prompts = str( - dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS), - ).lower() - decoder_only_format = dataset_cfg.decoder_only_format + if target_prompts is None: + target_prompts = _DEFAULT_TARGET_PROMPTS + if target_responses is None: + target_responses = _DEFAULT_TARGET_RESPONSES + target_prompts, target_responses = target_prompts.lower( + ), target_responses.lower() validate_target_settings( target_prompts, target_responses, @@ -465,8 +543,7 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: except FileNotFoundError as e: if extension == SUPPORTED_EXTENSIONS[-1]: files_searched = [ - f'{cfg.dataset.hf_name}/{cfg.dataset.split}{ext}' - for ext in SUPPORTED_EXTENSIONS + f'{name}/{split}{ext}' for ext in SUPPORTED_EXTENSIONS ] raise FileNotFoundError( f'Could not find a file with any of ' + \ @@ -498,26 +575,28 @@ def _download_remote_hf_dataset(remote_path: str, split: str) -> str: def build_collate_fn( - dataloader_cfg: DictConfig, + dataloader_cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, device_batch_size: int, ) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: - dataset_cfg = dataloader_cfg.dataset - max_seq_len = dataset_cfg.max_seq_len + # These `.get` calls are safe because the dataset_cfg is validated for extra keys + dataset_cfg = dataloader_cfg['dataset'] + target_responses = dataset_cfg.get( + 'target_responses', + _DEFAULT_TARGET_RESPONSES, + ) + target_prompts = dataset_cfg.get('target_prompts', _DEFAULT_TARGET_PROMPTS) + max_seq_len = dataset_cfg['max_seq_len'] + decoder_only_format = dataset_cfg['decoder_only_format'] + allow_pad_trimming = dataset_cfg.get('allow_pad_trimming', False) collate_fn = Seq2SeqFinetuningCollator( tokenizer=tokenizer, max_seq_len=max_seq_len, - decoder_only_format=dataset_cfg.decoder_only_format, - target_responses=dataset_cfg.get( - 'target_responses', - _DEFAULT_TARGET_RESPONSES, - ), - target_prompts=dataset_cfg.get( - 'target_prompts', - _DEFAULT_TARGET_PROMPTS, - ), - allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False), + decoder_only_format=decoder_only_format, + target_responses=target_responses, + target_prompts=target_prompts, + allow_pad_trimming=allow_pad_trimming, ) packing_ratio = dataset_cfg.get('packing_ratio') @@ -532,9 +611,9 @@ def build_collate_fn( if packing_ratio == 'auto': packing_ratio = auto_packing_ratio( - dataloader_cfg, - tokenizer, - device_batch_size, + dataloader_cfg=dataloader_cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, ) if isinstance(packing_ratio, str): @@ -550,7 +629,7 @@ def build_collate_fn( elif packing_ratio < 1.0: raise ValueError('packing_ratio must be >= 1, if supplied') - if not dataset_cfg.decoder_only_format: + if not decoder_only_format: raise NotImplementedError( 'On-the-fly packing is currently only supported for decoder-only formats.', ) @@ -611,9 +690,9 @@ def build_collate_fn( device_batch_size = 1 dataloader = build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, ).dataloader packing = cfg.dataset.get('packing_ratio') is not None diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index be5c703068..a6fdf34953 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -3,12 +3,11 @@ import logging import tempfile -from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple import numpy as np import torch from composer.utils import dist -from omegaconf import DictConfig from transformers import PreTrainedTokenizerBase log = logging.getLogger(__name__) @@ -318,7 +317,7 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int): def auto_packing_ratio( - dataloader_cfg: DictConfig, + dataloader_cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, device_batch_size: int, num_packing_ratios: int = 20, @@ -352,20 +351,21 @@ def auto_packing_ratio( # Set the seed so that auto packing is deterministic. reproducibility.seed_all(0) - max_seq_len = dataloader_cfg.dataset.max_seq_len # If max_seq_len is very small, skip profiling and select packing ratio of 1. + dataset_config = dataloader_cfg['dataset'] + max_seq_len = dataset_config.get('max_seq_len') if max_seq_len <= 100: return 1 min_ratio = 1 max_ratio = max_seq_len / 100 profiling_results = profile_packing( - dataloader_cfg, - tokenizer, - min_ratio, - max_ratio, - num_packing_ratios, - device_batch_size, + dataloader_cfg=dataloader_cfg, + tokenizer=tokenizer, + min_ratio=min_ratio, + max_ratio=max_ratio, + num_packing_ratios=num_packing_ratios, + device_batch_size=device_batch_size, ) # Obtain the maximum packing_ratio/minimum padding that has no waste. @@ -392,7 +392,7 @@ def auto_packing_ratio( def profile_packing( - dataloader_cfg: DictConfig, + dataloader_cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, min_ratio: float, max_ratio: float, @@ -416,39 +416,40 @@ def profile_packing( from llmfoundry.data.dataloader import build_dataloader - max_seq_len = dataloader_cfg.dataset.get('max_seq_len') - max_leftovers_to_keep = dataloader_cfg.dataset.get( - 'max_leftovers_to_keep', - None, - ) + dataset_cfg = dataloader_cfg['dataset'] + max_seq_len = dataset_cfg.get('max_seq_len') + max_leftovers_to_keep = dataset_cfg.get('max_leftovers_to_keep', None) # Turn off packing and sequence parallelism for the dataloader (we want raw, pre-packed, full-length examples) dataloader_cfg = copy.deepcopy(dataloader_cfg) - dataloader_cfg.dataset.packing_ratio = 1.0 - dataloader_cfg.dataset.auto_packing_replication = dataloader_cfg.dataset.get( - 'seq_parallel_replication', - 1, - ) or 1 - dataloader_cfg.dataset.seq_parallel_replication = 1 - dataloader_cfg.drop_last = False - dataloader_cfg.num_workers = 0 - dataloader_cfg.prefetch_factor = None - dataloader_cfg.persistent_workers = False + dataloader_cfg.update({ + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None, + 'persistent_workers': False, + }) + dataloader_cfg['dataset']['packing_ratio'] = 1.0 + dataloader_cfg['dataset']['auto_packing_replication' + ] = dataloader_cfg['dataset'].get( + 'seq_parallel_replication', + 1, + ) or 1 + dataloader_cfg['dataset']['seq_parallel_replication'] = 1 # If streaming dataset, use a temporary local folder for profiling local_rank_zero = dist.get_global_rank() - dist.get_local_rank() - if dataloader_cfg.dataset.get('remote') is not None: + if dataloader_cfg['dataset'].get('remote') is not None: tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] - dataloader_cfg.dataset.local = tmp_path + dataloader_cfg['dataset']['local'] = tmp_path - if dataloader_cfg.dataset.get('streams') is not None: - for stream_config in dataloader_cfg.dataset.streams.values(): + if dataloader_cfg['dataset'].get('streams') is not None: + for stream_config in dataloader_cfg['dataset']['streams'].values(): tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] - stream_config.local = tmp_path + stream_config['local'] = tmp_path # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], [] diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 07a3e79fbe..60b81cd145 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -20,8 +20,6 @@ import numpy as np import torch from composer.core.data_spec import DataSpec -from omegaconf import DictConfig -from omegaconf import OmegaConf as om from streaming import Stream, StreamingDataset from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerBase @@ -268,39 +266,45 @@ def get_sequence_id_from_batch( return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1) -def build_streams(dataset_cfg: DictConfig): - streams_dict = dataset_cfg.pop('streams', None) +def build_streams(streams: Optional[Dict[str, Any]] = None,): + streams_dict = streams # build streams - streams = None + streams_ret = [] if streams_dict is not None: - streams = [] - for stream in streams_dict.values(): - # stream is the streams kwargs - # fwd all kwargs with **stream allows streaming to check args - streams.append(Stream(**stream)) - return streams + streams_ret = [Stream(**stream) for stream in streams_dict.values()] + return streams_ret def build_text_dataloader( - cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - device_batch_size: Union[int, float], + device_batch_size: int, + dataset: Dict[str, Any], + drop_last: bool, + num_workers: int, + pin_memory: bool = True, + prefetch_factor: int = 2, + persistent_workers: bool = True, + timeout: int = 0, ) -> DataSpec: - assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' + + dataset_cfg = dataset # get kwargs - cfg.dataset['replication'], dataset_batch_size = construct_from_registry( + dataset_cfg['replication'], dataset_batch_size = construct_from_registry( name='dataset_replication_validator', registry=registry.dataset_replication_validators, partial_function=False, kwargs={ - 'cfg': cfg, + 'dataset_cfg': dataset_cfg, 'tokenizer': tokenizer, 'device_batch_size': device_batch_size, }, ) - streams = build_streams(cfg.dataset) + streams = build_streams( + streams=dataset_cfg.pop('streams') + if 'streams' in dataset_cfg else None, + ) valid_streaming_text_dataset_parameters = inspect.signature( StreamingTextDataset, @@ -308,39 +312,50 @@ def build_text_dataloader( dataset_config_subset_for_streaming_text_dataset = { k: v - for k, v in cfg.dataset.items() + for k, v in dataset_cfg.items() if k in valid_streaming_text_dataset_parameters } # build dataset potentially with streams - dataset = StreamingTextDataset( + text_dataset = StreamingTextDataset( tokenizer=tokenizer, streams=streams, batch_size=dataset_batch_size, **dataset_config_subset_for_streaming_text_dataset, ) + dataloader_cfg = { + 'name': 'text', + 'dataset': dataset_cfg, + 'drop_last': drop_last, + 'num_workers': num_workers, + 'pin_memory': pin_memory, + 'prefetch_factor': prefetch_factor, + 'persistent_workers': persistent_workers, + 'timeout': timeout, + } + collate_fn, dataloader_batch_size = construct_from_registry( name='text_collator', registry=registry.collators, partial_function=False, kwargs={ - 'cfg': cfg, - 'tokenizer': dataset.tokenizer, + 'dataloader_cfg': dataloader_cfg, + 'tokenizer': tokenizer, 'dataset_batch_size': dataset_batch_size, }, ) dl = DataLoader( - dataset, + text_dataset, collate_fn=collate_fn, batch_size=dataloader_batch_size, - drop_last=cfg.drop_last, - num_workers=cfg.num_workers, - pin_memory=cfg.get('pin_memory', True), - prefetch_factor=cfg.get('prefetch_factor', 2), - persistent_workers=cfg.get('persistent_workers', True), - timeout=cfg.get('timeout', 0), + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + timeout=timeout, ) return construct_from_registry( @@ -349,7 +364,7 @@ def build_text_dataloader( partial_function=False, kwargs={ 'dl': dl, - 'dataset_cfg': cfg.dataset, + 'dataset_cfg': dataset_cfg, }, ) @@ -415,14 +430,17 @@ def build_text_dataloader( 'drop_last': False, 'num_workers': 4, } - cfg = om.create(cfg) device_batch_size = 2 tokenizer_name = args.tokenizer tokenizer_kwargs = {'model_max_length': args.max_seq_len} tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - loader = build_text_dataloader(cfg, tokenizer, device_batch_size).dataloader + loader = build_text_dataloader( + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + ).dataloader assert isinstance(loader, DataLoader) assert isinstance(loader.dataset, StreamingTextDataset) tokenizer = loader.dataset.tokenizer diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 644eb1417d..a5fe3a1022 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -2,13 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Callable, Iterable, Mapping, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Mapping, Tuple, Union import torch import transformers from composer.core.data_spec import DataSpec from composer.core.types import Batch -from omegaconf import DictConfig from torch.utils.data import DataLoader as TorchDataloader from transformers import PreTrainedTokenizerBase @@ -20,9 +19,12 @@ log = logging.getLogger(__name__) -def _validate_cfg(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase): - eos_token_id = cfg.dataset.get('eos_token_id', None) - bos_token_id = cfg.dataset.get('bos_token_id', None) +def _validate_cfg( + dataset_cfg: Dict[str, Any], + tokenizer: PreTrainedTokenizerBase, +): + eos_token_id = dataset_cfg.get('eos_token_id', None) + bos_token_id = dataset_cfg.get('bos_token_id', None) if eos_token_id is None and bos_token_id is None and ( hasattr(tokenizer, 'eos_token_id') or @@ -35,7 +37,7 @@ def _validate_cfg(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase): tokenizer_eos_token_id = getattr(tokenizer, 'eos_token_id', None) if eos_token_id is not None and eos_token_id != tokenizer_eos_token_id: eos_mismatch_str = f'Provided {eos_token_id=} does not match the eos_token_id of the tokenizer={tokenizer_eos_token_id}.' - if cfg.dataset.pop('override_eos_token_id_mismatch_error', False): + if dataset_cfg.pop('override_eos_token_id_mismatch_error', False): log.warning(eos_mismatch_str) else: raise ValueError( @@ -46,7 +48,7 @@ def _validate_cfg(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase): tokenizer_bos_token_id = getattr(tokenizer, 'bos_token_id', None) if bos_token_id is not None and bos_token_id != tokenizer_bos_token_id: bos_mismatch_str = f'Provided {bos_token_id=} does not match the bos_token_id of the tokenizer={tokenizer_bos_token_id}.' - if cfg.dataset.pop('override_bos_token_id_mismatch_error', False): + if dataset_cfg.pop('override_bos_token_id_mismatch_error', False): log.warning(bos_mismatch_str) else: raise ValueError( @@ -54,20 +56,19 @@ def _validate_cfg(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase): ' To override this error, set the override_bos_token_id_mismatch_error flag to True in the dataset config section of the YAML.', ) - max_seq_len = cfg.dataset.get('max_seq_len') + max_seq_len = dataset_cfg.get('max_seq_len') if max_seq_len is not None: if max_seq_len != int(max_seq_len): raise ValueError('max_seq_len must be an integer') - cfg.dataset['max_seq_len'] = int(max_seq_len) + dataset_cfg['max_seq_len'] = int(max_seq_len) def validate_ds_replication( - cfg: DictConfig, + dataset_cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, device_batch_size: Union[int, float], ) -> Tuple[int, int]: - _validate_cfg(cfg, tokenizer) - dataset_cfg = cfg.dataset + _validate_cfg(dataset_cfg, tokenizer) if (dataset_cfg.get('seq_parallel_replication', 1) or 1) > 1: raise NotImplementedError('Sequence parallelism is not supported.') if not isinstance(device_batch_size, int): @@ -77,7 +78,7 @@ def validate_ds_replication( def get_data_spec( dl: Union[Iterable, TorchDataloader], - dataset_cfg: DictConfig, + dataset_cfg: Dict[str, Any], ) -> DataSpec: del dataset_cfg token_counting_func = get_tokens_per_batch_func() @@ -134,14 +135,16 @@ def get_num_tokens_in_batch(batch: Batch) -> int: def get_text_collator( - cfg: DictConfig, + dataloader_cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, dataset_batch_size: int, ) -> Tuple[Union[transformers.DataCollatorForLanguageModeling, ConcatenatedSequenceCollatorWrapper], int]: - eos_token_id = cfg.dataset.get('eos_token_id', None) - bos_token_id = cfg.dataset.get('bos_token_id', None) - mlm_probability = cfg.dataset.pop('mlm_probability', None) + dataset_cfg = dataloader_cfg.get('dataset') + assert isinstance(dataset_cfg, dict) + eos_token_id = dataset_cfg.get('eos_token_id', None) + bos_token_id = dataset_cfg.get('bos_token_id', None) + mlm_probability = dataset_cfg.pop('mlm_probability', None) collate_fn = transformers.DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=mlm_probability is not None, @@ -160,8 +163,8 @@ def get_text_collator( def get_finetuning_collator( - cfg: DictConfig, + dataloader_cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, dataset_batch_size: int, ) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: - return build_collate_fn(cfg, tokenizer, dataset_batch_size) + return build_collate_fn(dataloader_cfg, tokenizer, dataset_batch_size) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 7182c47d2a..5f3a53ed18 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -6,11 +6,19 @@ import logging import os import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) from composer.models.huggingface import peft_installed from composer.utils import dist -from omegaconf import DictConfig from torchmetrics import Metric from transformers import ( AutoConfig, @@ -28,7 +36,7 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithFSDP from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights -from llmfoundry.utils.config_utils import get_hf_config_value, pop_config +from llmfoundry.utils.config_utils import get_hf_config_value if TYPE_CHECKING: from peft import PeftConfig, PeftModel @@ -42,60 +50,79 @@ class ComposerHFCausalLM(HuggingFaceModelWithFSDP): """Configures a :class:`.HuggingFaceModel` around a Causal LM. Args: - om_model_config (DictConfig): An OmegaConf DictConfig specifying the configuration options - cfg.pretrained_model_name_or_path (str): The name of or local path to - the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel). - cfg.config_overrides (dict, optional): An optional dictionary of keyword - arguments that override the default configuration associated with - cfg.pretrained_model_name_or_path. - cfg.pretrained (bool): Whether to instantiate the model with pre-trained - weights coming from cfg.pretrained_model_name_or_path. If ``True``, - cfg.config_overrides must be compatible with the pre-trained weights. - cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to - initialize the model on. Currently, `meta` is only supported when - cfg.pretrained is ``False``. Default: ``'cpu'``. - cfg.peft_config (dict, optional): An optional dictionary of keyword arguments to be - passed to the PeftConfig constructor. If provided, the model will be wrapped in a PeftModel. - cfg.trust_remote_code (bool, optional): Whether to trust remote code when loading from Hugging Face - Hub. Default: ``True``. - cfg.use_auth_token (bool, optional): Whether to use the Hugging Face authentication token when - loading from Hugging Face Hub. Default: ``False``. - cfg.use_train_metrics (bool, optional): Whether to use training metrics. Default: ``True``. - cfg.load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``. - cfg.init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``. - cfg.use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``. + pretrained_model_name_or_path (str): The name of or local path to + the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel). + config_overrides (dict, optional): An optional dictionary of keyword + arguments that override the default configuration associated with + cfg.pretrained_model_name_or_path. + pretrained (bool): Whether to instantiate the model with pre-trained + weights coming from cfg.pretrained_model_name_or_path. If ``True``, + cfg.config_overrides must be compatible with the pre-trained weights. + init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to + initialize the model on. Currently, `meta` is only supported when + cfg.pretrained is ``False``. Default: ``'cpu'``. + peft_config (dict, optional): An optional dictionary of keyword arguments to be + passed to the PeftConfig constructor. If provided, the model will be wrapped in a PeftModel. + trust_remote_code (bool, optional): Whether to trust remote code when loading from Hugging Face + Hub. Default: ``True``. + use_auth_token (bool, optional): Whether to use the Hugging Face authentication token when + loading from Hugging Face Hub. Default: ``False``. + use_train_metrics (bool, optional): Whether to use training metrics. Default: ``True``. + load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Default: ``False``. + init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``. + use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``. tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ def __init__( self, - om_model_config: DictConfig, tokenizer: PreTrainedTokenizerBase, + pretrained_model_name_or_path: str, + pretrained: bool = True, + pretrained_lora_id_or_path: Optional[str] = None, + trust_remote_code: bool = True, + use_auth_token: bool = False, + use_flash_attention_2: bool = False, + load_in_8bit: bool = False, + init_device: str = 'cpu', + config_overrides: Optional[Dict[str, Any]] = None, + peft_config: Optional[Dict[str, Any]] = None, + use_train_metrics: bool = True, + additional_train_metrics: Optional[List] = None, + additional_eval_metrics: Optional[List] = None, ): - model = ComposerHFCausalLM.build_inner_model(om_model_config) - train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( - om_model_config, + config_overrides = config_overrides or {} + + model = ComposerHFCausalLM.build_inner_model( + pretrained_model_name_or_path=pretrained_model_name_or_path, + pretrained_lora_id_or_path=pretrained_lora_id_or_path, + trust_remote_code=trust_remote_code, + init_device=init_device, + use_flash_attention_2=use_flash_attention_2, + use_auth_token=use_auth_token, + config_overrides=config_overrides, + load_in_8bit=load_in_8bit, + pretrained=pretrained, + prepare_for_fsdp=True, ) - peft_config_dict = pop_config( - om_model_config, - 'peft_config', - must_exist=False, - convert=True, + train_metrics, eval_metrics = ComposerHFCausalLM.build_metrics( + use_train_metrics=use_train_metrics, + additional_train_metrics=additional_train_metrics, + additional_eval_metrics=additional_eval_metrics, ) - if peft_config_dict is not None and not peft_installed: + + if peft_config is not None and not peft_installed: raise ValueError( 'PEFT is not installed, but peft_config was passed. Please install LLM Foundry with the peft extra to use peft_config.', ) - peft_config = None - if peft_config_dict is not None: - peft_config = self._get_peft_config(peft_config_dict) + peft_config_object = None + if peft_config is not None: + peft_config_object = self._get_peft_config(peft_config) # Set up config args for the model construction and base classes - init_device = om_model_config.get('init_device', 'cpu') - super().__init__( model=model, shift_labels=True, @@ -103,31 +130,35 @@ def __init__( metrics=train_metrics, eval_metrics=eval_metrics, init_device=init_device, - peft_config=peft_config, + peft_config=peft_config_object, ) @staticmethod def build_metrics( - om_model_config: DictConfig, + use_train_metrics: bool, + additional_train_metrics: Optional[List[str]] = None, + additional_eval_metrics: Optional[List[str]] = None, ) -> Tuple[List[Metric], List[Metric]]: """Builds the training and evaluation metrics for the model. Args: - om_model_config (DictConfig): The model configuration. See `__init__` for details on allowed keys. + use_train_metrics (bool): Whether to use training metrics. + additional_train_metrics (Optional[List[str]]): Additional training metrics to include. + additional_eval_metrics (Optional[List[str]]): Additional evaluation metrics to include. + + Returns: + Tuple[List[Metric], List[Metric]]: A tuple containing the list of training metrics and evaluation metrics. """ from llmfoundry.utils.builders import build_metric - use_train_metrics = om_model_config.get('use_train_metrics', True) - train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + om_model_config.get( - 'additional_train_metrics', - [], + train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + ( + additional_train_metrics or [] ) train_metrics = [ build_metric(metric, {}) for metric in train_metric_names ] if use_train_metrics else [] - eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + om_model_config.get( - 'additional_eval_metrics', - [], + eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + ( + additional_eval_metrics or [] ) eval_metrics = [ build_metric(metric, {}) for metric in eval_metric_names @@ -137,42 +168,42 @@ def build_metrics( @staticmethod def build_inner_model( - om_model_config: DictConfig, + pretrained_model_name_or_path: str, + pretrained_lora_id_or_path: Optional[str], + trust_remote_code: bool, + init_device: str, + use_flash_attention_2: bool, + use_auth_token: bool, + config_overrides: Dict[str, Any], + load_in_8bit: bool, + pretrained: bool, prepare_for_fsdp: bool = False, ) -> Union[PreTrainedModel, 'PeftModel']: """Builds the inner model for the ComposerHFCausalLM. Args: - om_model_config (DictConfig): The model configuration. See `__init__` for details on allowed keys. + pretrained_model_name_or_path (str): The pretrained model name or path. + pretrained_lora_id_or_path (Optional[str]): The pretrained LORA ID or path. + trust_remote_code (bool): Whether to trust remote code. + init_device (str): The initialization device. + use_flash_attention_2 (bool): Whether to use flash attention 2. + use_auth_token (bool): Whether to use an authentication token. + config_overrides (Dict[str, Any]): The configuration overrides. + load_in_8bit (bool): Whether to load in 8-bit. + prepare_for_fsdp (bool, optional): Whether to prepare the model for FSDP wrapping. Default: False. + + Returns: + Union[PreTrainedModel, 'PeftModel']: The built inner model. prepare_for_fsdp (bool): Whether to prepare the model for FSDP wrapping. Default: ``False``. """ - pretrained_model_name_or_path = om_model_config.pretrained_model_name_or_path - pretrained_lora_id_or_path = om_model_config.get( - 'pretrained_lora_id_or_path', - None, - ) - - if not om_model_config.get( - 'trust_remote_code', - True, - ) and pretrained_model_name_or_path.startswith('mosaicml/mpt'): + if not trust_remote_code and pretrained_model_name_or_path.startswith( + 'mosaicml/mpt', + ): raise ValueError( 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' + 'which is significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.', ) - - # Set up Hugging Face args - trust_remote_code = om_model_config.get('trust_remote_code', True) - use_auth_token = om_model_config.get('use_auth_token', False) - use_flash_attention_2 = om_model_config.get( - 'use_flash_attention_2', - False, - ) - load_in_8bit = om_model_config.get('load_in_8bit', False) - - # Set up config args for the model construction and base classes - init_device = om_model_config.get('init_device', 'cpu') # Resolve "mixed" init device to either "cpu" or "meta" resolved_init_device = hf_get_init_device(init_device) requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' @@ -211,7 +242,7 @@ def _autoset_attn_implementation_monkeypatch( ) # set config overrides - for k, v in om_model_config.get('config_overrides', {}).items(): + for k, v in config_overrides.items(): if not hasattr(config, k): raise ValueError( f'config does not have attribute "{k}" to override ({k}: {v}).', @@ -258,7 +289,7 @@ def _autoset_attn_implementation_monkeypatch( # We need to have all non-zero local ranks be not-pretrained # Rank 0 will still be pretrained, and distribute the weights appropriately if dist.get_local_rank() != 0 and init_device == 'mixed': - om_model_config.pretrained = False + pretrained = False # If the HuggingFace model is coming from a local folder, Hugging Face copies the modules into the # transformers modules cache. On particular systems, this operation seems to cause contention between @@ -280,7 +311,7 @@ def _autoset_attn_implementation_monkeypatch( # initialize the model on the correct device if resolved_init_device == 'cpu': - if om_model_config.pretrained: + if pretrained: model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, @@ -294,7 +325,7 @@ def _autoset_attn_implementation_monkeypatch( trust_remote_code=trust_remote_code, ) elif resolved_init_device == 'meta': - if om_model_config.pretrained: + if pretrained: raise ValueError( 'Setting cfg.pretrained=True is not supported when init_device="meta".', ) diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index 6520fe7426..f54b7c42ec 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -5,10 +5,9 @@ from __future__ import annotations -from typing import Mapping +from typing import List, Mapping, Optional from composer.utils import dist -from omegaconf import DictConfig from transformers import ( AutoConfig, PreTrainedTokenizerBase, @@ -32,36 +31,45 @@ class ComposerHFT5(HuggingFaceModelWithFSDP): will expand support to more general classes of HF Encoder-Decoder models. Args: - cfg (DictConfig): An omegaconf dictionary used to configure the model: - cfg.pretrained_model_name_or_path (str): The name of or local path to - the HF model (e.g., `t5-base` to instantiate a T5 using the base config). - cfg.config_overrides (dict, optional): An optional dictionary of keyword - arguments that override the default configuration associated with - cfg.pretrained_model_name_or_path. Default: ``{}``. - cfg.pretrained (bool): Whether to instantiate the model with pre-trained - weights coming from cfg.pretrained_model_name_or_path. If ``True``, - cfg.config_overrides must be compatible with the pre-trained weights. - cfg.init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to - initialize the model on. Currently, `meta` is only supported when - cfg.pretrained is ``False``. Default: ``'cpu'``. + pretrained_model_name_or_path (str): The name of or local path to + the HF model (e.g., `t5-base` to instantiate a T5 using the base config). + config_overrides (dict, optional): An optional dictionary of keyword + arguments that override the default configuration associated with + cfg.pretrained_model_name_or_path. Default: ``{}``. + pretrained (bool): Whether to instantiate the model with pre-trained + weights coming from cfg.pretrained_model_name_or_path. If ``True``, + cfg.config_overrides must be compatible with the pre-trained weights. + init_device ('cpu' | 'meta'): Which device, 'cpu' or 'meta', to + initialize the model on. Currently, `meta` is only supported when + cfg.pretrained is ``False``. Default: ``'cpu'``. tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ def __init__( self, - om_model_config: DictConfig, tokenizer: PreTrainedTokenizerBase, + pretrained_model_name_or_path: str, + pretrained: Optional[bool] = True, + trust_remote_code: bool = True, + use_auth_token: bool = False, + config_overrides: Optional[Mapping] = None, + init_device: str = 'cpu', + additional_train_metrics: Optional[List] = None, + name: Optional[str] = None, ): from llmfoundry.utils.builders import build_metric + config_overrides = config_overrides or {} + additional_train_metrics = additional_train_metrics or [] + config = AutoConfig.from_pretrained( - om_model_config.pretrained_model_name_or_path, - trust_remote_code=om_model_config.get('trust_remote_code', True), - use_auth_token=om_model_config.get('use_auth_token', False), + pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, ) # set config overrides - for k, v in om_model_config.get('config_overrides', {}).items(): + for k, v in config_overrides.items(): if not hasattr(config, k): raise ValueError( f'config does not have attribute "{k}" to override ({k}: {v}).', @@ -84,8 +92,6 @@ def __init__( raise ValueError(f'Model type "hf_t5" currently only supports T5 models ' +\ f'using configs where `is_encoder_decoder` is ``True``.') - init_device = om_model_config.get('init_device', 'cpu') - # Get the device we want to initialize, and use the # resolved version to initialize the HF model resolved_init_device = hf_get_init_device(init_device) @@ -93,18 +99,18 @@ def __init__( # We need to have all non-zero local ranks be not-pretrained # Rank 0 will still be pretrained, and distribute the weights appropriately if dist.get_local_rank() != 0 and init_device == 'mixed': - om_model_config.pretrained = False + pretrained = False if resolved_init_device == 'cpu': - if om_model_config.pretrained: + if pretrained: model = T5ForConditionalGeneration.from_pretrained( - om_model_config.pretrained_model_name_or_path, + pretrained_model_name_or_path, config=config, ) else: model = T5ForConditionalGeneration(config) elif resolved_init_device == 'meta': - if om_model_config.pretrained: + if pretrained: raise ValueError( 'Setting cfg.pretrained=True is not supported when init_device="meta".', ) @@ -116,8 +122,8 @@ def __init__( ) metrics = [ - build_metric(metric, {}) for metric in DEFAULT_ENC_DEC_METRICS + - om_model_config.get('additional_train_metrics', []) + build_metric(metric, {}) + for metric in DEFAULT_ENC_DEC_METRICS + additional_train_metrics ] super().__init__( diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index bdf6cff925..15f1440b47 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -39,8 +39,6 @@ except Exception as e: raise e -from omegaconf import DictConfig -from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import ( BaseModelOutputWithPast, @@ -1052,8 +1050,11 @@ class ComposerMPTCausalLM(HuggingFaceModel): def __init__( self, - om_model_config: DictConfig, tokenizer: Optional[PreTrainedTokenizerBase] = None, + use_train_metrics: Optional[bool] = True, + additional_train_metrics: Optional[List] = None, + loss_fn: Optional[Union[str, Dict]] = 'fused_crossentropy', + **kwargs: Dict[str, Any], ): from llmfoundry.metrics import ( DEFAULT_CAUSAL_LM_EVAL_METRICS, @@ -1061,27 +1062,18 @@ def __init__( ) from llmfoundry.utils.builders import build_metric - resolved_om_model_config = om.to_container( - om_model_config, - resolve=True, - ) - assert isinstance(resolved_om_model_config, dict) - - hf_config = MPTConfig.from_dict(resolved_om_model_config) - model = MPTForCausalLM(hf_config) + additional_train_metrics = additional_train_metrics or [] - use_train_metrics = om_model_config.get('use_train_metrics', True) - train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + resolved_om_model_config.get( - 'additional_train_metrics', - [], + model = MPTForCausalLM( + MPTConfig(use_train_metrics=use_train_metrics, **kwargs), ) + + use_train_metrics = use_train_metrics + train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + additional_train_metrics train_metrics = [ build_metric(metric, {}) for metric in train_metric_names ] if use_train_metrics else [] - eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + resolved_om_model_config.get( - 'additional_eval_metrics', - [], - ) + eval_metric_names = DEFAULT_CAUSAL_LM_EVAL_METRICS + additional_train_metrics eval_metrics = [ build_metric(metric, {}) for metric in eval_metric_names ] @@ -1096,7 +1088,7 @@ def __init__( allow_embedding_resizing=True, ) - loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy') + loss_fn_config = loss_fn if loss_fn_config == 'fused_crossentropy': try: from flash_attn.losses.cross_entropy import \ diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 68b2ffad50..0c8e64b759 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -1,12 +1,11 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Iterable, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, Tuple, Type, Union from composer.core import Algorithm, Callback, DataSpec from composer.loggers import LoggerDestination from composer.models import ComposerModel from composer.optim import ComposerScheduler -from omegaconf import DictConfig from torch.optim import Optimizer from torch.utils.data import DataLoader as TorchDataloader from torchmetrics import Metric @@ -133,7 +132,9 @@ dataloaders = create_registry( 'llmfoundry', 'dataloaders', - generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], DataSpec], + generic_type=Callable[ + ..., + DataSpec], # The arguments to the dataloader may vary depending on the contents of the config. entry_points=True, description=_dataloaders_description, ) @@ -152,8 +153,8 @@ 'llmfoundry', 'dataset_replication_validators', generic_type=Callable[ - [DictConfig, PreTrainedTokenizerBase, Union[int, float]], Tuple[int, - int]], + [Dict[str, Any], PreTrainedTokenizerBase, Union[int, float]], + Tuple[int, int]], entry_points=True, description=_dataset_replication_validators_description, ) @@ -171,7 +172,7 @@ collators = create_registry( 'llmfoundry', 'collators', - generic_type=Callable[[DictConfig, PreTrainedTokenizerBase, int], + generic_type=Callable[[Dict[str, Any], PreTrainedTokenizerBase, int], Tuple[Any, int]], entry_points=True, description=_collators_description, @@ -188,7 +189,7 @@ data_specs = create_registry( 'llmfoundry', 'data_specs', - generic_type=Callable[[Union[Iterable, TorchDataloader], DictConfig], + generic_type=Callable[[Union[Iterable, TorchDataloader], Dict[str, Any]], DataSpec], entry_points=True, description=_data_specs_description, diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 39025b8066..73eb026d98 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -24,7 +24,7 @@ from composer.models import ComposerModel from composer.optim.scheduler import ComposerScheduler from composer.utils import dist -from omegaconf import DictConfig, ListConfig +from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.optim.optimizer import Optimizer from torchmetrics import Metric @@ -36,6 +36,7 @@ from llmfoundry.eval.datasets.in_context_learning_evaluation import \ get_icl_task_dataloader from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper +from llmfoundry.utils.config_utils import to_dict_container, to_list_container from llmfoundry.utils.registry_utils import construct_from_registry log = logging.getLogger(__name__) @@ -56,9 +57,9 @@ def build_evaluators( - eval_loader_config: Optional[Union[DictConfig, ListConfig]], - icl_tasks_config: Optional[Union[str, ListConfig]], - eval_gauntlet_config: Optional[Union[str, DictConfig]], + eval_loader_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], + icl_tasks_config: Optional[Union[str, List[Dict[str, Any]]]], + eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], *, tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, @@ -91,26 +92,31 @@ def build_evaluators( def build_eval_loaders( - eval_loader_config: Union[DictConfig, ListConfig], + eval_loader_config: Union[Dict[str, Any], List[Dict[str, Any]]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, ) -> List[Evaluator]: evaluators: List[Evaluator] = [] - if isinstance(eval_loader_config, ListConfig): - eval_configs: ListConfig = eval_loader_config + if isinstance(eval_loader_config, list): + eval_configs = eval_loader_config is_multi_eval = True - else: - eval_configs = ListConfig([eval_loader_config]) + elif isinstance(eval_loader_config, dict): + eval_configs = [eval_loader_config] is_multi_eval = False + else: + raise ValueError( + f'Got invalid type for eval_loader_config: {type(eval_loader_config)}, {eval_loader_config=}', + ) for eval_config in eval_configs: + label = eval_config.pop('label') if is_multi_eval else None eval_dataloader = build_dataloader( eval_config, tokenizer, device_eval_batch_size, ) eval_loader: Evaluator = Evaluator( - label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', + label=f'eval/{label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, # Load the eval data to fail fast. metrics will get added # later in add_metrics_to_eval_loaders, after the model is loaded @@ -138,8 +144,8 @@ def add_metrics_to_eval_loaders( def build_icl_data_and_gauntlet( - icl_tasks_config: Union[str, ListConfig], - eval_gauntlet_config: Optional[Union[str, DictConfig]], + icl_tasks_config: Union[str, List[Dict[str, Any]]], + eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, @@ -157,15 +163,18 @@ def build_icl_data_and_gauntlet( if isinstance(eval_gauntlet_config, str): with open(eval_gauntlet_config, 'r') as icl_f: eval_gauntlet_cfg = om.load(icl_f) - eval_gauntlet = eval_gauntlet_cfg.eval_gauntlet - elif isinstance(eval_gauntlet_config, DictConfig): # pyright: ignore + assert isinstance(eval_gauntlet_cfg, DictConfig) + eval_gauntlet = to_dict_container( + eval_gauntlet_cfg['eval_gauntlet'], + ) + elif isinstance(eval_gauntlet_config, dict): # pyright: ignore eval_gauntlet = eval_gauntlet_config else: raise ValueError( f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}', ) - eval_gauntlet.logger_keys = logger_keys - eval_gauntlet.benchmark_sizes = { + eval_gauntlet['logger_keys'] = logger_keys + eval_gauntlet['benchmark_sizes'] = { e.label: e.dataloader.num_samples for e in icl_evaluators } eval_gauntlet_cb = EvalGauntlet(**eval_gauntlet) @@ -174,7 +183,7 @@ def build_icl_data_and_gauntlet( def build_composer_model( name: str, - cfg: DictConfig, + cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, init_context: Optional[ContextManager] = None, master_weights_dtype: Optional[str] = None, @@ -201,7 +210,7 @@ def build_composer_model( pre_validation_function=ComposerModel, post_validation_function=None, kwargs={ - 'om_model_config': cfg, + **cfg, 'tokenizer': tokenizer, }, ) @@ -400,14 +409,12 @@ def _extract_param_groups( def build_optimizer( model: torch.nn.Module, name: str, - optimizer_config: Optional[Dict[str, Any]] = None, + optimizer_config: Dict[str, Any], ) -> Optimizer: params = _extract_param_groups(model, optimizer_config) - kwargs = optimizer_config + kwargs = {**optimizer_config} - if kwargs is None: - kwargs = {} if 'params' in kwargs: raise ValueError( 'The `params` will be automatically extracted from the model and ' + @@ -490,7 +497,7 @@ def build_tokenizer( def build_icl_evaluators( - icl_tasks: Union[str, ListConfig], + icl_tasks: Union[str, List[Dict[str, Any]]], tokenizer: PreTrainedTokenizerBase, default_max_seq_len: int, default_batch_size: int, @@ -508,52 +515,52 @@ def build_icl_evaluators( log.info(f'Extracting ICL task config from path: {icl_tasks}') with open(icl_tasks, 'r') as icl_f: icl_task_cfg = om.load(icl_f) - icl_tasks_list = icl_task_cfg.icl_tasks + icl_tasks_list = to_list_container(icl_task_cfg.icl_tasks) else: icl_tasks_list = icl_tasks - def _validate_cfg(icl_cfg: DictConfig): + def _validate_cfg(icl_cfg: Dict[str, Any]): assert 'label' in icl_cfg - assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None + assert 'dataset_uri' in icl_cfg and icl_cfg['dataset_uri'] is not None assert 'icl_task_type' in icl_cfg assert 'num_fewshot' in icl_cfg if 'metric_names' not in icl_cfg: - if icl_cfg.icl_task_type == 'language_modeling': - icl_cfg.metric_names = ['InContextLearningLMAccuracy'] - elif icl_cfg.icl_task_type == 'multiple_choice': - icl_cfg.metric_names = [ + if icl_cfg['icl_task_type'] == 'language_modeling': + icl_cfg['metric_names'] = ['InContextLearningLMAccuracy'] + elif icl_cfg['icl_task_type'] == 'multiple_choice': + icl_cfg['metric_names'] = [ 'InContextLearningMultipleChoiceAccuracy', ] - elif icl_cfg.icl_task_type == 'schema': - icl_cfg.metric_names = [ + elif icl_cfg['icl_task_type'] == 'schema': + icl_cfg['metric_names'] = [ 'InContextLearningMultipleChoiceAccuracy', ] - elif icl_cfg.icl_task_type == 'generation_task_with_answers': - icl_cfg.metric_names = [ + elif icl_cfg['icl_task_type'] == 'generation_task_with_answers': + icl_cfg['metric_names'] = [ 'InContextLearningGenerationExactMatchAccuracy', ] else: raise ValueError( - f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.', + f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg["icl_task_type"]}.', ) if 'prompt_string' not in icl_cfg: - icl_cfg.prompt_string = '' + icl_cfg['prompt_string'] = '' if 'example_delimiter' not in icl_cfg: - icl_cfg.example_delimiter = '\n' + icl_cfg['example_delimiter'] = '\n' if 'continuation_delimiter' not in icl_cfg: - icl_cfg.continuation_delimiter = ' ' + icl_cfg['continuation_delimiter'] = ' ' if 'max_seq_len' not in icl_cfg: - icl_cfg.max_seq_len = default_max_seq_len + icl_cfg['max_seq_len'] = default_max_seq_len if 'batch_size' not in icl_cfg: - icl_cfg.batch_size = default_batch_size + icl_cfg['batch_size'] = default_batch_size if 'pass_at_k' not in icl_cfg: - icl_cfg.pass_at_k = 1 + icl_cfg['pass_at_k'] = 1 if 'fewshot_random_seed' not in icl_cfg: - icl_cfg.fewshot_random_seed = 1234 + icl_cfg['fewshot_random_seed'] = 1234 if 'generations_per_sample' not in icl_cfg: - icl_cfg.generations_per_sample = 1 + icl_cfg['generations_per_sample'] = 1 if 'num_beams' in icl_cfg: raise ValueError( @@ -561,18 +568,21 @@ def _validate_cfg(icl_cfg: DictConfig): 'Please use generation_kwargs.num_beams instead.') for icl_cfg in icl_tasks_list: - assert isinstance(icl_cfg, DictConfig) + assert isinstance( + icl_cfg, + dict, + ), f'Expected dict, got {type(icl_cfg)}, {icl_cfg=}' _validate_cfg(icl_cfg) - for num_fewshot in list(icl_cfg.num_fewshot): + for num_fewshot in list(icl_cfg['num_fewshot']): if tokenizer.pad_token_id is None: # Current workaround to support GPT2 tokenizer with `pad_token_id = None` pad_tok_id = tokenizer.eos_token_id else: pad_tok_id = tokenizer.pad_token_id - label = f'{icl_cfg.label}/{num_fewshot}-shot' - metric_names = list(icl_cfg.metric_names) + label = f'{icl_cfg["label"]}/{num_fewshot}-shot' + metric_names = list(icl_cfg['metric_names']) # TODO: fix Composer bug when copying local paths and destination exists - destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl' + destination_path = f'{destination_dir}/{icl_cfg["label"]}-{num_fewshot}.jsonl' if dist.get_local_rank() == 0 and os.path.exists(destination_path): os.remove(destination_path) dist.barrier() @@ -584,42 +594,36 @@ def _validate_cfg(icl_cfg: DictConfig): 'early_stopping_criteria', None, ) - if isinstance(early_stopping_criteria, ListConfig): - early_stopping_criteria = om.to_container( - early_stopping_criteria, - ) assert early_stopping_criteria is None or isinstance( early_stopping_criteria, list, ) dataloaders = get_icl_task_dataloader( - icl_cfg.icl_task_type, - icl_cfg.dataset_uri, + icl_cfg['icl_task_type'], + icl_cfg['dataset_uri'], tokenizer, - batch_size=icl_cfg.batch_size, - max_seq_len=icl_cfg.max_seq_len, + batch_size=icl_cfg['batch_size'], + max_seq_len=icl_cfg['max_seq_len'], pad_tok_id=pad_tok_id, num_fewshot=num_fewshot, - prompt_string=icl_cfg.prompt_string, - example_delimiter=icl_cfg.example_delimiter, + prompt_string=icl_cfg['prompt_string'], + example_delimiter=icl_cfg['example_delimiter'], hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, - continuation_delimiter=icl_cfg.continuation_delimiter, + continuation_delimiter=icl_cfg['continuation_delimiter'], question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, - fewshot_random_seed=icl_cfg.fewshot_random_seed, - pass_at_k=icl_cfg.pass_at_k, - generations_per_sample=icl_cfg.generations_per_sample, + fewshot_random_seed=icl_cfg['fewshot_random_seed'], + pass_at_k=icl_cfg['pass_at_k'], + generations_per_sample=icl_cfg['generations_per_sample'], has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True), ) - if hasattr( - icl_cfg, - 'has_categories', - ) and icl_cfg.has_categories and isinstance(dataloaders, dict): + if 'has_categories' in icl_cfg and icl_cfg[ + 'has_categories'] and isinstance(dataloaders, dict): for category in dataloaders.keys(): logger_keys.extend([ f'metrics/{label}/{category}/{m}' for m in metric_names diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index db180e3168..211ed08d3e 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -2,15 +2,29 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import logging import math import os import warnings -from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union +from dataclasses import dataclass, fields +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import mlflow from composer.utils import dist, parse_uri -from omegaconf import DictConfig, ListConfig +from omegaconf import MISSING, DictConfig, ListConfig, MissingMandatoryValue from omegaconf import OmegaConf as om from transformers import PretrainedConfig @@ -28,8 +42,277 @@ ] -def pop_config( +@dataclass +class EvalConfig: + # Eval Config required parameters: + models: List[Dict[str, Any]] = MISSING + max_seq_len: int = MISSING + device_eval_batch_size: int = MISSING + + # Eval Config optional parameters: + code_paths: Optional[List[str]] = None + + # Eval hyperparameters + eval_gauntlet: Optional[Dict[str, Any]] = None + eval_gauntlet_str: Optional[str] = None + eval_loader: Optional[Dict[str, Any]] = None + eval_loaders: Optional[List[Dict[str, Any]]] = None + eval_subset_num_batches: int = -1 + icl_subset_num_batches: Optional[int] = None + # One of icl_tasks or icl_tasks_str must be specified + icl_tasks: Optional[List[Dict[str, Any]]] = None + icl_tasks_str: Optional[str] = None + + # Logging parameters + python_log_level: Optional[str] = 'debug' + loggers: Optional[Dict[str, Any]] = None + log_config: bool = True + + # Model/run parameters + seed: int = 17 + precision: str = 'amp_bf16' + run_name: Optional[str] = None + metadata: Optional[Dict[str, str]] = None + + # Distributed parameters + dist_timeout: Union[float, int] = 600.0 + fsdp_config: Optional[Dict[str, Any]] = None + + # Callback parameters + callbacks: Optional[Dict[str, Any]] = None + + # Variables to ignore + variables: Optional[Dict[str, Any]] = None + + +EVAL_CONFIG_KEYS = {field.name for field in fields(EvalConfig)} + + +@dataclass +class TrainConfig: + """Dataclass for training configuration.""" + + # Mandatory model training parameters + model: Dict[str, Any] = MISSING + tokenizer: Dict[str, Any] = MISSING + optimizer: Dict[str, Any] = MISSING + scheduler: Dict[str, Any] = MISSING + train_loader: Dict[str, Any] = MISSING + device_train_batch_size: int = MISSING + device_eval_batch_size: int = MISSING + max_duration: Union[int, str] = MISSING + eval_interval: Union[int, str] = MISSING + max_seq_len: int = MISSING + seed: int = MISSING + + # Precision + precision: str = 'amp_bf16' + + # Code paths to import + code_paths: Optional[List[str]] = None + + # Cuda allocation configuration + max_split_size_mb: Optional[int] = None + expandable_segments: bool = False + cuda_load_lazy: bool = False + + # Distributed training parameters + dist_timeout: Union[int, float] = 600.0 + fsdp_config: Optional[Dict[str, Any]] = None + + # Evaluation parameters + eval_loader: Optional[Dict[str, Any]] = None + eval_loaders: Optional[List[Dict[str, Any]] + ] = None # should not be set by the user + icl_tasks: Optional[List[Dict[str, Any]]] = None + icl_tasks_str: Optional[str] = None # should not be set by the user + eval_gauntlet: Optional[Dict[str, Any]] = None + eval_gauntlet_str: Optional[str] = None # should not be set by the user + icl_subset_num_batches: Optional[int] = None + icl_seq_len: Optional[int] = None + + # Logging + loggers: Optional[Dict[str, Any]] = None + progress_bar: bool = False + log_to_console: bool = True + python_log_level: Optional[str] = 'debug' + console_log_interval: Union[int, str] = '1ba' + log_config: bool = True + + # Callbacks + callbacks: Optional[Dict[str, Any]] = None + algorithms: Optional[Dict[str, Any]] = None + + # Checkpoints + save_folder: Optional[str] = None + save_latest_filename: Optional[str] = None + save_overwrite: bool = False + save_weights_only: bool = False + save_filename: Optional[str] = None + save_interval: Union[str, int] = '1000ba' + save_num_checkpoints_to_keep: int = -1 + load_path: Optional[str] = None + load_weights_only: bool = False + load_strict_model_weights: bool = True + load_ignore_keys: Optional[List[str]] = None + save_ignore_keys: Optional[List[str]] = None + + # Dataloader + device_train_microbatch_size: Union[str, int] = 'auto' + global_train_batch_size: Optional[int] = None + + # Eval dataloader + eval_subset_num_batches: int = -1 + eval_first: bool = False + compile_config: Optional[Dict[str, Any]] = None + + # Metadata + metadata: Optional[Dict[str, Any]] = None + run_name: Optional[str] = None + + # Resumption + autoresume: bool = False + + # Profiling + profiler: Optional[Dict[str, Any]] = None + + # Variables to ignore + variables: Optional[Dict[str, Any]] = None + + +TRAIN_CONFIG_KEYS = {field.name for field in fields(TrainConfig)} + + +def forbid_config_key(cfg_dict: Dict[str, Any], key: str): + if key in cfg_dict: + raise ValueError( + f'Config key `{key}` should not be set. Please remove it from the config.', + ) + + +def to_dict_container(cfg: Union[DictConfig, Dict[str, Any]]) -> Dict[str, Any]: + maybe_dict = to_container(cfg) + if isinstance(maybe_dict, dict): + return maybe_dict + else: + raise ValueError(f'Expected a dict-like type, got {type(maybe_dict)}') + + +def to_list_container( + cfg: Union[ListConfig, List[Dict[str, Any]]], +) -> List[Dict[str, Any]]: + maybe_list = to_container(cfg) + if isinstance(maybe_list, list): + return maybe_list + else: + raise ValueError(f'Expected a list-like type, got {type(maybe_list)}') + + +def to_container( + cfg: Optional[Union[DictConfig, ListConfig, Dict[str, Any], + List[Dict[str, Any]]]], +) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """Converts a DictConfig or ListConfig to a dict or list. + + `omegaconf.to_container` does not handle nested DictConfig or ListConfig + objects, so this function is used to convert them to dicts or lists. + """ + if isinstance(cfg, DictConfig): + ret = om.to_container(cfg, resolve=True) + assert isinstance(ret, dict) + return ret # type: ignore (return type is correct and converting all keys to str would be unnecessarily costly) + elif isinstance(cfg, ListConfig): + ret = om.to_container(cfg, resolve=True) + assert isinstance(ret, list) + return ret # type: ignore (see above) + else: + return cfg # type: ignore (dicts and lists are already in the correct format) + + +T = TypeVar('T') + + +def make_dataclass_and_log_config( cfg: DictConfig, + dataclass_constructor: Callable[..., T], + dataclass_fields: Set[str], + transforms: Optional[List[Callable[[Dict[str, Any]], Dict[str, + Any]]]] = None, + icl_tasks_required: bool = False, +) -> Tuple[Dict[str, Any], T]: + """Converts a DictConfig to a dataclass and creates a logged config.""" + # Resolve all interpolation variables as early as possible + unstructured_config = om.to_container(cfg, resolve=True) + assert isinstance(unstructured_config, dict) + assert all(isinstance(k, str) for k in unstructured_config.keys()) + unstructured_config = {str(k): v for k, v in unstructured_config.items()} + + # Flatten union types before creating structured config: + if 'eval_gauntlet' in unstructured_config: + forbid_config_key(unstructured_config, 'eval_gauntlet_str') + if isinstance(unstructured_config['eval_gauntlet'], str): + unstructured_config['eval_gauntlet_str'] = unstructured_config.pop( + 'eval_gauntlet', + ) + if (loader := unstructured_config.get('eval_loader', None)) is not None: + forbid_config_key(unstructured_config, 'eval_loaders') + if isinstance(loader, list): + unstructured_config['eval_loaders'] = unstructured_config.pop( + 'eval_loader', + ) + if 'icl_tasks' in unstructured_config: + forbid_config_key(unstructured_config, 'icl_tasks_str') + if isinstance(unstructured_config['icl_tasks'], str): + unstructured_config['icl_tasks_str'] = unstructured_config.pop( + 'icl_tasks', + ) + else: + if icl_tasks_required: + raise MissingMandatoryValue( + 'icl_tasks must be specified in the config', + ) + + # Create copy of config for logging + logged_cfg: Dict[str, Any] = copy.deepcopy(unstructured_config) + + # Apply transforms to the unstructured config before constructing dataclass + for transform in transforms or []: + unstructured_config = transform(unstructured_config) + + logged_cfg.update(unstructured_config, merge=True) + + arg_config_keys = set(unstructured_config.keys()) + extraneous_keys = set.difference(arg_config_keys, dataclass_fields) + + if 'variables' not in unstructured_config: + unstructured_config['variables'] = {} + + for key in extraneous_keys: + warnings.warn( + f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary. Interpreting {key} as a variable for logging purposes. Top-level variables are deprecated and will not be supported in future releases. Please place any variables under the `variables` key.', + category=DeprecationWarning, + ) + unstructured_config['variables'][key] = unstructured_config.pop(key) + + dataclass_dict_config: DictConfig = om.structured( + dataclass_constructor(**unstructured_config), + ) + + # Error on missing mandatory values: + for key in dataclass_fields: + _ = dataclass_dict_config[key] + + # Convert DictConfig to dict for dataclass constructor so that child + # configs are not DictConfigs + dataclass_config: T = dataclass_constructor( + **to_dict_container(dataclass_dict_config), + ) + + return logged_cfg, dataclass_config + + +def pop_config( + cfg: Union[Dict[str, Any], DictConfig], key: str, must_exist: bool = True, default_value: Any = None, @@ -106,40 +389,41 @@ def calculate_batch_size_info( # Coming soon: this conversion math will be done inside Composer Trainer -def update_batch_size_info(cfg: DictConfig) -> DictConfig: +def update_batch_size_info(cfg: Dict[str, Any]) -> Dict[str, Any]: device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( - cfg.global_train_batch_size, - cfg.device_train_microbatch_size, + cfg['global_train_batch_size'], + cfg['device_train_microbatch_size'], ) - cfg.n_gpus = dist.get_world_size() - cfg.device_train_batch_size = device_train_batch_size - cfg.device_train_microbatch_size = device_train_microbatch_size - cfg.device_train_grad_accum = device_train_grad_accum + cfg['n_gpus'] = dist.get_world_size() + cfg['device_train_batch_size'] = device_train_batch_size + cfg['device_train_microbatch_size'] = device_train_microbatch_size + cfg['device_train_grad_accum'] = device_train_grad_accum # Safely set `device_eval_batch_size` if not provided by user if 'device_eval_batch_size' not in cfg: - if cfg.device_train_microbatch_size == 'auto': - cfg.device_eval_batch_size = 1 # TODO debug auto eval microbatching + if cfg['device_train_microbatch_size'] == 'auto': + cfg['device_eval_batch_size' + ] = 1 # TODO debug auto eval microbatching else: - cfg.device_eval_batch_size = cfg.device_train_microbatch_size + cfg['device_eval_batch_size'] = cfg['device_train_microbatch_size'] return cfg -def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): +def process_init_device(model_cfg: Dict[str, Any], fsdp_config: Optional[Dict]): # Restrict model init_device to 'meta' and 'cpu', # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors # when multiple GPUs are available. # Also 'meta' is only valid when using FSDP init_context = contextlib.nullcontext() if 'init_device' in model_cfg: - assert model_cfg.init_device in ['meta', 'cpu', 'mixed'] - if fsdp_config is None and model_cfg.init_device == 'meta': + assert model_cfg['init_device'] in ['meta', 'cpu', 'mixed'] + if fsdp_config is None and model_cfg['init_device'] == 'meta': warnings.warn( "Using `cfg.model.init_device='meta'` is only valid when using FSDP! " +\ "Reverting to `cfg.model.init_device='cpu'`.") - model_cfg.init_device = 'cpu' - if model_cfg.init_device == 'meta': + model_cfg['init_device'] = 'cpu' + if model_cfg['init_device'] == 'meta': init_context = init_empty_weights() - if model_cfg.init_device == 'mixed': + if model_cfg['init_device'] == 'mixed': if fsdp_config is None: raise NotImplementedError( 'Using init_device `mixed` is only supported with FSDP. ' + @@ -168,7 +452,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): raise ValueError( 'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.', ) - model_cfg.ffn_config.device_mesh = fsdp_config['device_mesh'] + model_cfg['ffn_config']['device_mesh'] = fsdp_config['device_mesh'] # No mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') @@ -197,7 +481,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): return init_context -def log_config(cfg: DictConfig) -> None: +def log_config(cfg: Dict[str, Any]) -> None: """Logs the current config and updates the wandb and mlflow configs. This function can be called multiple times to update the wandb and MLflow @@ -210,14 +494,14 @@ def log_config(cfg: DictConfig) -> None: except ImportError as e: raise e if wandb.run: - wandb.config.update(om.to_container(cfg, resolve=True)) + wandb.config.update(cfg) if 'mlflow' in cfg.get('loggers', {}) and mlflow.active_run(): mlflow.log_params(params=om.to_container(cfg, resolve=True)) _log_dataset_uri(cfg) -def _parse_source_dataset(cfg: DictConfig) -> List[Tuple[str, str, str]]: +def _parse_source_dataset(cfg: Dict[str, Any]) -> List[Tuple[str, str, str]]: """Parse a run config for dataset information. Given a config dictionary, parse through it to determine what the datasource @@ -233,7 +517,7 @@ def _parse_source_dataset(cfg: DictConfig) -> List[Tuple[str, str, str]]: data_paths = [] # Handle train loader if it exists - train_dataset = cfg.get('train_loader', {}).get('dataset', {}) + train_dataset: Dict = cfg.get('train_loader', {}).get('dataset', {}) train_split = train_dataset.get('split', None) train_source_path = cfg.get('source_dataset_train', None) _process_data_source( @@ -246,13 +530,14 @@ def _parse_source_dataset(cfg: DictConfig) -> List[Tuple[str, str, str]]: # Handle eval_loader which might be a list or a single dictionary eval_data_loaders = cfg.get('eval_loader', {}) - if not isinstance(eval_data_loaders, ListConfig): + if not isinstance(eval_data_loaders, list): eval_data_loaders = [ eval_data_loaders, ] # Normalize to list if it's a single dictionary for eval_data_loader in eval_data_loaders: - eval_dataset = eval_data_loader.get('dataset', {}) + assert isinstance(eval_data_loader, dict) # pyright type check + eval_dataset: Dict = eval_data_loader.get('dataset', {}) eval_split = eval_dataset.get('split', None) eval_source_path = cfg.get('source_dataset_eval', None) _process_data_source( @@ -320,7 +605,7 @@ def _process_data_source( log.warning('DataSource Not Found.') -def _log_dataset_uri(cfg: DictConfig) -> None: +def _log_dataset_uri(cfg: Dict[str, Any]) -> None: """Logs dataset tracking information to MLflow. Args: diff --git a/llmfoundry/utils/mosaicml_logger_utils.py b/llmfoundry/utils/mosaicml_logger_utils.py index a65ebd9454..b01170ff0f 100644 --- a/llmfoundry/utils/mosaicml_logger_utils.py +++ b/llmfoundry/utils/mosaicml_logger_utils.py @@ -10,7 +10,6 @@ MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR, ) -from omegaconf import DictConfig, ListConfig __all__ = [ 'maybe_create_mosaicml_logger', @@ -50,9 +49,9 @@ def find_mosaicml_logger( def log_eval_analytics( mosaicml_logger: MosaicMLLogger, - model_configs: ListConfig, - icl_tasks: Union[str, ListConfig], - eval_gauntlet_config: Optional[Union[str, DictConfig]], + model_configs: List[Dict[str, Any]], + icl_tasks: Union[str, List[Dict[str, Any]]], + eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], ): """Logs analytics for runs using the `eval.py` script.""" metrics: Dict[str, Any] = { @@ -84,14 +83,14 @@ def log_eval_analytics( def log_train_analytics( mosaicml_logger: MosaicMLLogger, - model_config: DictConfig, - train_loader_config: DictConfig, - eval_loader_config: Optional[Union[DictConfig, ListConfig]], - callback_configs: Optional[DictConfig], + model_config: Dict[str, Any], + train_loader_config: Dict[str, Any], + eval_loader_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], + callback_configs: Optional[Dict[str, Any]], tokenizer_name: str, load_path: Optional[str], - icl_tasks_config: Optional[Union[ListConfig, str]], - eval_gauntlet: Optional[Union[DictConfig, str]], + icl_tasks_config: Optional[Union[List[Dict[str, Any]], str]], + eval_gauntlet: Optional[Union[Dict[str, Any], str]], ): """Logs analytics for runs using the `train.py` script.""" train_loader_dataset = train_loader_config.get('dataset', {}) @@ -128,10 +127,10 @@ def log_train_analytics( if eval_loader_config is not None: metrics['llmfoundry/eval_loaders'] = [] - if isinstance(eval_loader_config, ListConfig): - eval_loader_configs: ListConfig = eval_loader_config + if isinstance(eval_loader_config, list): + eval_loader_configs: list = eval_loader_config else: - eval_loader_configs = ListConfig([eval_loader_config]) + eval_loader_configs = [eval_loader_config] for loader_config in eval_loader_configs: eval_loader_info = {} diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index b925385e3e..a1930b3045 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -1,12 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import copy import logging import os import sys import time -import warnings from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd @@ -15,7 +13,7 @@ from composer.loggers.logger_destination import LoggerDestination from composer.trainer import Trainer from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig, ListConfig +from omegaconf import DictConfig from omegaconf import OmegaConf as om from rich.traceback import install @@ -35,8 +33,10 @@ build_tokenizer, ) from llmfoundry.utils.config_utils import ( + EVAL_CONFIG_KEYS, + EvalConfig, log_config, - pop_config, + make_dataclass_and_log_config, process_init_device, ) from llmfoundry.utils.registry_utils import import_file @@ -45,34 +45,33 @@ def evaluate_model( - model_cfg: DictConfig, + tokenizer: Dict[str, Any], + model_name: str, + model: Dict[str, Any], dist_timeout: Union[float, int], run_name: str, seed: int, - icl_tasks: Union[str, ListConfig], + icl_tasks: Union[str, List[Dict[str, Any]]], max_seq_len: int, device_eval_batch_size: int, - eval_gauntlet_config: Optional[Union[str, DictConfig]], - eval_loader_config: Optional[Union[DictConfig, ListConfig]], - fsdp_config: Optional[Dict], + eval_gauntlet_config: Optional[Union[str, Dict[str, Any]]], + eval_loader_config: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], + fsdp_config: Optional[Dict[str, Any]], loggers: List[LoggerDestination], python_log_level: Optional[str], precision: str, eval_gauntlet_df: Optional[pd.DataFrame], eval_subset_num_batches: int, icl_subset_num_batches: Optional[int], - callback_configs: Optional[DictConfig], + callback_configs: Optional[Dict[str, Any]], metadata: Optional[Dict[str, str]], - logged_config: DictConfig, + logged_config: Dict[str, Any], should_log_config: bool = True, + load_path: Optional[str] = None, ): - - log.info(f'Evaluating model: {model_cfg.model_name}') + log.info(f'Evaluating model: {model_name}') # Build tokenizer and model - tokenizer_cfg: Dict[str, Any] = om.to_container( - model_cfg.tokenizer, - resolve=True, - ) # type: ignore + tokenizer_cfg = tokenizer tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) @@ -104,19 +103,20 @@ def evaluate_model( mosaicml_logger.log_metrics(metadata) mosaicml_logger._flush_metadata(force_flush=True) - if fsdp_config and model_cfg.model.get('load_in_8bit', False): + if fsdp_config and model.get('load_in_8bit', False): raise ValueError( 'The FSDP config block is not supported when loading ' + 'Hugging Face models in 8bit.', ) - init_context = process_init_device(model_cfg.model, fsdp_config) + init_context = process_init_device(model, fsdp_config) + name = model.pop('name') composer_model = build_composer_model( - name=model_cfg.model.name, - cfg=model_cfg.model, + name=name, tokenizer=tokenizer, init_context=init_context, + cfg=model, ) # Now add the eval metrics @@ -130,11 +130,10 @@ def evaluate_model( if eval_gauntlet_df is None and eval_gauntlet_callback is not None: eval_gauntlet_df = pd.DataFrame( columns=['model_name'] + list(eval_gauntlet_callback.averages) + - [t.name for t in eval_gauntlet_callback.categories], + [t['name'] for t in eval_gauntlet_callback.categories], ) - load_path = model_cfg.get('load_path', None) - if model_cfg.model.name == 'mpt_causal_lm' and load_path is None: + if name == 'mpt_causal_lm' and load_path is None: raise ValueError( 'MPT causal LMs require a load_path to the checkpoint for model evaluation.' + @@ -143,7 +142,7 @@ def evaluate_model( assert composer_model is not None - log.info(f'Building trainer for {model_cfg.model_name}...') + log.info(f'Building trainer for {model_name}...') trainer = Trainer( run_name=run_name, seed=seed, @@ -164,7 +163,7 @@ def evaluate_model( log.info('Evaluation config:') log_config(logged_config) - log.info(f'Starting eval for {model_cfg.model_name}...') + log.info(f'Starting eval for {model_name}...') if torch.cuda.is_available(): torch.cuda.synchronize() a = time.time() @@ -176,148 +175,61 @@ def evaluate_model( torch.cuda.synchronize() b = time.time() - log.info(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') + log.info(f'Ran {model_name} eval in: {b-a} seconds') return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: # Run user provided code if specified - code_paths = pop_config( - cfg, - 'code_paths', - must_exist=False, - default_value=[], - convert=True, - ) - for code_path in code_paths: + for code_path in cfg.get('code_paths', []): import_file(code_path) - om.resolve(cfg) - - # Create copy of config for logging - logged_cfg: DictConfig = copy.deepcopy(cfg) - - model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True) - eval_gauntlet_config: Optional[ - Union[str, DictConfig] - ] = pop_config(cfg, 'eval_gauntlet', must_exist=False, default_value=None) - - fsdp_dict_cfg: Optional[DictConfig] = pop_config( + logged_cfg, eval_config = make_dataclass_and_log_config( cfg, - 'fsdp_config', - must_exist=False, - default_value=None, + EvalConfig, + EVAL_CONFIG_KEYS, + icl_tasks_required=True, ) - fsdp_config: Optional[Dict] = om.to_container( - fsdp_dict_cfg, - resolve=True, - ) if fsdp_dict_cfg is not None else None # type: ignore - assert isinstance(fsdp_config, Dict) or fsdp_config is None + + model_configs = eval_config.models + eval_gauntlet_config = eval_config.eval_gauntlet + + fsdp_config = eval_config.fsdp_config # Mandatory Evaluation Parameters - icl_tasks: Union[ - str, ListConfig] = pop_config(cfg, 'icl_tasks', must_exist=True) - max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True) - device_eval_batch_size: int = pop_config( - cfg, - 'device_eval_batch_size', - must_exist=True, - ) - precision: str = pop_config( - cfg, - 'precision', - must_exist=False, - default_value=None, - ) - python_log_level: Optional[str] = pop_config( - cfg, - 'python_log_level', - must_exist=False, - default_value='debug', - ) + icl_tasks = eval_config.icl_tasks or eval_config.icl_tasks_str + if icl_tasks is None: + raise ValueError('icl_tasks must be specified in the config') # Optional Evaluation Parameters with default values - eval_loader_config: Optional[ - Union[DictConfig, ListConfig] - ] = pop_config(cfg, 'eval_loader', must_exist=False, default_value=None) - seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17) - dist_timeout: Union[float, int] = pop_config( - cfg, - 'dist_timeout', - must_exist=False, - default_value=600.0, - ) + eval_loader_config = eval_config.eval_loader or eval_config.eval_loaders default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = pop_config( - cfg, - 'run_name', - must_exist=False, - default_value=default_run_name, - ) - loggers_cfg: Dict[ - str, - Any] = pop_config(cfg, 'loggers', must_exist=False, default_value={}) - eval_subset_num_batches: int = pop_config( - cfg, - 'eval_subset_num_batches', - must_exist=False, - default_value=-1, - ) - icl_subset_num_batches: Optional[int] = pop_config( - cfg, - 'icl_subset_num_batches', - must_exist=False, - default_value=None, - ) - metadata: Optional[Dict[str, str]] = pop_config( - cfg, - 'metadata', - must_exist=False, - default_value=None, - convert=True, - ) - should_log_config: bool = pop_config( - cfg, - 'log_config', - must_exist=False, - default_value=True, - ) - - # Pop out interpolation variables. - pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None) - callback_configs: Optional[DictConfig] = pop_config( - cfg, - 'callbacks', - must_exist=False, - default_value=None, - ) - - # Warn for unused parameters - for key in cfg: - warnings.warn( - f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.', - ) + run_name = eval_config.run_name if eval_config.run_name else default_run_name - reproducibility.seed_all(seed) - dist.initialize_dist(get_device(None), timeout=dist_timeout) + reproducibility.seed_all(eval_config.seed) + dist.initialize_dist(get_device(None), timeout=eval_config.dist_timeout) - if python_log_level is not None: + if eval_config.python_log_level is not None: logging.basicConfig( # Example of format string # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here format= f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', ) - logging.getLogger('llmfoundry').setLevel(python_log_level.upper()) + logging.getLogger('llmfoundry').setLevel( + eval_config.python_log_level.upper(), + ) + # Default argument values for evaluate_model eval_gauntlet_df = None models_df = None composite_scores = None trainers = [] + # Build loggers loggers: List[LoggerDestination] = [ build_logger(name, logger_cfg) - for name, logger_cfg in loggers_cfg.items() + for name, logger_cfg in (eval_config.loggers or {}).items() ] mosaicml_logger = find_mosaicml_logger(loggers) @@ -338,7 +250,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: for model_cfg in model_configs: - attn_config = model_cfg.model.get('attn_config', None) + attn_config = model_cfg['model'].get('attn_config', None) if attn_config is not None: seq_parallel_world_size = attn_config.get( 'seq_parallel_world_size', @@ -351,26 +263,26 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) = evaluate_model( - model_cfg=model_cfg, - dist_timeout=dist_timeout, + dist_timeout=eval_config.dist_timeout, run_name=run_name, - seed=seed, + seed=eval_config.seed, icl_tasks=icl_tasks, - max_seq_len=max_seq_len, - device_eval_batch_size=device_eval_batch_size, + max_seq_len=eval_config.max_seq_len, + device_eval_batch_size=eval_config.device_eval_batch_size, eval_gauntlet_config=eval_gauntlet_config, eval_loader_config=eval_loader_config, fsdp_config=fsdp_config, loggers=loggers, - python_log_level=python_log_level, - precision=precision, + python_log_level=eval_config.python_log_level, + precision=eval_config.precision, eval_gauntlet_df=eval_gauntlet_df, - callback_configs=callback_configs, - eval_subset_num_batches=eval_subset_num_batches, - icl_subset_num_batches=icl_subset_num_batches, - metadata=metadata, + callback_configs=eval_config.callbacks, + eval_subset_num_batches=eval_config.eval_subset_num_batches, + icl_subset_num_batches=eval_config.icl_subset_num_batches, + metadata=eval_config.metadata, logged_config=logged_cfg, - should_log_config=should_log_config, + should_log_config=eval_config.log_config, + **model_cfg, ) trainers.append(trainer) @@ -383,14 +295,15 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: benchmark_to_taxonomy = {} if eval_gauntlet_callback is not None: for t in eval_gauntlet_callback.categories: - for b in t.benchmarks: - benchmark_to_taxonomy[b.name] = t.name + for b in t['benchmarks']: + benchmark_to_taxonomy[b['name']] = t['name'] + assert 'model_name' in model_cfg, 'model_name must be specified in model config' model_results = calculate_markdown_results( logger_keys, trainer, benchmark_to_taxonomy, - model_cfg.model_name, + model_cfg['model_name'], ) if models_df is None: diff --git a/scripts/eval/yamls/hf_8bit_eval.yaml b/scripts/eval/yamls/hf_8bit_eval.yaml index 3bf3c23414..30da2e5ef3 100644 --- a/scripts/eval/yamls/hf_8bit_eval.yaml +++ b/scripts/eval/yamls/hf_8bit_eval.yaml @@ -1,22 +1,24 @@ -max_seq_len: 1024 +variables: + model_name_or_path: bigscience/bloom-1b7 + max_seq_len: 1024 + seed: 1 precision: amp_fp16 - -model_name_or_path: bigscience/bloom-1b7 +max_seq_len: ${variables.max_seq_len} models: - - model_name: ${model_name_or_path} + model_name: ${variables.model_name_or_path} model: name: hf_causal_lm - pretrained_model_name_or_path: ${model_name_or_path} + pretrained_model_name_or_path: ${variables.model_name_or_path} init_device: mixed pretrained: true load_in_8bit: true tokenizer: - name: ${model_name_or_path} + name: ${variables.model_name_or_path} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} device_eval_batch_size: 4 diff --git a/scripts/eval/yamls/hf_eval.yaml b/scripts/eval/yamls/hf_eval.yaml index 9eb0245f9a..708c871d88 100644 --- a/scripts/eval/yamls/hf_eval.yaml +++ b/scripts/eval/yamls/hf_eval.yaml @@ -1,23 +1,26 @@ -max_seq_len: 1024 -seed: 1 -precision: fp32 +variables: + # If you are using one model, put it here: + model_name_or_path: EleutherAI/gpt-neo-125m + # otherwise, write a block for each model you want to test in the `models` section -# If you are using one model, put it here: -model_name_or_path: EleutherAI/gpt-neo-125m -# otherwise, write a block for each model you want to test in the `models` section + precision: fp32 + max_seq_len: 1024 + +precision: ${variables.precision} +max_seq_len: ${variables.max_seq_len} models: - - model_name: ${model_name_or_path} + model_name: ${variables.model_name_or_path} model: name: hf_causal_lm - pretrained_model_name_or_path: ${model_name_or_path} + pretrained_model_name_or_path: ${variables.model_name_or_path} init_device: mixed pretrained: true tokenizer: - name: ${model_name_or_path} + name: ${variables.model_name_or_path} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # # if you are evaluating more than one model, list them all as YAML blocks without variable interpolation # - # model_name: mosaicml/mpt-7b @@ -27,11 +30,11 @@ models: # init_device: cpu # pretrained: true # config_overrides: -# max_seq_len: ${max_seq_len} +# max_seq_len: ${variables.max_seq_len} # tokenizer: # name: mosaicml/mpt-7b # kwargs: -# model_max_length: ${max_seq_len} +# model_max_length: ${variables.max_seq_len} device_eval_batch_size: 4 diff --git a/scripts/eval/yamls/hf_lora_eval.yml b/scripts/eval/yamls/hf_lora_eval.yml index 08861b8569..e1e87968bc 100644 --- a/scripts/eval/yamls/hf_lora_eval.yml +++ b/scripts/eval/yamls/hf_lora_eval.yml @@ -1,25 +1,27 @@ -max_seq_len: 2048 +variables: + model_name_or_path: facebook/opt-350m + # If you are using a seperated lora weight, put it here: + # lora weights must be compatible with the specified model + lora_id_or_path: ybelkada/opt-350m-lora # Example lora weights for opt-350m + max_seq_len: 2048 + seed: 1 precision: amp_fp16 - -model_name_or_path: facebook/opt-350m -# If you are using a seperated lora weight, put it here: -# lora weights must be compatible with the specified model -lora_id_or_path: ybelkada/opt-350m-lora # Example lora weights for opt-350m +max_seq_len: ${variables.max_seq_len} models: - - model_name: ${model_name_or_path} + model_name: ${variables.model_name_or_path} model: name: hf_causal_lm - pretrained_model_name_or_path: ${model_name_or_path} + pretrained_model_name_or_path: ${variables.model_name_or_path} init_device: mixed pretrained: true - pretrained_lora_id_or_path: ${lora_id_or_path} + pretrained_lora_id_or_path: ${variables.lora_id_or_path} tokenizer: - name: ${model_name_or_path} + name: ${variables.model_name_or_path} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} device_eval_batch_size: 4 diff --git a/scripts/eval/yamls/mpt_eval.yaml b/scripts/eval/yamls/mpt_eval.yaml index f59a73f15b..5274bd0b9d 100644 --- a/scripts/eval/yamls/mpt_eval.yaml +++ b/scripts/eval/yamls/mpt_eval.yaml @@ -1,16 +1,19 @@ -max_seq_len: 1024 -tokenizer_name: EleutherAI/gpt-neox-20b +variables: + tokenizer_name: EleutherAI/gpt-neox-20b + max_seq_len: 1024 + seed: 1 precision: amp_fp16 +max_seq_len: ${variables.max_seq_len} models: - model_name: mpt_test # Tokenizer tokenizer: - name: ${tokenizer_name} + name: ${variables.tokenizer_name} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} model: name: mpt_causal_lm init_device: mixed @@ -19,7 +22,7 @@ models: n_heads: 12 n_layers: 12 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash diff --git a/scripts/inference/benchmarking/benchmark.py b/scripts/inference/benchmarking/benchmark.py index 463183442d..f85b895316 100644 --- a/scripts/inference/benchmarking/benchmark.py +++ b/scripts/inference/benchmarking/benchmark.py @@ -66,10 +66,11 @@ def main(config: DictConfig): tokenizer_name=tokenizer_name, tokenizer_kwargs=tokenizer_kwargs, ) + name = config.model.pop('name') composer_model = build_composer_model( - name=config.model.name, - cfg=config.model, + name=name, tokenizer=tokenizer, + cfg=config.model, ) model = composer_model.model model.eval() diff --git a/scripts/train/finetune_example/gpt2-arc-easy--cpu.yaml b/scripts/train/finetune_example/gpt2-arc-easy--cpu.yaml index 2b1821c92c..635313d4bc 100644 --- a/scripts/train/finetune_example/gpt2-arc-easy--cpu.yaml +++ b/scripts/train/finetune_example/gpt2-arc-easy--cpu.yaml @@ -1,5 +1,7 @@ +variables: + global_seed: 17 + max_seq_len: 512 -global_seed: 17 # Run Name run_name: # If left blank, will be read from env var $RUN_NAME diff --git a/scripts/train/train.py b/scripts/train/train.py index 77e7a49732..e8f5b8220a 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,6 +1,5 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import copy import gc import logging import os @@ -10,7 +9,8 @@ from typing import Any, Dict, List, Optional, Union import torch -from composer import Trainer +import torch.distributed +from composer import ComposerModel, Trainer from composer.core.callback import Callback from composer.profiler import ( JSONTraceHandler, @@ -19,7 +19,7 @@ cyclic_schedule, ) from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig, ListConfig +from omegaconf import DictConfig from omegaconf import OmegaConf as om from rich.traceback import install @@ -31,6 +31,7 @@ ) install() + from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader from llmfoundry.layers_registry import ffns_with_megablocks @@ -46,7 +47,10 @@ build_tokenizer, ) from llmfoundry.utils.config_utils import ( + TRAIN_CONFIG_KEYS, + TrainConfig, log_config, + make_dataclass_and_log_config, pop_config, process_init_device, update_batch_size_info, @@ -56,38 +60,38 @@ log = logging.getLogger(__name__) -def validate_config(cfg: DictConfig): +def validate_config(train_config: TrainConfig): """Validates compatible model and dataloader selection.""" - loaders = [cfg.train_loader] - if 'eval_loader' in cfg: - eval_loader = cfg.eval_loader - if isinstance(eval_loader, ListConfig): - for loader in eval_loader: - if loader.label is None: - raise ValueError( - 'When specifying multiple evaluation datasets, each one must include the \ + # Validate the rest of the config + loaders = [train_config.train_loader] + if train_config.eval_loaders is not None: + for loader in (train_config.eval_loaders or []): # pyright + if 'label' not in loader or loader['label'] is None: + raise ValueError( + 'When specifying multiple evaluation datasets, each one must include the \ `label` attribute.', - ) - loaders.append(loader) - else: - loaders.append(eval_loader) + ) + loaders.append(loader) + if train_config.eval_loader is not None: + loaders.append(train_config.eval_loader) for loader in loaders: - if loader.name == 'text': - if cfg.model.name == 'hf_t5': + if loader['name'] == 'text': + if train_config.model['name'] == 'hf_t5': raise ValueError( - f'Model type "{cfg.model.name}" is not supported when using the "text " ' +\ + f'Model type "{train_config.model["name"]}" is not supported when using the "text " ' +\ f'dataloader. Only finetuning is supported.') - if 'icl_tasks' in cfg: - if cfg.model.name == 'hf_t5': + if train_config.icl_tasks is not None or train_config.icl_tasks_str is not None: + if train_config.model['name'] == 'hf_t5': raise ValueError( 'ICL evaluation does not currently support Encoder-Decoder models, such as "hf_t5".', ) if ( - cfg.model.get('fc_type', 'torch') != 'te' and 'te' - not in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') and - 'fp8' in cfg.precision + train_config.model.get('fc_type', 'torch') != 'te' and + 'te' not in train_config.model.get('ffn_config', + {}).get('ffn_type', 'mptmlp') and + 'fp8' in train_config.precision ): warnings.warn( "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " @@ -96,47 +100,57 @@ def validate_config(cfg: DictConfig): ) if ( - cfg.model.get('fc_type', 'torch') == 'te' or - 'te' in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') + train_config.model.get('fc_type', 'torch') == 'te' or 'te' + in train_config.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') ): - fsdp_config = cfg.get('fsdp_config', None) - act_ckpt = fsdp_config.get('activation_checkpointing', False) + fsdp_config = train_config.fsdp_config + act_ckpt = fsdp_config.get( + 'activation_checkpointing', + False, + ) if fsdp_config else False act_ckpt_reentrant = fsdp_config.get( 'activation_checkpointing_reentrant', False, - ) + ) if fsdp_config else False if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True: warnings.warn( '`te.Linear` layers do not support activation_checkpointing with ' + '`activation_checkpointing_reentrant = True`. ' + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.', ) - cfg.fsdp_config.activation_checkpointing_reentrant = False + assert train_config.fsdp_config is not None # pyright (this is known because fsdp_config is not None) + train_config.fsdp_config['activation_checkpointing_reentrant' + ] = False - if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': + if train_config.model.get('ffn_config', + {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp': warnings.warn( '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.', ) torch._dynamo.config.suppress_errors = True # type: ignore (third-party) - if cfg.model.get('load_in_8bit', False): + if train_config.model.get('load_in_8bit', False): raise ValueError( '`load_in_8bit` is only supported for evaluation rather than training.', ) - if cfg.model.get('ffn_config', - {}).get('ffn_type', 'mptmlp') in ffns_with_megablocks: - moe_world_size = cfg.model.get('ffn_config', - {}).get('moe_world_size', 1) - use_orig_params = cfg.get('fsdp_config', - {}).get('use_orig_params', True) + if train_config.model.get('ffn_config', {}).get( + 'ffn_type', + 'mptmlp', + ) in ffns_with_megablocks: + moe_world_size = train_config.model.get('ffn_config', + {}).get('moe_world_size', 1) + use_orig_params = train_config.fsdp_config.get( + 'use_orig_params', + True, + ) if train_config.fsdp_config is not None else True if moe_world_size > 1 and not use_orig_params: raise ValueError( f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.', ) - attn_config = cfg.model.get('attn_config', None) + attn_config = train_config.model.get('attn_config', None) if attn_config is not None: seq_parallel_world_size = attn_config.get( 'seq_parallel_world_size', @@ -146,6 +160,27 @@ def validate_config(cfg: DictConfig): raise ValueError('Training does not support sequence parallelism.') +def _log_num_params(model: ComposerModel, logged_cfg: Dict[str, Any]): + # Log number of parameters + if hasattr(model, 'n_total_params'): + n_params = model.n_total_params + n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. + else: + n_params = sum(p.numel() for p in model.parameters()) + n_trainable_params = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + if hasattr(model, 'n_active_params'): + n_active_params = model.n_active_params + else: + n_active_params = n_params + logged_cfg.update({ + 'n_params': n_params, + 'n_active_params': n_active_params, + 'n_trainable_params': n_trainable_params, + }) + + def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): """Initialize distributed and test setup with a barrier. @@ -160,18 +195,35 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): def main(cfg: DictConfig) -> Trainer: - # Run user provided code if specified - code_paths = pop_config( - cfg, - 'code_paths', - must_exist=False, - default_value=[], - convert=True, - ) + code_paths = cfg.get('code_paths', []) # Import any user provided code for code_path in code_paths: import_file(code_path) + logged_cfg, train_cfg = make_dataclass_and_log_config( + cfg, + TrainConfig, + TRAIN_CONFIG_KEYS, + transforms=[update_batch_size_info], + ) + + # Set logging level + if train_cfg.python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', + ) + logging.getLogger('llmfoundry').setLevel( + train_cfg.python_log_level.upper(), + ) # Foundry module + logging.getLogger(__name__).setLevel( + train_cfg.python_log_level.upper(), + ) # Train script + + _initialize_dist_with_barrier(dist_timeout=train_cfg.dist_timeout) + # Filter deprecation warning from torch internal usage warnings.filterwarnings( action='ignore', @@ -181,22 +233,16 @@ def main(cfg: DictConfig) -> Trainer: ) # Check for incompatibilities between the model and data loaders - validate_config(cfg) - - # Resolve all interpolation variables as early as possible - om.resolve(cfg) - - # Create copy of config for logging - logged_cfg: DictConfig = copy.deepcopy(cfg) + validate_config(train_cfg) cuda_alloc_conf = [] # Get max split size mb - max_split_size_mb: Optional[int] = cfg.pop('max_split_size_mb', None) + max_split_size_mb: Optional[int] = train_cfg.max_split_size_mb if max_split_size_mb is not None: cuda_alloc_conf.append(f'max_split_size_mb:{max_split_size_mb}') # Expandable segments - if cfg.pop('expandable_segments', False): + if train_cfg.expandable_segments: cuda_alloc_conf.append('expandable_segments:True') if len(cuda_alloc_conf) > 0: @@ -204,317 +250,48 @@ def main(cfg: DictConfig) -> Trainer: # Set CUDA lazy loading # This can save a bit of memory if not all modules are needed - cuda_load_lazy: bool = cfg.pop('cuda_load_lazy', False) + cuda_load_lazy: bool = train_cfg.cuda_load_lazy if cuda_load_lazy: os.environ['CUDA_MODULE_LOADING'] = 'LAZY' # Set seed first - seed: int = pop_config(cfg, 'seed', must_exist=True) + seed: int = train_cfg.seed reproducibility.seed_all(seed) - # Initialize pytorch distributed training process groups - dist_timeout: Union[int, float] = pop_config( - cfg, - 'dist_timeout', - must_exist=False, - default_value=600.0, - ) - python_log_level: Optional[str] = pop_config( - cfg, - 'python_log_level', - must_exist=False, - default_value='debug', - ) - # Set logging level - if python_log_level is not None: - logging.basicConfig( - # Example of format string - # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here - format= - f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s', - ) - logging.getLogger('llmfoundry').setLevel( - python_log_level.upper(), - ) # Foundry module - logging.getLogger(__name__).setLevel( - python_log_level.upper(), - ) # Train script - - _initialize_dist_with_barrier(dist_timeout=dist_timeout) - - # Get global and device batch size information from distributed/single node setting - cfg = update_batch_size_info(cfg) - logged_cfg.update(cfg, merge=True) - # Mandatory model training configs - model_config: DictConfig = pop_config(cfg, 'model', must_exist=True) - tokenizer_config: Dict[ - str, Any] = pop_config(cfg, 'tokenizer', must_exist=True, convert=True) - optimizer_config: Dict[ - str, Any] = pop_config(cfg, 'optimizer', must_exist=True, convert=True) - scheduler_config: Dict[ - str, Any] = pop_config(cfg, 'scheduler', must_exist=True, convert=True) - train_loader_config: DictConfig = pop_config( - cfg, - 'train_loader', - must_exist=True, - ) + model_config = train_cfg.model + train_loader_config = train_cfg.train_loader # Optional fsdp data, fine-tuning, and eval configs - fsdp_config: Optional[Dict[str, Any]] = pop_config( - cfg, - 'fsdp_config', - must_exist=False, - default_value=None, - convert=True, - ) - eval_loader_config: Optional[ - Union[DictConfig, ListConfig] - ] = pop_config(cfg, 'eval_loader', must_exist=False, default_value=None) - icl_tasks_config: Optional[ - Union[ListConfig, str] - ] = pop_config(cfg, 'icl_tasks', must_exist=False, default_value=None) - if icl_tasks_config is not None: - seq_parallel_replication = train_loader_config.dataset.get( - 'seq_parallel_replication', - None, - ) - if seq_parallel_replication is not None and seq_parallel_replication != 1: - raise ValueError( - 'icl eval tasks are not supported with sequence parallelism', - ) - eval_gauntlet_config: Optional[ - Union[DictConfig, str] - ] = pop_config(cfg, 'eval_gauntlet', must_exist=False, default_value=None) - icl_subset_num_batches: Optional[int] = pop_config( - cfg, - 'icl_subset_num_batches', - must_exist=False, - default_value=None, - ) - icl_seq_len: Optional[int] = pop_config( - cfg, - 'icl_seq_len', - must_exist=False, - default_value=None, - ) - # Optional logging, evaluation and callback configs - logger_configs: Optional[DictConfig] = pop_config( - cfg, - 'loggers', - must_exist=False, - default_value=None, - convert=True, - ) - callback_configs: Optional[DictConfig] = pop_config( - cfg, - 'callbacks', - must_exist=False, - default_value=None, - convert=True, - ) - algorithm_configs: Optional[DictConfig] = pop_config( - cfg, - 'algorithms', - must_exist=False, - default_value=None, - ) + fsdp_config: Optional[Dict[str, Any]] = train_cfg.fsdp_config - # Mandatory hyperparameters for training - device_train_batch_size: int = pop_config( - cfg, - 'device_train_batch_size', - must_exist=True, - ) - device_eval_batch_size: int = pop_config( - cfg, - 'device_eval_batch_size', - must_exist=True, - ) - max_duration: Union[int, - str] = pop_config(cfg, 'max_duration', must_exist=True) - eval_interval: Union[int, str] = pop_config( - cfg, - 'eval_interval', - default_value=1, - must_exist=False, - ) - precision: str = pop_config(cfg, 'precision', must_exist=True) - max_seq_len: int = pop_config(cfg, 'max_seq_len', must_exist=True) + eval_loader_config = train_cfg.eval_loader if train_cfg.eval_loader is not None else train_cfg.eval_loaders + icl_tasks_config = train_cfg.icl_tasks + eval_gauntlet_config = train_cfg.eval_gauntlet # Optional parameters will be set to default values if not specified. default_run_name: str = os.environ.get('RUN_NAME', 'llm') - run_name: str = pop_config( - cfg, - 'run_name', - must_exist=False, - default_value=default_run_name, - ) - save_folder: Optional[str] = pop_config( - cfg, - 'save_folder', - must_exist=False, - default_value=None, - ) + run_name: str = train_cfg.run_name if train_cfg.run_name else default_run_name is_state_dict_sharded: bool = ( fsdp_config.get('state_dict_type', 'full') == 'sharded' ) if fsdp_config else False - save_latest_filename: str = pop_config( - cfg, - 'save_latest_filename', - must_exist=False, - default_value='latest-sharded-rank{rank}' - if is_state_dict_sharded else 'latest-rank{rank}.pt', - ) - save_overwrite: bool = pop_config( - cfg, - 'save_overwrite', - must_exist=False, - default_value=False, - ) - save_weights_only: bool = pop_config( - cfg, - 'save_weights_only', - must_exist=False, - default_value=False, - ) - save_filename: str = pop_config( - cfg, - 'save_filename', - must_exist=False, - default_value='ep{epoch}-ba{batch}-rank{rank}.pt', - ) - save_interval: Union[str, int] = pop_config( - cfg, - 'save_interval', - must_exist=False, - default_value='1000ba', - ) - save_num_checkpoints_to_keep: int = pop_config( - cfg, - 'save_num_checkpoints_to_keep', - must_exist=False, - default_value=-1, - ) - progress_bar = pop_config( - cfg, - 'progress_bar', - must_exist=False, - default_value=False, - ) - log_to_console: bool = pop_config( - cfg, - 'log_to_console', - must_exist=False, - default_value=True, - ) - console_log_interval: Union[int, str] = pop_config( - cfg, - 'console_log_interval', - must_exist=False, - default_value='1ba', - ) - device_train_microbatch_size: Union[str, int] = pop_config( - cfg, - 'device_train_microbatch_size', - must_exist=False, - default_value='auto', - ) - eval_subset_num_batches: int = pop_config( - cfg, - 'eval_subset_num_batches', - must_exist=False, - default_value=-1, - ) - eval_first: bool = pop_config( - cfg, - 'eval_first', - must_exist=False, - default_value=False, - ) - load_path: str = pop_config( - cfg, - 'load_path', - must_exist=False, - default_value=None, - ) - load_weights_only: bool = pop_config( - cfg, - 'load_weights_only', - must_exist=False, - default_value=False, - ) - load_strict_model_weights: bool = pop_config( - cfg, - 'load_strict_model_weights', - must_exist=False, - default_value=True, - ) - load_ignore_keys: Optional[List[str]] = pop_config( - cfg, - 'load_ignore_keys', - must_exist=False, - default_value=None, - ) - save_ignore_keys: Optional[List[str]] = pop_config( - cfg, - 'save_ignore_keys', - must_exist=False, - default_value=None, - ) - compile_config: Optional[ - Dict[str, Any] - ] = pop_config(cfg, 'compile_config', must_exist=False, default_value=None) - metadata: Optional[Dict[str, str]] = pop_config( - cfg, - 'metadata', - must_exist=False, - default_value=None, - convert=True, - ) - should_log_config: bool = pop_config( - cfg, - 'log_config', - must_exist=False, - default_value=True, - ) + save_latest_filename: str = train_cfg.save_latest_filename if train_cfg.save_latest_filename else 'latest-sharded-rank{rank}' if is_state_dict_sharded else 'latest-rank{rank}.pt' + save_filename: str = train_cfg.save_filename if train_cfg.save_filename else 'ep{epoch}-ba{batch}-rank{rank}.pt' # Enable autoresume from model checkpoints if possible autoresume_default: bool = False if logged_cfg.get('run_name', None) is not None \ - and save_folder is not None \ - and not save_overwrite \ - and not save_weights_only: + and train_cfg.save_folder is not None \ + and not train_cfg.save_overwrite \ + and not train_cfg.save_weights_only: autoresume_default = True - if cfg.get('autoresume') is None and autoresume_default: + if not train_cfg.autoresume and autoresume_default: log.info( 'As run_name, save_folder, and save_latest_filename are set, \ changing autoresume default to True...', ) - autoresume: bool = pop_config( - cfg, - 'autoresume', - must_exist=False, - default_value=autoresume_default, - ) - - # Pop known unused parameters that are used as interpolation variables or - # created by update_batch_size_info. - pop_config(cfg, 'data_local', must_exist=False) - pop_config(cfg, 'data_remote', must_exist=False) - pop_config(cfg, 'global_seed', must_exist=False) - pop_config(cfg, 'global_train_batch_size', must_exist=False) - pop_config(cfg, 'n_gpus', must_exist=False) - pop_config(cfg, 'device_train_grad_accum', must_exist=False) - - # Warn users for unused parameters - for key in cfg: - warnings.warn( - f'Unused parameter {key} found in cfg. Please check your yaml to ensure this parameter is necessary.', - ) - # Warn if fsdp is enabled but user only has 1 GPU if dist.get_world_size() == 1 and fsdp_config is not None: warnings.warn( @@ -528,19 +305,19 @@ def main(cfg: DictConfig) -> Trainer: # Build tokenizer log.info('Building tokenizer...') - tokenizer_name = tokenizer_config['name'] - tokenizer_kwargs = tokenizer_config.get('kwargs', {}) + tokenizer_name = train_cfg.tokenizer['name'] + tokenizer_kwargs = train_cfg.tokenizer.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) # Scheduler - scheduler_name: str = scheduler_config.pop('name') - scheduler = build_scheduler(scheduler_name, scheduler_config) + scheduler_name: str = train_cfg.scheduler.pop('name') + scheduler = build_scheduler(scheduler_name, train_cfg.scheduler) # Loggers loggers = [ build_logger(str(name), logger_cfg) - for name, logger_cfg in logger_configs.items() - ] if logger_configs else [] + for name, logger_cfg in train_cfg.loggers.items() + ] if train_cfg.loggers else [] mosaicml_logger = find_mosaicml_logger(loggers) if mosaicml_logger is None: @@ -549,29 +326,22 @@ def main(cfg: DictConfig) -> Trainer: # mosaicml_logger will be None if run isn't on MosaicML platform loggers.append(mosaicml_logger) - if metadata is not None: + if train_cfg.metadata is not None: # Flatten the metadata for logging logged_cfg.pop('metadata', None) - logged_cfg.update(metadata, merge=True) + logged_cfg.update(train_cfg.metadata, merge=True) if mosaicml_logger is not None: - mosaicml_logger.log_metrics(metadata) + mosaicml_logger.log_metrics(train_cfg.metadata) mosaicml_logger._flush_metadata(force_flush=True) # Profiling profiler: Optional[Profiler] = None - profiler_cfg: Optional[DictConfig] = pop_config( - cfg, - 'profiler', - must_exist=False, - convert=False, - default_value=None, - ) + profiler_cfg = train_cfg.profiler if profiler_cfg: profiler_schedule_cfg: Dict = pop_config( profiler_cfg, 'schedule', must_exist=True, - convert=True, ) profiler_schedule = cyclic_schedule(**profiler_schedule_cfg) # Only support json trace handler @@ -581,7 +351,6 @@ def main(cfg: DictConfig) -> Trainer: 'json_trace_handler', must_exist=False, default_value=None, - convert=True, ) if profiler_trace_cfg: profiler_trace_handlers.append( @@ -593,22 +362,26 @@ def main(cfg: DictConfig) -> Trainer: schedule=profiler_schedule, ) + callback_configs = train_cfg.callbacks or {} + # Callbacks callbacks: List[Callback] = [ build_callback( name=str(name), kwargs=callback_cfg, - train_config=om.to_container(logged_cfg), + train_config=logged_cfg, ) for name, callback_cfg in callback_configs.items() - ] if callback_configs else [] + ] use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks) + algorithm_configs = train_cfg.algorithms or {} + # Algorithms algorithms = [ build_algorithm(str(name), algorithm_cfg) for name, algorithm_cfg in algorithm_configs.items() - ] if algorithm_configs else None + ] # Dataloaders log.info('Building train loader...') @@ -616,7 +389,7 @@ def main(cfg: DictConfig) -> Trainer: train_loader = build_dataloader( train_loader_config, tokenizer, - device_train_batch_size, + train_cfg.device_train_batch_size, ) except Exception as e: if mosaicml_logger is not None: @@ -629,23 +402,23 @@ def main(cfg: DictConfig) -> Trainer: ## Evaluation if use_async_eval: evaluators = [] - if eval_first: + if train_cfg.eval_first: warnings.warn( 'AsyncEval callback does not support eval_first=True. Ignoring.', ) - eval_first = False + train_cfg.eval_first = False else: log.info('Building eval loader...') - eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + eval_icl_seq_len: int = train_cfg.icl_seq_len if train_cfg.icl_seq_len else train_cfg.max_seq_len evaluators, _, eval_gauntlet_callback = build_evaluators( eval_loader_config, icl_tasks_config, eval_gauntlet_config, tokenizer=tokenizer, - device_eval_batch_size=device_eval_batch_size, + device_eval_batch_size=train_cfg.device_eval_batch_size, icl_seq_len=eval_icl_seq_len, - icl_subset_num_batches=icl_subset_num_batches, + icl_subset_num_batches=train_cfg.icl_subset_num_batches, ) if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) @@ -656,44 +429,31 @@ def main(cfg: DictConfig) -> Trainer: model_config, train_loader_config, eval_loader_config, - callback_configs, + train_cfg.callbacks, tokenizer_name, - load_path, + train_cfg.load_path, icl_tasks_config, eval_gauntlet_config, ) # Build Model log.info('Initializing model...') + name = model_config.pop('name') + assert isinstance(name, str) + assert isinstance(model_config, dict) model = build_composer_model( - name=model_config.name, - cfg=model_config, + name=name, tokenizer=tokenizer, init_context=init_context, master_weights_dtype=model_config.get('master_weights_dtype', None), + cfg=model_config, ) - # Log number of parameters - if hasattr(model, 'n_total_params'): - n_params = model.n_total_params - n_trainable_params = n_params # TODO: we currently assume all parameters are trainable. - else: - n_params = sum(p.numel() for p in model.parameters()) - n_trainable_params = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - if hasattr(model, 'n_active_params'): - n_active_params = model.n_active_params - else: - n_active_params = n_params - logged_cfg.update({ - 'n_params': n_params, - 'n_active_params': n_active_params, - 'n_trainable_params': n_trainable_params, - }) + _log_num_params(model, logged_cfg) # Optimizer - optimizer_name: str = optimizer_config.pop('name') - optimizer = build_optimizer(model, optimizer_name, optimizer_config) + optimizer_name: str = train_cfg.optimizer.pop('name') + optimizer_cfg = train_cfg.optimizer + optimizer = build_optimizer(model, optimizer_name, optimizer_cfg) # Now add the eval metrics try: @@ -712,6 +472,7 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger.log_exception(e) raise e + compile_config = train_cfg.compile_config # Build the Trainer log.info('Building trainer...') trainer = Trainer( @@ -722,45 +483,45 @@ def main(cfg: DictConfig) -> Trainer: eval_dataloader=evaluators, optimizers=optimizer, schedulers=scheduler, - max_duration=max_duration, - eval_interval=eval_interval, - eval_subset_num_batches=eval_subset_num_batches, - progress_bar=progress_bar, - log_to_console=log_to_console, - console_log_interval=console_log_interval, + max_duration=train_cfg.max_duration, + eval_interval=train_cfg.eval_interval, + eval_subset_num_batches=train_cfg.eval_subset_num_batches, + progress_bar=train_cfg.progress_bar, + log_to_console=train_cfg.log_to_console, + console_log_interval=train_cfg.console_log_interval, loggers=loggers, callbacks=callbacks, - precision=precision, + precision=train_cfg.precision, algorithms=algorithms, - device_train_microbatch_size=device_train_microbatch_size, + device_train_microbatch_size=train_cfg.device_train_microbatch_size, fsdp_config=fsdp_config, - save_folder=save_folder, + save_folder=train_cfg.save_folder, save_filename=save_filename, save_latest_filename=save_latest_filename, - save_interval=save_interval, - save_num_checkpoints_to_keep=save_num_checkpoints_to_keep, - save_overwrite=save_overwrite, - save_weights_only=save_weights_only, - load_path=load_path, - load_weights_only=load_weights_only, - load_strict_model_weights=load_strict_model_weights, - load_ignore_keys=load_ignore_keys, - save_ignore_keys=save_ignore_keys, - autoresume=autoresume, - python_log_level=python_log_level, - dist_timeout=dist_timeout, + save_interval=train_cfg.save_interval, + save_num_checkpoints_to_keep=train_cfg.save_num_checkpoints_to_keep, + save_overwrite=train_cfg.save_overwrite, + save_weights_only=train_cfg.save_weights_only, + load_path=train_cfg.load_path, + load_weights_only=train_cfg.load_weights_only, + load_strict_model_weights=train_cfg.load_strict_model_weights, + load_ignore_keys=train_cfg.load_ignore_keys, + save_ignore_keys=train_cfg.save_ignore_keys, + autoresume=train_cfg.autoresume, + python_log_level=train_cfg.python_log_level, + dist_timeout=train_cfg.dist_timeout, profiler=profiler, compile_config=compile_config, ) - if should_log_config: + if train_cfg.log_config: log.info('Logging config') log_config(logged_cfg) torch.cuda.empty_cache() gc.collect() # Eval first if requested - if eval_first and trainer.state.timestamp.batch.value == 0: + if train_cfg.eval_first and trainer.state.timestamp.batch.value == 0: trainer.eval() log.info('Starting training...') diff --git a/scripts/train/yamls/finetune/1b_local_data_sft.yaml b/scripts/train/yamls/finetune/1b_local_data_sft.yaml index 46141ce5ab..c977bd2945 100644 --- a/scripts/train/yamls/finetune/1b_local_data_sft.yaml +++ b/scripts/train/yamls/finetune/1b_local_data_sft.yaml @@ -4,8 +4,11 @@ # This is not the right YAML if you are trying to finetune a HuggingFace pretrained model. # ############################################################################################ -max_seq_len: 2048 -global_seed: 17 +variables: + global_seed: 17 + max_seq_len: 2048 + +max_seq_len: ${variables.max_seq_len} # Run Name run_name: # If left blank, will be read from env var $RUN_NAME @@ -19,7 +22,7 @@ model: n_heads: 16 # Modified 24->16 so that d_head == 128 to satisfy FlashAttention n_layers: 24 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -30,7 +33,7 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Local data to load into huggingface datasets dataset: &hf_dataset @@ -45,7 +48,7 @@ train_loader: &train_loader dataset: <<: *hf_dataset split: train - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} allow_pad_trimming: false decoder_only_format: true shuffle: true @@ -97,7 +100,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 128 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/finetune/7b_dolly_sft.yaml b/scripts/train/yamls/finetune/7b_dolly_sft.yaml index 024362299a..f9edba3716 100644 --- a/scripts/train/yamls/finetune/7b_dolly_sft.yaml +++ b/scripts/train/yamls/finetune/7b_dolly_sft.yaml @@ -4,11 +4,15 @@ # This is not the right YAML if you are trying to finetune a HuggingFace pretrained model. # ############################################################################################ -max_seq_len: 2048 -global_seed: 17 +variables: + global_seed: 17 + max_seq_len: 2048 + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} # Run Name -run_name: # If left blank, will be read from env var $RUN_NAME +run_name: ${variables.run_name} # Model model: @@ -18,7 +22,7 @@ model: n_heads: 32 n_layers: 32 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -29,7 +33,7 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: @@ -37,7 +41,7 @@ train_loader: dataset: hf_name: HuggingFaceH4/databricks_dolly_15k split: train - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} allow_pad_trimming: false decoder_only_format: true shuffle: true @@ -85,7 +89,7 @@ eval_interval: 1 # this is the only allowed value for no eval global_train_batch_size: 64 # assuming 8 gpus # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/finetune/dbrx-full-ft.yaml b/scripts/train/yamls/finetune/dbrx-full-ft.yaml index 9cb53e40fd..24fd4cb126 100644 --- a/scripts/train/yamls/finetune/dbrx-full-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-full-ft.yaml @@ -1,9 +1,14 @@ -# Note: This requires ~64x80GB GPUs -max_seq_len: 4096 -icl_seq_len: 1024 +variables: + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Note: This requires ~64x80GB GPUs + max_seq_len: 4096 + icl_seq_len: 1024 + +run_name: ${variables.run_name} +max_seq_len: ${variables.max_seq_len} +icl_seq_len: ${variables.icl_seq_len} # Model model: @@ -19,7 +24,7 @@ model: tokenizer: name: databricks/dbrx-instruct kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} trust_remote_code: true # Dataloaders @@ -29,7 +34,7 @@ train_loader: split: train hf_name: mosaicml/dolly_hhrlhf shuffle: true - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} eos_token_id: 0 packing_ratio: auto allow_pad_trimming: false @@ -46,7 +51,7 @@ eval_loader: split: test hf_name: mosaicml/dolly_hhrlhf shuffle: false - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} packing_ratio: null allow_pad_trimming: false decoder_only_format: true diff --git a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml index 06e8f1d6f0..87950134bd 100644 --- a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml +++ b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml @@ -1,9 +1,15 @@ -# Note: This requires ~16x80GB GPUs -max_seq_len: 4096 -icl_seq_len: 1024 +variables: + # Note: This requires ~16x80GB GPUs + icl_seq_len: 1024 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + max_seq_len: 4096 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +icl_seq_len: ${variables.icl_seq_len} +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -27,7 +33,7 @@ model: tokenizer: name: databricks/dbrx-instruct kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} trust_remote_code: true # Dataloaders @@ -37,7 +43,7 @@ train_loader: split: train hf_name: mosaicml/dolly_hhrlhf shuffle: true - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} eos_token_id: 0 packing_ratio: auto allow_pad_trimming: false @@ -54,7 +60,7 @@ eval_loader: split: test hf_name: mosaicml/dolly_hhrlhf shuffle: false - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} packing_ratio: null allow_pad_trimming: false decoder_only_format: true diff --git a/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml b/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml index 4047256614..95a70acfd7 100644 --- a/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml +++ b/scripts/train/yamls/finetune/gpt2-arc-easy-cpu-streaming-dataset.yaml @@ -1,11 +1,16 @@ -max_seq_len: 512 -global_seed: 17 +variables: + global_seed: 17 -data_local: ./my_data -data_remote: # If blank, files must be present in data_local + max_seq_len: 512 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + data_local: ./my_data + data_remote: # If blank, files must be present in data_local + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -17,7 +22,7 @@ model: tokenizer: name: gpt2 kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: @@ -26,12 +31,12 @@ train_loader: ############ streams: my_data: - remote: ${data_remote} - local: ${data_local} + remote: ${variables.data_remote} + local: ${variables.data_local} split: train ############ shuffle: true - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} decoder_only_format: true drop_last: true num_workers: 8 @@ -63,7 +68,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 8 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/finetune/mpt-30b-instruct.yaml b/scripts/train/yamls/finetune/mpt-30b-instruct.yaml index 226f96230e..3ef41e0aa1 100644 --- a/scripts/train/yamls/finetune/mpt-30b-instruct.yaml +++ b/scripts/train/yamls/finetune/mpt-30b-instruct.yaml @@ -1,9 +1,16 @@ -tokenizer_name: mosaicml/mpt-30b -max_seq_len: 8192 -global_seed: 17 +variables: + tokenizer_name: mosaicml/mpt-30b + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $COMPOSER_RUN_NAME + max_seq_len: 8192 + + # Run Name + run_name: # If left blank, will be read from env var $COMPOSER_RUN_NAME + + icl_max_seq_len: 2048 + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -12,7 +19,7 @@ model: pretrained_model_name_or_path: mosaicml/mpt-30b init_device: mixed config_overrides: - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} attn_config: attn_impl: flash # Note: we still use packing, but turn this off for memory. @@ -21,9 +28,9 @@ model: # Tokenizer tokenizer: - name: ${tokenizer_name} + name: ${variables.tokenizer_name} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: @@ -31,7 +38,7 @@ train_loader: dataset: hf_name: mosaicml/instruct-v3 split: train - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} allow_pad_trimming: false decoder_only_format: true packing_ratio: 9 @@ -48,7 +55,7 @@ eval_loader: dataset: hf_name: mosaicml/instruct-v3 split: test - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} allow_pad_trimming: false decoder_only_format: true packing_ratio: 9 @@ -87,7 +94,7 @@ eval_first: true global_train_batch_size: 72 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 1 precision: amp_bf16 @@ -120,7 +127,7 @@ callbacks: # save_interval: 3ep # save_num_checkpoints_to_keep: 1 -icl_max_seq_len: 2048 +icl_max_seq_len: ${variables.icl_max_seq_len} # YOU MUST ADD YOUR OWN DATASET URIs # this section can be removed if you do not want to track these metrics @@ -131,7 +138,7 @@ icl_tasks: num_fewshot: - 0 batch_size: 4 - max_seq_len: ${icl_max_seq_len} + max_seq_len: ${variables.icl_max_seq_len} icl_task_type: multiple_choice metric_names: - InContextLearningMultipleChoiceAccuracy @@ -144,7 +151,7 @@ icl_tasks: num_fewshot: - 0 batch_size: 4 - max_seq_len: ${icl_max_seq_len} + max_seq_len: ${variables.icl_max_seq_len} icl_task_type: multiple_choice metric_names: - InContextLearningMultipleChoiceAccuracy @@ -157,7 +164,7 @@ icl_tasks: num_fewshot: - 0 batch_size: 4 - max_seq_len: ${icl_max_seq_len} + max_seq_len: ${variables.icl_max_seq_len} icl_task_type: multiple_choice metric_names: - InContextLearningMultipleChoiceAccuracy diff --git a/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml b/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml index a789c4b491..3634fd259e 100644 --- a/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml +++ b/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml @@ -1,15 +1,20 @@ -max_seq_len: 2048 -global_seed: 17 +variables: + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + max_seq_len: 2048 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} model: name: hf_causal_lm pretrained: true pretrained_model_name_or_path: mosaicml/mpt-7b config_overrides: - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} attn_config: attn_impl: flash # Set this to `true` if using `train_loader.dataset.packing_ratio` below @@ -19,7 +24,7 @@ model: tokenizer: name: mosaicml/mpt-7b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders @@ -28,7 +33,7 @@ train_loader: dataset: hf_name: mosaicml/dolly_hhrlhf split: train - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} allow_pad_trimming: false decoder_only_format: true # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with @@ -51,7 +56,7 @@ eval_loader: dataset: hf_name: mosaicml/dolly_hhrlhf split: test - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} allow_pad_trimming: false decoder_only_format: true # packing_ratio: @@ -91,7 +96,7 @@ eval_first: true global_train_batch_size: 48 # somewhere in the 6-8 * numgpus range seems good # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/finetune/mpt-7b_domain_adapt.yaml b/scripts/train/yamls/finetune/mpt-7b_domain_adapt.yaml index 49a70e97f2..9357ef7771 100644 --- a/scripts/train/yamls/finetune/mpt-7b_domain_adapt.yaml +++ b/scripts/train/yamls/finetune/mpt-7b_domain_adapt.yaml @@ -1,7 +1,12 @@ -data_local: ./my-adaptation-data -data_remote: # If blank, files must be present in data_local -max_seq_len: 4096 -global_seed: 17 +variables: + global_seed: 17 + + data_local: ./my-adaptation-data + data_remote: # If blank, files must be present in data_local + + max_seq_len: 4096 + +max_seq_len: ${variables.max_seq_len} # Run Name run_name: # If left blank, will be read from env var $COMPOSER_RUN_NAME @@ -12,7 +17,7 @@ model: pretrained: true pretrained_model_name_or_path: mosaicml/mpt-7b config_overrides: - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} attn_config: attn_impl: flash attn_uses_sequence_id: false @@ -21,31 +26,31 @@ model: tokenizer: name: mosaicml/mpt-7b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train_small shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val_small shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -76,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 1024 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/finetune/t5-small_dolly_sft.yaml b/scripts/train/yamls/finetune/t5-small_dolly_sft.yaml index 2264e359e0..d394018cfc 100644 --- a/scripts/train/yamls/finetune/t5-small_dolly_sft.yaml +++ b/scripts/train/yamls/finetune/t5-small_dolly_sft.yaml @@ -1,19 +1,24 @@ -max_seq_len: 1024 -global_seed: 17 -model_name: t5-small +variables: + global_seed: 17 + model_name: t5-small -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + max_seq_len: 1024 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: name: hf_t5 - pretrained_model_name_or_path: ${model_name} + pretrained_model_name_or_path: ${variables.model_name} pretrained: true # Tokenizer tokenizer: - name: ${model_name} + name: ${variables.model_name} # Dataloaders train_loader: @@ -21,7 +26,7 @@ train_loader: dataset: hf_name: HuggingFaceH4/databricks_dolly_15k split: train - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} allow_pad_trimming: false decoder_only_format: false shuffle: true @@ -62,7 +67,7 @@ eval_interval: 1 # this is the only allowed value for no eval global_train_batch_size: 64 # assuming 8 gpus # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/gpt-neo-125m.yaml b/scripts/train/yamls/pretrain/gpt-neo-125m.yaml index 2791acc935..5f02ba47e6 100644 --- a/scripts/train/yamls/pretrain/gpt-neo-125m.yaml +++ b/scripts/train/yamls/pretrain/gpt-neo-125m.yaml @@ -1,14 +1,19 @@ # Pretrain a gpt-neo-125m style model # this is NOT a finetuning run -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -tokenizer_name: gpt2 -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + tokenizer_name: gpt2 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + max_seq_len: 2048 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -17,37 +22,37 @@ model: config_overrides: # WARNING: if setting `pretrained: true`, `max_position_embeddings` must match the # `max_position_embeddings` used during pre-training - max_position_embeddings: ${max_seq_len} + max_position_embeddings: ${variables.max_seq_len} pretrained: false # false: only use the architecture; true: initialize with pretrained weights # Tokenizer tokenizer: - name: ${tokenizer_name} + name: ${variables.tokenizer_name} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -78,7 +83,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml b/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml index b2d71ad762..fe9828b50a 100644 --- a/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml +++ b/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml @@ -1,14 +1,19 @@ # Pretrain a gpt-neo-125m style model # this is NOT a finetuning run -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -tokenizer_name: EleutherAI/gpt-neo-125M -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + tokenizer_name: EleutherAI/gpt-neo-125M + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + max_seq_len: 2048 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -17,37 +22,37 @@ model: config_overrides: # WARNING: if setting `pretrained: true`, `max_position_embeddings` must match the # `max_position_embeddings` used during pre-training - max_position_embeddings: ${max_seq_len} + max_position_embeddings: ${variables.max_seq_len} pretrained: false # false: only use the architecture; true: initialize with pretrained weights # Tokenizer tokenizer: - name: ${tokenizer_name} + name: ${variables.tokenizer_name} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -78,7 +83,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/gpt2-small.yaml b/scripts/train/yamls/pretrain/gpt2-small.yaml index 52d0f8cb73..458f6869da 100644 --- a/scripts/train/yamls/pretrain/gpt2-small.yaml +++ b/scripts/train/yamls/pretrain/gpt2-small.yaml @@ -1,14 +1,19 @@ # Pretrain a gpt2 style model # this is NOT a finetuning run -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -tokenizer_name: gpt2 -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + tokenizer_name: gpt2 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + max_seq_len: 2048 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -17,37 +22,37 @@ model: config_overrides: # WARNING: if setting `pretrained: true`, `max_position_embeddings` must match the # `max_position_embeddings` used during pre-training - n_positions: ${max_seq_len} + n_positions: ${variables.max_seq_len} pretrained: false # false: only use the architecture; true: initialize with pretrained weights # Tokenizer tokenizer: - name: ${tokenizer_name} + name: ${variables.tokenizer_name} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -78,7 +83,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-125m.yaml b/scripts/train/yamls/pretrain/mpt-125m.yaml index 78dc789e7d..644dfc26c1 100644 --- a/scripts/train/yamls/pretrain/mpt-125m.yaml +++ b/scripts/train/yamls/pretrain/mpt-125m.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 12 n_layers: 12 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 16 device_train_microbatch_size: 16 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-13b.yaml b/scripts/train/yamls/pretrain/mpt-13b.yaml index 782c01f1f0..41002bb45d 100644 --- a/scripts/train/yamls/pretrain/mpt-13b.yaml +++ b/scripts/train/yamls/pretrain/mpt-13b.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 40 n_layers: 40 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 1024 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-1b.yaml b/scripts/train/yamls/pretrain/mpt-1b.yaml index effa60c59e..93b2a58a42 100644 --- a/scripts/train/yamls/pretrain/mpt-1b.yaml +++ b/scripts/train/yamls/pretrain/mpt-1b.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 16 # Modified 24->16 so that d_head == 128 to satisfy FlashAttention n_layers: 24 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 512 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-30b.yaml b/scripts/train/yamls/pretrain/mpt-30b.yaml index 6b82407c63..3627c36dd0 100644 --- a/scripts/train/yamls/pretrain/mpt-30b.yaml +++ b/scripts/train/yamls/pretrain/mpt-30b.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 56 n_layers: 48 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 2048 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-350m.yaml b/scripts/train/yamls/pretrain/mpt-350m.yaml index 63bc6169a1..ebe8da715f 100644 --- a/scripts/train/yamls/pretrain/mpt-350m.yaml +++ b/scripts/train/yamls/pretrain/mpt-350m.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 16 n_layers: 24 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-3b.yaml b/scripts/train/yamls/pretrain/mpt-3b.yaml index 74d422398d..615f59ee3f 100644 --- a/scripts/train/yamls/pretrain/mpt-3b.yaml +++ b/scripts/train/yamls/pretrain/mpt-3b.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 20 # Modified 32->20 so that d_head == 128 to statisfy FlashAttention n_layers: 32 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 512 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-70b.yaml b/scripts/train/yamls/pretrain/mpt-70b.yaml index 8e6856ceb8..55450a8bfc 100644 --- a/scripts/train/yamls/pretrain/mpt-70b.yaml +++ b/scripts/train/yamls/pretrain/mpt-70b.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 64 n_layers: 80 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 2048 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-760m.yaml b/scripts/train/yamls/pretrain/mpt-760m.yaml index f11f199036..5c1f0bdbdc 100644 --- a/scripts/train/yamls/pretrain/mpt-760m.yaml +++ b/scripts/train/yamls/pretrain/mpt-760m.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 12 # Modified 16->12 so that d_head == 128 to statisfy FlashAttention n_layers: 24 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-7b.yaml b/scripts/train/yamls/pretrain/mpt-7b.yaml index 831383168f..b97f3f2c9e 100644 --- a/scripts/train/yamls/pretrain/mpt-7b.yaml +++ b/scripts/train/yamls/pretrain/mpt-7b.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 2048 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 32 n_layers: 32 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: flash @@ -23,30 +27,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -77,7 +81,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 1024 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 8 device_train_microbatch_size: 8 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/mpt-small-cpu.yaml b/scripts/train/yamls/pretrain/mpt-small-cpu.yaml index 1f50c68a74..b579723002 100644 --- a/scripts/train/yamls/pretrain/mpt-small-cpu.yaml +++ b/scripts/train/yamls/pretrain/mpt-small-cpu.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 128 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 128 + global_seed: 17 -# Run Name -run_name: mpt_causal_lm_cpu # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: mpt_causal_lm_cpu # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 4 n_layers: 4 expansion_ratio: 5 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: torch @@ -24,30 +28,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 2 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 2 @@ -79,7 +83,7 @@ global_train_batch_size: 256 autoresume: false # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 16 device_train_microbatch_size: 16 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/opt-3b.yaml b/scripts/train/yamls/pretrain/opt-3b.yaml index 65b73257c2..31b7bf255b 100644 --- a/scripts/train/yamls/pretrain/opt-3b.yaml +++ b/scripts/train/yamls/pretrain/opt-3b.yaml @@ -1,11 +1,15 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -tokenizer_name: facebook/opt-2.7b -max_seq_len: 256 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + tokenizer_name: facebook/opt-2.7b + max_seq_len: 256 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -15,32 +19,32 @@ model: # Tokenizer tokenizer: - name: ${tokenizer_name} + name: ${variables.tokenizer_name} kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -71,7 +75,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 4 device_train_microbatch_size: 4 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/testing-moe.yaml b/scripts/train/yamls/pretrain/testing-moe.yaml index eea2b999b7..e61e3e451e 100644 --- a/scripts/train/yamls/pretrain/testing-moe.yaml +++ b/scripts/train/yamls/pretrain/testing-moe.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 128 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 128 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -24,7 +28,7 @@ model: n_heads: 2 n_layers: 2 expansion_ratio: 1 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: torch @@ -34,30 +38,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -88,7 +92,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 16 device_train_microbatch_size: 16 # device_train_microbatch_size: auto diff --git a/scripts/train/yamls/pretrain/testing.yaml b/scripts/train/yamls/pretrain/testing.yaml index 01ebecafe2..2271be5d6d 100644 --- a/scripts/train/yamls/pretrain/testing.yaml +++ b/scripts/train/yamls/pretrain/testing.yaml @@ -1,10 +1,14 @@ -data_local: ./my-copy-c4 -data_remote: # If blank, files must be present in data_local -max_seq_len: 128 -global_seed: 17 +variables: + data_local: ./my-copy-c4 + data_remote: # If blank, files must be present in data_local + max_seq_len: 128 + global_seed: 17 -# Run Name -run_name: # If left blank, will be read from env var $RUN_NAME + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} # Model model: @@ -14,7 +18,7 @@ model: n_heads: 2 n_layers: 2 expansion_ratio: 4 - max_seq_len: ${max_seq_len} + max_seq_len: ${variables.max_seq_len} vocab_size: 50368 attn_config: attn_impl: torch @@ -24,30 +28,30 @@ model: tokenizer: name: EleutherAI/gpt-neox-20b kwargs: - model_max_length: ${max_seq_len} + model_max_length: ${variables.max_seq_len} # Dataloaders train_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: train shuffle: true - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: true num_workers: 8 eval_loader: name: text dataset: - local: ${data_local} - remote: ${data_remote} + local: ${variables.data_local} + remote: ${variables.data_remote} split: val shuffle: false - max_seq_len: ${max_seq_len} - shuffle_seed: ${global_seed} + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} drop_last: false num_workers: 8 @@ -78,7 +82,7 @@ eval_subset_num_batches: -1 global_train_batch_size: 256 # System -seed: ${global_seed} +seed: ${variables.global_seed} device_eval_batch_size: 16 device_train_microbatch_size: 16 # device_train_microbatch_size: auto diff --git a/tests/a_scripts/eval/test_eval.py b/tests/a_scripts/eval/test_eval.py index 2c7f81a8b2..a56778538c 100644 --- a/tests/a_scripts/eval/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -13,12 +13,9 @@ from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils.config_utils import to_dict_container from scripts.eval.eval import main # noqa: E402 -from tests.data_utils import ( - create_arxiv_dataset, - create_c4_dataset_xxsmall, - gpt_tiny_cfg, -) +from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg @pytest.fixture(autouse=True) @@ -42,6 +39,7 @@ def eval_cfg(foundry_dir: str) -> Union[om.ListConfig, om.DictConfig]: @pytest.fixture() def mock_saved_model_path(eval_cfg: Union[om.ListConfig, om.DictConfig]): + eval_cfg = copy.deepcopy(eval_cfg) # copy config before modifying model_cfg = eval_cfg.models[0] # set device to cpu device = 'cpu' @@ -52,10 +50,11 @@ def mock_saved_model_path(eval_cfg: Union[om.ListConfig, om.DictConfig]): model_cfg.tokenizer.get('kwargs', {}), ) # build model + name = model_cfg.model.pop('name') model = build_composer_model( - name=model_cfg.model.name, - cfg=model_cfg.model, + name=name, tokenizer=tokenizer, + cfg=to_dict_container(model_cfg.model), ) # create mocked save checkpoint @@ -73,6 +72,7 @@ def test_icl_eval( capfd: Any, mock_saved_model_path: Any, ): + eval_cfg = copy.deepcopy(eval_cfg) eval_cfg.models[0].load_path = mock_saved_model_path assert isinstance(eval_cfg, om.DictConfig) main(eval_cfg) @@ -124,8 +124,6 @@ def test_loader_eval( first_eval_loader.label = 'c4' # Create second eval dataloader using the arxiv dataset. second_eval_loader = copy.deepcopy(first_eval_loader) - arxiv_dataset_name = create_arxiv_dataset(tmp_path) - second_eval_loader.data_local = arxiv_dataset_name second_eval_loader.label = 'arxiv' test_cfg.eval_loader = om.OmegaConf.create([ first_eval_loader, diff --git a/tests/a_scripts/eval/test_eval_inputs.py b/tests/a_scripts/eval/test_eval_inputs.py index 030fc434bf..98b15743b3 100644 --- a/tests/a_scripts/eval/test_eval_inputs.py +++ b/tests/a_scripts/eval/test_eval_inputs.py @@ -40,6 +40,8 @@ def test_mispelled_mandatory_params_fail(self, cfg: DictConfig) -> None: with pytest.raises(( omegaconf.errors.ConfigKeyError, omegaconf.errors.InterpolationKeyError, + omegaconf.errors.MissingMandatoryValue, + TypeError, )): cfg[p + '-mispelled'] = cfg.pop(p) main(cfg) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 2d4a79b9b7..9274976796 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -31,7 +31,7 @@ build_optimizer, build_tokenizer, ) -from llmfoundry.utils.config_utils import process_init_device +from llmfoundry.utils.config_utils import process_init_device, to_dict_container from scripts.inference.convert_composer_to_hf import convert_composer_to_hf from tests.data_utils import make_tiny_ft_dataset @@ -829,7 +829,6 @@ def test_huggingface_conversion_callback( ) assert model_cfg is not None assert tokenizer_name is not None - model_cfg = om.create(model_cfg) if peft_config is not None: model_cfg['peft_config'] = peft_config @@ -851,15 +850,16 @@ def test_huggingface_conversion_callback( ) train_dataloader = build_finetuning_dataloader( - dataloader_cfg, - tokenizer, - device_batch_size, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + **dataloader_cfg, ) + name = model_cfg.pop('name') original_model = build_composer_model( - model_cfg['name'], - model_cfg, - tokenizer, + name, + tokenizer=tokenizer, + cfg=model_cfg, ) optimizer_name = optimizer_config.pop('name') optimizer = build_optimizer( @@ -973,10 +973,11 @@ def test_convert_and_generate( om_cfg.tokenizer.name, use_auth_token=model == 'llama2', ) + name = om_cfg.model.pop('name') original_model = build_composer_model( - name=om_cfg['model'].name, - cfg=om_cfg['model'], + name=name, tokenizer=tokenizer, + cfg=to_dict_container(om_cfg['model']), ) trainer = Trainer( model=original_model, @@ -1067,10 +1068,11 @@ def test_convert_and_generate_meta( tokenizer = transformers.AutoTokenizer.from_pretrained( om_cfg.tokenizer.name, ) + name = om_cfg.model.pop('name') original_model = build_composer_model( - name=om_cfg['model'].name, - cfg=om_cfg['model'], + name=name, tokenizer=tokenizer, + cfg=to_dict_container(om_cfg['model']), ) trainer = Trainer( model=original_model, @@ -1226,7 +1228,6 @@ def test_mptmoe_huggingface_conversion_callback( tokenizer_name = 'EleutherAI/gpt-neox-20b' assert model_cfg is not None assert tokenizer_name is not None - model_cfg = om.create(model_cfg) fsdp_config = { 'sharding_strategy': sharding_strategy, @@ -1273,9 +1274,9 @@ def test_mptmoe_huggingface_conversion_callback( ) train_dataloader = build_finetuning_dataloader( - dataloader_cfg, - tokenizer, - device_batch_size, + **dataloader_cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, ) optimizer_config = { @@ -1288,11 +1289,12 @@ def test_mptmoe_huggingface_conversion_callback( optimizer_name = optimizer_config.pop('name') init_context = process_init_device(model_cfg, fsdp_config) + name = model_cfg.pop('name') original_model = build_composer_model( - name=model_cfg.name, - cfg=model_cfg, + name=name, tokenizer=tokenizer, init_context=init_context, + cfg=model_cfg, ) optimizer = build_optimizer( diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index f721e0499d..2be1d5139d 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -11,12 +11,13 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from scripts.train.train import main, validate_config # noqa: E402 -from tests.data_utils import ( - create_arxiv_dataset, - create_c4_dataset_xxsmall, - gpt_tiny_cfg, +from llmfoundry.utils.config_utils import ( + make_dataclass_and_log_config, + update_batch_size_info, ) +from scripts.train.train import TrainConfig # noqa: E402 +from scripts.train.train import TRAIN_CONFIG_KEYS, main, validate_config +from tests.data_utils import create_c4_dataset_xxsmall, gpt_tiny_cfg from tests.fixtures.autouse import REPO_DIR @@ -118,8 +119,6 @@ def test_train_multi_eval(tmp_path: pathlib.Path): first_eval_loader.label = 'c4' # Create second eval dataloader using the arxiv dataset. second_eval_loader = copy.deepcopy(first_eval_loader) - arxiv_dataset_name = create_arxiv_dataset(tmp_path) - second_eval_loader.data_local = arxiv_dataset_name second_eval_loader.label = 'arxiv' test_cfg.eval_loader = om.create([first_eval_loader, second_eval_loader]) test_cfg.eval_subset_num_batches = 1 # -1 to evaluate on all batches @@ -181,7 +180,13 @@ def test_validate_config(): match= 'MoEs with expert parallelism (.*) require `use_orig_params=True`.', ): - validate_config(test_cfg) + _, cfg_obj = make_dataclass_and_log_config( + test_cfg, + TrainConfig, + TRAIN_CONFIG_KEYS, + transforms=[update_batch_size_info], + ) + validate_config(cfg_obj) def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path): diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index ad36630def..5a3b21dc3b 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -63,7 +63,7 @@ def cfg(self, foundry_dir: str) -> DictConfig: def test_misspelled_mandatory_params_fail(self, cfg: DictConfig) -> None: """Check that mandatory misspelled inputs fail to train.""" cfg.trai_loader = cfg.pop('train_loader') - with pytest.raises(omegaconf.errors.ConfigAttributeError): + with pytest.raises((omegaconf.errors.MissingMandatoryValue, TypeError)): main(cfg) def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: @@ -76,14 +76,16 @@ def test_missing_mandatory_parameters_fail(self, cfg: DictConfig) -> None: 'scheduler', 'max_duration', 'eval_interval', - 'precision', 'max_seq_len', ] for param in mandatory_params: orig_param = cfg.pop(param) - with pytest.raises( - (omegaconf.errors.ConfigAttributeError, NameError), - ): + with pytest.raises(( + TypeError, + NameError, + omegaconf.errors.InterpolationKeyError, + omegaconf.errors.MissingMandatoryValue, + )): main(cfg) cfg[param] = orig_param diff --git a/tests/callbacks/test_eval_gauntlet_callback.py b/tests/callbacks/test_eval_gauntlet_callback.py index 8d4df43d63..9c80127af5 100644 --- a/tests/callbacks/test_eval_gauntlet_callback.py +++ b/tests/callbacks/test_eval_gauntlet_callback.py @@ -13,6 +13,7 @@ from llmfoundry.eval.metrics.nlp import InContextLearningLMAccuracy from llmfoundry.utils.builders import build_icl_data_and_gauntlet +from llmfoundry.utils.config_utils import to_dict_container @pytest.fixture(autouse=True) @@ -73,8 +74,9 @@ def test_gauntlet_callback(averages: Optional[dict]): icl_task_type: language_modeling """, ) - assert isinstance(icl_task_config, - om.ListConfig) or isinstance(icl_task_config, str) + icl_task_config_list: List[om.DictConfig + ] = list(icl_task_config) # type: ignore + assert all(isinstance(c, om.DictConfig) for c in icl_task_config_list) eval_gauntlet_config = om.OmegaConf.create( """ @@ -94,22 +96,16 @@ def test_gauntlet_callback(averages: Optional[dict]): random_baseline: 0.0 """, ) - assert isinstance(eval_gauntlet_config, - om.DictConfig) or isinstance(eval_gauntlet_config, str) + assert isinstance(eval_gauntlet_config, om.DictConfig) if averages is not None: eval_gauntlet_config.averages = averages tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') # test loading functionality - _, _, eval_gauntlet_callback = build_icl_data_and_gauntlet( - icl_task_config, - eval_gauntlet_config, - tokenizer, - 4, - 1024, - 1, - ) + _, _, eval_gauntlet_callback = build_icl_data_and_gauntlet([ + to_dict_container(c) for c in icl_task_config_list + ], to_dict_container(eval_gauntlet_config), tokenizer, 4, 1024, 1) assert eval_gauntlet_callback is not None state = MockState(eval_gauntlet_callback.logger_keys) logger = MockLogger(state) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 27ac52e425..0da518d2e7 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -40,6 +40,7 @@ ) from llmfoundry.data.utils import get_tokens_per_batch_func from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils.config_utils import to_dict_container # yapf: disable from llmfoundry.utils.exceptions import ( ConsecutiveRepeatedChatRolesError, @@ -245,7 +246,7 @@ def test_correct_padding( test_cfg = get_config( conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml', ) - test_cfg.data_local = data_local + test_cfg.variables.data_local = data_local test_cfg.eval_loader.dataset.split = split test_cfg.dataset = om.create({ 'num_canonical_nodes': 1, @@ -258,10 +259,13 @@ def test_correct_padding( ) # Dataloaders + test_cfg.eval_loader.pop('name') + assert isinstance(test_cfg, DictConfig) + test_cfg = to_dict_container(test_cfg) eval_loader = build_text_dataloader( - test_cfg.eval_loader, - tokenizer, - batch_size, + **test_cfg['eval_loader'], + tokenizer=tokenizer, + device_batch_size=batch_size, ).dataloader batch = next(iter(eval_loader)) @@ -347,9 +351,9 @@ def test_invalid_jsonl_data(): with pytest.raises(MisconfiguredHfDatasetError): build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, ).dataloader @@ -410,9 +414,9 @@ def test_finetuning_dataloader( expected_keys += ['decoder_attention_mask', 'decoder_input_ids'] loader = build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + **cfg, ).dataloader batch_ix = 0 for batch in loader: @@ -463,7 +467,11 @@ def test_finetuning_dataloader_safe_load( tokenizer = build_tokenizer('gpt2', {}) with expectation: - _ = build_finetuning_dataloader(cfg, tokenizer, 1) + _ = build_finetuning_dataloader( + tokenizer=tokenizer, + device_batch_size=1, + **cfg, + ) # If no raised errors, we should expect downloaded files with only safe file types. if expectation == does_not_raise(): @@ -532,7 +540,11 @@ def test_finetuning_dataloader_small_data( ) with error_context: - _ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) + _ = build_finetuning_dataloader( + tokenizer=tokenizer, + device_batch_size=device_batch_size, + **cfg, + ) if dist.get_global_rank() == 0: shutil.rmtree(tiny_dataset_folder_path) @@ -577,7 +589,11 @@ def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): tokenizer_kwargs={'model_max_length': max_seq_len}, ) - _ = build_finetuning_dataloader(cfg, tokenizer, 4) + _ = build_finetuning_dataloader( + tokenizer=tokenizer, + device_batch_size=4, + **cfg, + ) def mock_get_file(path: str, destination: str, overwrite: bool = False): @@ -625,7 +641,11 @@ def test_finetuning_dataloader_custom_split_remote(split: str): 'llmfoundry.data.finetuning.dataloader.get_file', wraps=mock_get_file, ) as f: - _ = build_finetuning_dataloader(cfg, tokenizer, 4) + _ = build_finetuning_dataloader( + tokenizer=tokenizer, + device_batch_size=4, + **cfg, + ) for call in f.call_args_list: path_arg = call.kwargs['path'] dest_arg = call.kwargs['destination'] @@ -698,7 +718,11 @@ def test_finetuning_dataloader_streaming( cfg = om.create(cfg) - dataloader = build_finetuning_dataloader(cfg, tokenizer, 2).dataloader + dataloader = build_finetuning_dataloader( + tokenizer=tokenizer, + device_batch_size=2, + **cfg, + ).dataloader expected_keys = ['input_ids', 'labels'] for batch in dataloader: @@ -909,9 +933,9 @@ def test_malformed_data( with error_context: dl = build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + **cfg, ).dataloader if not any(invalid_prompt_response_params): @@ -1028,9 +1052,9 @@ def test_malformed_conversation_data( with error_context: build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + **cfg, ).dataloader @@ -1083,9 +1107,9 @@ def pad_preprocessing_function( # type: ignore device_batch_size = 1 dataloader = build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + **cfg, ).dataloader # We should be able to iterate through this dataset without crashing @@ -1231,7 +1255,11 @@ def test_token_counting_func_dataloader_setting( lambda *args, **kwargs: [], ) - dl = build_finetuning_dataloader(cfg, gptt, batch_size) + dl = build_finetuning_dataloader( + tokenizer=gptt, + device_batch_size=batch_size, + **cfg, + ) elif dataloader_type == 'finetuning-streaming': cfg = DictConfig({ 'name': 'finetuning', @@ -1252,9 +1280,13 @@ def test_token_counting_func_dataloader_setting( lambda *args, **kwargs: [], ) - dl = build_finetuning_dataloader(cfg, gptt, batch_size) + dl = build_finetuning_dataloader( + tokenizer=gptt, + device_batch_size=batch_size, + **cfg, + ) elif dataloader_type == 'text': - cfg = DictConfig({ + cfg = { 'name': 'text', 'dataset': { 'local': 'dummy-path', @@ -1265,7 +1297,7 @@ def test_token_counting_func_dataloader_setting( 'shuffle_seed': 0, }, **common_args, - }) + } ds_mock = MagicMock() ds_mock.tokenizer = gptt monkeypatch.setattr( @@ -1273,12 +1305,15 @@ def test_token_counting_func_dataloader_setting( lambda *args, **kwargs: ds_mock, ) - dl = build_text_dataloader(cfg, gptt, batch_size) + cfg.pop('name') + dl = build_text_dataloader( + **cfg, + tokenizer=gptt, + device_batch_size=batch_size, + ) else: raise NotImplementedError() - cfg = om.create(cfg) - batch_collated = dl.dataloader.collate_fn(batch_tokenized) # type: ignore actual_token_count = dl.get_num_tokens_in_batch(batch_collated) @@ -1286,12 +1321,12 @@ def test_token_counting_func_dataloader_setting( def test_build_unknown_dataloader(): - cfg = DictConfig({ + cfg = { 'name': 'unknown', - }) + } tokenizer = MagicMock() with pytest.raises(catalogue.RegistryError): - _ = build_dataloader(cfg, tokenizer, 2) + _ = build_dataloader(cfg=cfg, tokenizer=tokenizer, device_batch_size=2) invalid_conversation_params_sharegpt = [ @@ -1389,7 +1424,7 @@ def test_sharegpt_format( with error_context: build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + **cfg, ).dataloader diff --git a/tests/data/test_icl_datasets.py b/tests/data/test_icl_datasets.py index ce9fa7a493..5254c8e862 100644 --- a/tests/data/test_icl_datasets.py +++ b/tests/data/test_icl_datasets.py @@ -8,6 +8,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.utils.builders import build_icl_evaluators +from llmfoundry.utils.config_utils import to_list_container def load_icl_config(conf_path: str = 'tests/data/test_tasks.yaml'): @@ -23,7 +24,7 @@ def run_test( ): task_cfg = load_icl_config() evaluators, _ = build_icl_evaluators( - task_cfg.icl_tasks, + to_list_container(task_cfg.icl_tasks), tokenizer, 1024, 8, diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index db95bcad3d..0e02b4d98b 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -121,9 +121,9 @@ def test_auto_packing(profile_packing: Mock): profile_packing.return_value = [(1, .9, 0), (2, .8, 0), (3, .7, .5)] packing_ratio = auto_packing_ratio( - dataloader_cfg=DictConfig({'dataset': { + dataloader_cfg={'dataset': { 'max_seq_len': 2048, - }}), + }}, tokenizer=None, device_batch_size=1, ) # Dummy values, profiling results are already set. @@ -148,9 +148,9 @@ def test_dist_auto_packing(profile_packing: Mock): (3, .7, .5)] # should pick 2 packing_ratio = auto_packing_ratio( - dataloader_cfg=DictConfig({'dataset': { + dataloader_cfg={'dataset': { 'max_seq_len': 2048, - }}), + }}, tokenizer=None, device_batch_size=1, ) # Dummy values, profiling results are already set. @@ -196,8 +196,8 @@ def test_auto_packing_with_streaming_dataloader(tmp_path: Path): }) loader = build_finetuning_dataloader( - cfg, - tokenizer, + **cfg, + tokenizer=tokenizer, device_batch_size=6, ).dataloader @@ -217,7 +217,7 @@ def test_packing_with_dataloader(packing_ratio: Any): """Tests that packing works with a dataloader.""" reproducibility.seed_all(17) tokenizer = build_tokenizer('gpt2', {}) - cfg = DictConfig({ + cfg = { 'name': 'finetuning', 'dataset': { 'hf_name': 'tatsu-lab/alpaca', @@ -236,11 +236,11 @@ def test_packing_with_dataloader(packing_ratio: Any): 'prefetch_factor': None, 'persistent_workers': False, 'timeout': 0, - }) + } loader = build_finetuning_dataloader( - cfg, - tokenizer, + **cfg, + tokenizer=tokenizer, device_batch_size=6, ).dataloader diff --git a/tests/data_utils.py b/tests/data_utils.py index 30f49efb2d..9653d8579a 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -299,13 +299,13 @@ def gpt_tiny_cfg(dataset_name: str, device: str): test_cfg = om.load(f) assert isinstance(test_cfg, DictConfig) - test_cfg.data_local = dataset_name + test_cfg.variables.data_local = dataset_name test_cfg.global_train_batch_size = 8 test_cfg.device_eval_batch_size = 4 test_cfg.device_train_microbatch_size = 4 test_cfg.max_duration = '4ba' test_cfg.eval_interval = '4ba' - test_cfg.run_name = 'gpt-mini-integration-test' + test_cfg.variables.run_name = 'gpt-mini-integration-test' if device == 'cpu': test_cfg.model.init_device = 'cpu' diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index cd85bd2603..87751956c5 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -53,9 +53,9 @@ def tiny_ft_dataloader( }) dataloader = build_finetuning_dataloader( - dataloader_cfg, - mpt_tokenizer, - device_batch_size, + **dataloader_cfg, + tokenizer=mpt_tokenizer, + device_batch_size=device_batch_size, ).dataloader assert isinstance(dataloader, DataLoader) diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py index 50ad4497d5..83b0924a5d 100644 --- a/tests/fixtures/models.py +++ b/tests/fixtures/models.py @@ -2,10 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import Any, Callable +from typing import Any, Callable, Dict import pytest -from omegaconf import DictConfig from pytest import fixture from transformers import PreTrainedTokenizerBase @@ -14,9 +13,10 @@ from llmfoundry.utils.builders import build_composer_model, build_tokenizer -def _build_model(config: DictConfig, tokenizer: PreTrainedTokenizerBase): +def _build_model(config: Dict[str, Any], tokenizer: PreTrainedTokenizerBase): + name = config.pop('name') model = build_composer_model( - name=config.name, + name=name, cfg=config, tokenizer=tokenizer, ) @@ -34,13 +34,13 @@ def build_tiny_mpt( ) -> Callable[..., ComposerMPTCausalLM]: def build(**kwargs: Any) -> ComposerMPTCausalLM: - config = DictConfig({ + config = { 'name': 'mpt_causal_lm', 'd_model': 128, 'n_heads': 4, 'n_layers': 2, 'expansion_ratio': 2, - }) + } config.update(kwargs) model = _build_model(config, mpt_tokenizer) assert isinstance(model, ComposerMPTCausalLM) @@ -62,12 +62,12 @@ def build(**kwargs: Any) -> ComposerHFCausalLM: 'expansion_ratio': 2, } config_overrides.update(kwargs) - config = DictConfig({ + config = { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'mosaicml/mpt-7b', 'pretrained': False, 'config_overrides': config_overrides, - }) + } model = _build_model(config, mpt_tokenizer) assert isinstance(model, ComposerHFCausalLM) return model diff --git a/tests/models/hf/test_fsdp_weight_tying.py b/tests/models/hf/test_fsdp_weight_tying.py index 4b76996ba1..69ced673a1 100644 --- a/tests/models/hf/test_fsdp_weight_tying.py +++ b/tests/models/hf/test_fsdp_weight_tying.py @@ -7,7 +7,6 @@ import pytest from composer import Trainer from composer.models.huggingface import maybe_get_underlying_model -from omegaconf import OmegaConf as om from llmfoundry.utils.builders import build_composer_model, build_tokenizer @@ -54,7 +53,6 @@ def test_fsdp_weight_tying( assert model_cfg is not None assert tokenizer_name is not None - model_cfg = om.create(model_cfg) if peft_config is not None: model_cfg['peft_config'] = peft_config @@ -74,14 +72,15 @@ def test_fsdp_weight_tying( tokenizer_kwargs={'model_max_length': 32}, ) + name = model_cfg.pop('name') original_model = build_composer_model( - name=model_cfg['name'], + name=name, cfg=model_cfg, tokenizer=tokenizer, ) underlying_model = maybe_get_underlying_model(original_model.model) - lm_head = underlying_model.lm_head if peft_config is None else underlying_model.lm_head + lm_head = underlying_model.lm_head embedding_layer = underlying_model.model.embed_tokens if peft_config is None else underlying_model.model.embed_tokens lm_head_id = id(lm_head.weight) diff --git a/tests/models/hf/test_hf_config.py b/tests/models/hf/test_hf_config.py index 191bce48f7..ae5839ff92 100644 --- a/tests/models/hf/test_hf_config.py +++ b/tests/models/hf/test_hf_config.py @@ -10,13 +10,13 @@ import pytest import torch -from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import AutoModelForCausalLM, PretrainedConfig from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils.config_utils import to_dict_container def test_remote_code_false_mpt( @@ -49,9 +49,10 @@ def test_remote_code_false_mpt( ValueError, match='trust_remote_code must be set to True for MPT models.', ): + name = test_cfg.model.pop('name') _ = build_composer_model( - name=test_cfg.model.name, - cfg=test_cfg.model, + name=name, + cfg=to_dict_container(test_cfg.model), tokenizer=tokenizer, ) @@ -153,9 +154,10 @@ def test_hf_config_override( tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + name = test_cfg.model.pop('name') model = build_composer_model( - name=test_cfg.model.name, - cfg=test_cfg.model, + name=name, + cfg=to_dict_container(test_cfg.model), tokenizer=tokenizer, ) @@ -169,7 +171,7 @@ def test_hf_config_override( # load hf causal lm model with config_overrides hf_model_config = deepcopy(test_cfg) - model_cfg = DictConfig({ + model_cfg = om.create({ 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': save_path, 'pretrained': False, @@ -177,9 +179,10 @@ def test_hf_config_override( }) hf_model_config.model = model_cfg + name = hf_model_config.model.pop('name') hf_model = build_composer_model( - name=hf_model_config.model.name, - cfg=hf_model_config.model, + name=name, + cfg=to_dict_container(hf_model_config.model), tokenizer=tokenizer, ) @@ -212,10 +215,10 @@ def test_rope_scaling_override(): 'pretrained': False, 'init_device': 'cpu', } - model_cfg = om.create(model_cfg) + name = model_cfg.pop('name') model = build_composer_model( - name=model_cfg.name, + name=name, cfg=model_cfg, tokenizer=None, # type: ignore ) @@ -241,10 +244,10 @@ def test_nested_override(): 'pretrained': False, 'init_device': 'meta', } - model_cfg = om.create(model_cfg) + name = model_cfg.pop('name') model = build_composer_model( - name=model_cfg.name, + name=name, cfg=model_cfg, tokenizer=None, # type: ignore ) diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 683a0ba0cd..522fc5db57 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -9,7 +9,6 @@ import torch import transformers from composer import Trainer -from omegaconf import OmegaConf as om from peft import LoraConfig, get_peft_model from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp @@ -76,7 +75,6 @@ def test_lora_mixed_init( assert model_cfg is not None assert tokenizer_name is not None - model_cfg = om.create(model_cfg) model_cfg['peft_config'] = peft_config fsdp_config = { @@ -95,8 +93,9 @@ def test_lora_mixed_init( tokenizer_kwargs={'model_max_length': 32}, ) + name = model_cfg.pop('name') original_model = build_composer_model( - name=model_cfg['name'], + name=name, cfg=model_cfg, tokenizer=tokenizer, ) diff --git a/tests/models/hf/test_hf_t5.py b/tests/models/hf/test_hf_t5.py index fb8689e665..47443f2410 100644 --- a/tests/models/hf/test_hf_t5.py +++ b/tests/models/hf/test_hf_t5.py @@ -23,4 +23,4 @@ def test_experimental_hf_t5(): tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base') with pytest.warns(ExperimentalWarning): - _ = ComposerHFT5(cfg, tokenizer) + _ = ComposerHFT5(**cfg, tokenizer=tokenizer) diff --git a/tests/models/hf/test_hf_v_mpt.py b/tests/models/hf/test_hf_v_mpt.py index 042a18bf76..66f04e0c4a 100644 --- a/tests/models/hf/test_hf_v_mpt.py +++ b/tests/models/hf/test_hf_v_mpt.py @@ -9,6 +9,7 @@ from omegaconf import OmegaConf as om from llmfoundry.utils.builders import build_composer_model, build_tokenizer +from llmfoundry.utils.config_utils import to_dict_container @pytest.mark.gpu @@ -69,9 +70,11 @@ def test_compare_hf_v_mpt( tokenizer_name=tokenizer_name, tokenizer_kwargs=tokenizer_kwargs, ) + name = hf_cfg.model.pop('name') + hf_cfg.model.pop('device') hf_model = build_composer_model( - name=hf_cfg.model.name, - cfg=hf_cfg.model, + name=name, + cfg=to_dict_container(hf_cfg.model), tokenizer=tokenizer, ).to(device) hf_n_params = sum(p.numel() for p in hf_model.parameters()) @@ -121,9 +124,13 @@ def test_compare_hf_v_mpt( print('Initializing model...') print(model_cfg) + if 'name' in model_cfg: + name = model_cfg.pop('name') + if 'device' in model_cfg: + model_cfg.pop('device') model = build_composer_model( - name=model_cfg.name, - cfg=model_cfg, + name=name, + cfg=to_dict_container(model_cfg), tokenizer=tokenizer, ).to(device) n_params = sum(p.numel() for p in model.parameters()) diff --git a/tests/models/inference_api_wrapper/test_fmapi.py b/tests/models/inference_api_wrapper/test_fmapi.py index 72c41c2ebe..af26823aae 100644 --- a/tests/models/inference_api_wrapper/test_fmapi.py +++ b/tests/models/inference_api_wrapper/test_fmapi.py @@ -14,6 +14,7 @@ ) from llmfoundry.models.inference_api_wrapper.fmapi import FMAPIEvalInterface from llmfoundry.utils.builders import build_icl_evaluators +from llmfoundry.utils.config_utils import to_list_container def load_icl_config(): @@ -89,7 +90,7 @@ def mock_create(**kwargs: Dict[str, str]): return MockCompletion(' ') -def test_casual_fmapi_wrapper(tmp_path: str): +def test_causal_fmapi_wrapper(tmp_path: str): # patch block_until_ready with patch.object(FMAPIEvalInterface, 'block_until_ready') as mock: @@ -110,7 +111,7 @@ def test_casual_fmapi_wrapper(tmp_path: str): task_cfg = load_icl_config() evaluators, _ = build_icl_evaluators( - task_cfg.icl_tasks, + to_list_container(task_cfg.icl_tasks), tokenizer, 1024, 2, @@ -153,7 +154,7 @@ def test_chat_fmapi_wrapper(tmp_path: str): task_cfg = load_icl_config() evaluators, _ = build_icl_evaluators( - task_cfg.icl_tasks, + to_list_container(task_cfg.icl_tasks), tokenizer, 1024, 2, diff --git a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py index acc3cb9622..f35e5cd750 100644 --- a/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py +++ b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py @@ -14,6 +14,7 @@ ) from llmfoundry.tokenizers import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_icl_evaluators +from llmfoundry.utils.config_utils import to_list_container @pytest.fixture(scope='module') @@ -112,7 +113,7 @@ def test_openai_api_eval_wrapper(tmp_path: str, openai_api_key_env_var: str): task_cfg = load_icl_config() evaluators, _ = build_icl_evaluators( - task_cfg.icl_tasks, + to_list_container(task_cfg.icl_tasks), tokenizer, 1024, 2, @@ -152,7 +153,7 @@ def test_chat_api_eval_wrapper(tmp_path: str, openai_api_key_env_var: str): task_cfg = load_icl_config() evaluators, _ = build_icl_evaluators( - task_cfg.icl_tasks, + to_list_container(task_cfg.icl_tasks), tokenizer, 1024, 2, diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 88113cf55b..3dc3e5dda1 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -5,7 +5,6 @@ import pytest from composer.core.precision import get_precision_context -from omegaconf import OmegaConf as om from llmfoundry.models.hf.hf_fsdp import rgetattr from llmfoundry.models.layers.attention import is_flash_v2_installed @@ -45,8 +44,6 @@ def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): if use_flash_attention_2: model_cfg['use_flash_attention_2'] = True - model_cfg = om.create(model_cfg) - tokenizer = build_tokenizer( tokenizer_name=tokenizer_name, tokenizer_kwargs={'model_max_length': 10}, @@ -60,8 +57,9 @@ def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): ) and use_flash_attention_2 else contextlib.nullcontext() with error_context: + name = model_cfg.pop('name') model = build_composer_model( - name=model_cfg['name'], + name=name, cfg=model_cfg, tokenizer=tokenizer, ) diff --git a/tests/models/test_fsdp_act_checkpoint.py b/tests/models/test_fsdp_act_checkpoint.py index ab5f2705b4..a41574538a 100644 --- a/tests/models/test_fsdp_act_checkpoint.py +++ b/tests/models/test_fsdp_act_checkpoint.py @@ -6,7 +6,6 @@ import pytest from composer import Trainer from composer.utils import get_device -from omegaconf import OmegaConf as om from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ CheckpointWrapper @@ -47,7 +46,6 @@ def test_fsdp_act_checkpoint( }, 'activation_checkpointing_target': activation_checkpointing_target, } - model_cfg = om.create(model_cfg) fsdp_config = { 'activation_checkpointing': activation_checkpointing, @@ -55,7 +53,7 @@ def test_fsdp_act_checkpoint( 'activation_cpu_offload': False, } - model = ComposerMPTCausalLM(model_cfg) + model = ComposerMPTCausalLM(**model_cfg) model = device.module_to_device(model) trainer = Trainer( diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 243e45b671..8f074dd270 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -42,6 +42,7 @@ from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils.config_utils import to_dict_container def get_config( @@ -54,9 +55,12 @@ def get_config( return cast(DictConfig, test_cfg) -def _load_tokenizer_cfg(cfg: DictConfig) -> Dict: - config = om.to_container(cfg, resolve=True) - assert isinstance(config, Dict) +def _load_tokenizer_cfg(cfg: Union[Dict[str, Any], DictConfig]) -> Dict: + if isinstance(cfg, DictConfig): + config = to_dict_container(cfg) + else: + assert isinstance(cfg, dict) + config = cfg return config @@ -103,9 +107,10 @@ def _get_objs( tokenizer_cfg.get('kwargs', {}), ) + name = test_cfg.model.pop('name') model = build_composer_model( - name=test_cfg.model.name, - cfg=test_cfg.model, + name=name, + cfg=to_dict_container(test_cfg.model), tokenizer=tokenizer, ) @@ -354,9 +359,10 @@ def test_full_forward_and_backward_gpt2_small(batch_size: int = 2): tokenizer_cfg.get('kwargs', {}), ) + name = neo_cfg.model.pop('name') model = build_composer_model( - name=neo_cfg.model.name, - cfg=neo_cfg.model, + name=name, + cfg=to_dict_container(neo_cfg.model), tokenizer=tokenizer, ).to(device) @@ -412,9 +418,10 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): tokenizer_cfg.get('kwargs', {}), ) + name = t5_cfg.model.pop('name') model = build_composer_model( - name=t5_cfg.model.name, - cfg=t5_cfg.model, + name=name, + cfg=to_dict_container(t5_cfg.model), tokenizer=tokenizer, ).to(device) @@ -511,9 +518,10 @@ def test_determinism( tokenizer_cfg.get('kwargs', {}), ) + name = test_cfg.model.pop('name') model_1 = build_composer_model( - name=test_cfg.model.name, - cfg=test_cfg.model, + name=name, + cfg=to_dict_container(test_cfg.model), tokenizer=tokenizer, ) model_2 = copy.deepcopy(model_1) @@ -590,9 +598,10 @@ def test_loss_fn(): tokenizer_cfg.get('kwargs', {}), ) + name = test_cfg.model.pop('name') model_1 = build_composer_model( - name=test_cfg.model.name, - cfg=test_cfg.model, + name=name, + cfg=to_dict_container(test_cfg.model), tokenizer=tokenizer, ) model_2 = copy.deepcopy(model_1) @@ -693,9 +702,10 @@ def test_loss_reduction(loss_fn_config: str): tokenizer_cfg.get('kwargs', {}), ) + name = test_cfg.model.pop('name') model_1 = build_composer_model( - name=test_cfg.model.name, - cfg=test_cfg.model, + name=name, + cfg=to_dict_container(test_cfg.model), tokenizer=tokenizer, ) model_2 = copy.deepcopy(model_1) @@ -799,15 +809,14 @@ def test_opt_wrapping(peft_config: Optional[dict[str, str]]): if peft_config is not None: conf['model']['peft_config'] = peft_config - config = DictConfig(conf) - - tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer) + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(conf['tokenizer']) tokenizer = build_tokenizer( - config.tokenizer.name, + conf['tokenizer']['name'], tokenizer_cfg.get('kwargs', {}), ) - model = ComposerHFCausalLM(config.model, tokenizer) + conf['model'].pop('name') + model = ComposerHFCausalLM(**conf['model'], tokenizer=tokenizer) # check that all the modules we except are blocked from FSDP wrapping underlying_model = maybe_get_underlying_model(model.model) @@ -840,7 +849,8 @@ def test_lora_id(): tokenizer_cfg.get('kwargs', {}), ) - model = ComposerHFCausalLM(config.model, tokenizer) + config.model.pop('name') + model = ComposerHFCausalLM(**config.model, tokenizer=tokenizer) assert isinstance(model.model, peft.PeftModelForCausalLM) @@ -948,7 +958,7 @@ def test_mpt_creation( assert len(mpt.transformer.blocks) == 2 d_model = hf_config.d_model - if ffn_hidden_size is None: + if ffn_hidden_size is None: # type: ignore (sometimes it may not be none) ffn_hidden_size = int(hf_config.d_model * hf_config.expansion_ratio) for block in mpt.transformer.blocks: assert isinstance(block, MPTBlock) diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index f64925e6dd..dfcb5b327c 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -13,7 +13,6 @@ from composer.callbacks import Generate from composer.core import Evaluator from composer.loggers import WandBLogger -from omegaconf import DictConfig, ListConfig from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer @@ -286,14 +285,15 @@ def test_build_evaluators_empty(): def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): tokenizer = TiktokenTokenizerWrapper(model_name='gpt-4') - eval_loader_cfg = DictConfig({ + eval_loader_cfg = { 'name': 'text', 'dataset': { + 'streams': None, # mocked, not needed }, 'drop_last': False, 'num_workers': 8, - }) + } monkeypatch.setattr( 'llmfoundry.data.text_data.StreamingTextDataset', lambda *args, @@ -307,7 +307,7 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): assert eval_loaders[0].dataloader is not None assert eval_loaders[0].metric_names == [] - multi_eval_loader_cfg = ListConfig([ + multi_eval_loader_cfg = [ { 'name': 'text', 'label': 'test1', @@ -326,7 +326,7 @@ def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): 'drop_last': False, 'num_workers': 8, }, - ]) + ] monkeypatch.setattr( 'llmfoundry.data.text_data.StreamingTextDataset', lambda *args, diff --git a/tests/utils/test_mlflow_logging.py b/tests/utils/test_mlflow_logging.py index 205c985e97..04a600d44c 100644 --- a/tests/utils/test_mlflow_logging.py +++ b/tests/utils/test_mlflow_logging.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock, patch import pytest -from omegaconf import OmegaConf from llmfoundry.utils.config_utils import ( _log_dataset_uri, @@ -18,7 +17,7 @@ def create_config(**kwargs: Any): """Helper function to create OmegaConf configurations.""" - return OmegaConf.create(kwargs) + return kwargs def test_parse_source_dataset_delta_table(): @@ -108,7 +107,7 @@ def test_log_dataset_uri(): def test_multiple_eval_datasets(): # Setup a configuration with multiple evaluation datasets - cfg = OmegaConf.create({ + cfg = { 'train_loader': { 'dataset': { 'hf_name': 'huggingface/train_dataset', @@ -123,7 +122,7 @@ def test_multiple_eval_datasets(): 'hf_name': 'huggingface/eval_dataset2', }, }], - }) + } expected_data_paths = [('hf', 'huggingface/train_dataset', 'train'), ('hf', 'huggingface/eval_dataset1', 'eval'),