diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index e02bf03693..788a8943b1 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -204,7 +204,7 @@ def _save_checkpoint(self, state: State, logger: Logger): state_dict[k] = v.to(dtype=self.dtype) if dist.get_global_rank() == 0: - log.debug('Saving Hugging Face checkpoint to disk') + log.debug('Saving Hugging Face checkpoint in global rank 0') copied_config = copy.deepcopy(original_model.config) if copied_config.model_type == 'mpt': diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 7d497b4efd..8ccf7f25e9 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -477,13 +477,13 @@ def build_text_denoising_dataloader( remote=cfg.dataset.get('remote'), split=cfg.dataset.get('split'), shuffle=cfg.dataset.get('shuffle', False), - predownload=cfg.dataset.get('predownload', 100_000), + predownload=cfg.dataset.get('predownload', None), keep_zip=cfg.dataset.get('keep_zip', False), download_retry=cfg.dataset.get('download_retry', 2), download_timeout=cfg.dataset.get('download_timeout', 60), - validate_hash=cfg.dataset.get('validate_hash'), + validate_hash=cfg.dataset.get('validate_hash', None), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), - num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', 128), + num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, ) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 44d6d345f5..b19cab841f 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -136,13 +136,13 @@ def build_finetuning_dataloader(cfg: DictConfig, epoch_size=cfg.dataset.get('epoch_size', None), predownload=cfg.dataset.get('predownload', None), cache_limit=cfg.dataset.get('cache_limit', None), - partition_algo=cfg.dataset.get('partition_algo', 'orig'), + partition_algo=cfg.dataset.get('partition_algo', 'relaxed'), num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, shuffle=cfg.dataset.get('shuffle', False), - shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'), + shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), - shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18), + shuffle_block_size=cfg.dataset.get('shuffle_block_size', None), sampling_method=cfg.dataset.get('sampling_method', 'balanced'), sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 6ba6ad96c8..bc712a7504 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -88,12 +88,12 @@ class StreamingFinetuningDataset(StreamingDataset): keep_zip (bool): Whether to keep or delete the compressed form when decompressing downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to `False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all + epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all streams. If ``None``, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger or smaller epoch size. Defaults to ``None``. predownload (int, optional): Target number of samples ahead to download the shards of while - iterating. Defaults to ``100_000``. + iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to None @@ -101,15 +101,17 @@ class StreamingFinetuningDataset(StreamingDataset): bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None. partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. Defaults to ``None``, which is interpreted as the number of nodes of the - initial run. + resumption. If ``None``, this is interpreted as 64 times the number of physical + nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the + number of physical nodes of the initial run otherwise. Defaults to ``None``. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + shuffle_block_size (int): Unit of shuffle. If ``None``, its value is calculated as + ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to ``None``. sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``. sampling_granularity (int): When picking samples for a stream's final partial repeat, @@ -129,16 +131,16 @@ def __init__(self, download_timeout: float = 60, validate_hash: Optional[str] = None, keep_zip: bool = False, - epoch_size: Optional[int] = None, + epoch_size: Optional[Union[int, str]] = None, predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, - partition_algo: str = 'orig', + partition_algo: str = 'relaxed', num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, shuffle: bool = False, - shuffle_algo: str = 'py1b', + shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, + shuffle_block_size: Optional[int] = None, sampling_method: str = 'balanced', sampling_granularity: int = 1, batching_method: str = 'random', diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 93af2f63ed..51fd6b38dc 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -46,12 +46,12 @@ class StreamingTextDataset(StreamingDataset): keep_zip (bool): Whether to keep or delete the compressed form when decompressing downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to `False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all + epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all streams. If ``None``, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger or smaller epoch size. Defaults to ``None``. predownload (int, optional): Target number of samples ahead to download the shards of while - iterating. Defaults to ``100_000``. + iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to None @@ -59,15 +59,19 @@ class StreamingTextDataset(StreamingDataset): bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None. partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. Defaults to ``None``, which is interpreted as the number of nodes of the - initial run. + resumption. If ``None``, this is interpreted as 64 times the number of physical + nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the + number of physical nodes of the initial run otherwise. Defaults to ``None``. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split + into blocks of this size, and samples within each block are shuffled. If ``None``, its + value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to + ``None``. sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``. sampling_granularity (int): When picking samples for a stream's final partial repeat, @@ -89,16 +93,16 @@ def __init__(self, download_timeout: float = 60, validate_hash: Optional[str] = None, keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: int = 100_000, + epoch_size: Optional[Union[int, str]] = None, + predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, - partition_algo: str = 'orig', + partition_algo: str = 'relaxed', num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, shuffle: bool = False, - shuffle_algo: str = 'py1b', + shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, + shuffle_block_size: Optional[int] = None, sampling_method: str = 'balanced', sampling_granularity: int = 1, batching_method: str = 'random', diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index c4ca68d733..c0a1e65248 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -59,6 +59,7 @@ def __init__( use_cache: bool = False, init_config: Dict = init_config_defaults, fc_type: str = 'torch', + tie_word_embeddings: bool = True, verbose: Optional[int] = None, **kwargs: Any, ): @@ -128,6 +129,7 @@ def __init__( --- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. + tie_word_embeddings (bool): Whether to tie the input embedding and output layers. """ self.d_model = d_model self.n_heads = n_heads @@ -164,7 +166,11 @@ def __init__( warnings.warn( f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`' ) - super().__init__(**kwargs) + # tie_word_embeddings is set in Huggingface's PretrainedConfig __init__ + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self._validate_config() diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0cb3ebd56c..274c1b76e5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -45,7 +45,9 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias +from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, + attn_bias_shape, + build_attn_bias) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY @@ -231,10 +233,11 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') - def get_input_embeddings(self) -> nn.Embedding: + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.wte - def set_input_embeddings(self, value: nn.Embedding) -> None: + def set_input_embeddings( + self, value: Union[SharedEmbedding, nn.Embedding]) -> None: self.wte = value @torch.no_grad() @@ -574,14 +577,20 @@ class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config: MPTConfig): super().__init__(config) - if not config.tie_word_embeddings: - raise ValueError( - 'MPTForCausalLM only supports tied word embeddings') - log.info(f'Instantiating an MPTForCausalLM model from {__file__}') self.transformer: MPTModel = MPTModel(config) + self.lm_head = None + if not config.tie_word_embeddings: + self.lm_head = nn.Linear( + config.d_model, + config.vocab_size, + bias=False, + device=config.init_device, + ) + self.lm_head._fsdp_wrap = True + for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): continue @@ -602,19 +611,38 @@ def __init__(self, config: MPTConfig): ) self.logit_scale = logit_scale - def get_input_embeddings(self) -> nn.Embedding: - return self.transformer.wte + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: + return self.transformer.get_input_embeddings() def set_input_embeddings( self, value: Union[SharedEmbedding, nn.Embedding]) -> None: - self.transformer.wte = value + self.transformer.set_input_embeddings(value) - def get_output_embeddings(self) -> nn.Embedding: - return self.transformer.wte + def get_output_embeddings( + self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]: + if self.lm_head is not None: + return self.lm_head + return self.transformer.get_input_embeddings() def set_output_embeddings( - self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None: - self.transformer.wte = new_embeddings + self, new_embeddings: Union[SharedEmbedding, nn.Embedding, + nn.Linear]) -> None: + if self.lm_head is not None: + self.lm_head = new_embeddings + else: + if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)): + raise ValueError( + 'new_embeddings must be an instance of SharedEmbedding ' + + f'or nn.Embedding, but got {type(new_embeddings)}.') + warnings.warn( + 'Using `set_output_embeddings` to set the embedding layer of ' + + 'MPTForCausalLM with tied weights. Given weights are tied, ' + + 'using `set_input_embeddings` is recommended over using ' + + '`set_output_embeddings`.') + self.transformer.set_input_embeddings(new_embeddings) + + def tie_weights(self) -> None: + self.lm_head = None def set_decoder(self, decoder: MPTModel) -> None: self.transformer = decoder @@ -658,12 +686,14 @@ def forward( use_cache=use_cache, ) - # move outputs to same device as weights for token embedding - # needed to support HF `device_map` - logits = self.transformer.wte( - outputs.last_hidden_state.to(self.transformer.wte.weight.device), - True, - ) + if self.lm_head is not None: + logits = self.lm_head(outputs.last_hidden_state) + else: + # move outputs to same device as weights for token embedding + # needed to support HF `device_map` + out = outputs.last_hidden_state + out = out.to(self.transformer.wte.weight.device) + logits = self.transformer.wte(out, True) if self.logit_scale is not None: if self.logit_scale == 0: @@ -705,7 +735,35 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', + None) or ['MPTBlock'] + + if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: + if len(act_ckpt_list) > 1: + log.info( + 'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).' + ) + return isinstance(module, MPTBlock) + + mod_types = () + for mod_name in act_ckpt_list: + if mod_name.lower() == 'mptblock': + mod_types += (MPTBlock,) + elif mod_name in ATTN_CLASS_REGISTRY: + mod_types += (ATTN_CLASS_REGISTRY[mod_name],) + elif mod_name in FFN_CLASS_REGISTRY: + mod_types += (FFN_CLASS_REGISTRY[mod_name],) + elif mod_name in NORM_CLASS_REGISTRY: + mod_types += (NORM_CLASS_REGISTRY[mod_name],) + else: + msg = ', '.join( + list(ATTN_CLASS_REGISTRY.keys()) + + list(FFN_CLASS_REGISTRY.keys()) + + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) + raise ValueError( + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' + ) + return isinstance(module, mod_types) def prepare_inputs_for_generation( self, @@ -859,7 +917,11 @@ def flops_per_batch(self, batch: Mapping) -> int: # assume the backward pass is approximately 2x the forward pass bs, msl = batch['input_ids'].shape[0:2] - params_flops_per_token = 2 * self.n_active_params + params = self.n_active_params + if not self.model.transformer.config.tie_word_embeddings: + # embedding layers are lookup tables, therefore are not counted in the FLOP computation + params -= self.model.transformer.wte.weight.numel() + params_flops_per_token = 2 * params params_flops_per_seq = params_flops_per_token * msl attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 * (self.model.config.d_model * (msl**2))) diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 0627cec4cd..35e77eab6c 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -19,7 +19,8 @@ import numpy as np import sentencepiece as spm -from transformers import AutoTokenizer, PreTrainedTokenizer +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) log = logging.getLogger(__name__) @@ -35,8 +36,9 @@ def _get_weight_data_type(data_type: str): # TODO: move this functionality to composer once the bug fixes are upstreamed def get_hf_tokenizer_from_composer_state_dict( - state_dict: Dict[str, Any], - tokenizer_save_dir: Optional[str] = None + state_dict: Dict[str, Any], + trust_remote_code: bool, + tokenizer_save_dir: Optional[str] = None, ) -> Optional[PreTrainedTokenizer]: if 'state' not in state_dict: raise RuntimeError( @@ -85,7 +87,8 @@ def get_hf_tokenizer_from_composer_state_dict( with open(tokenizer_file_path, 'wb') as _tmp_file: _tmp_file.write(s.serialized_model_proto()) - hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_dir) + hf_tokenizer = load_tokenizer(tokenizer_save_dir, + trust_remote_code=trust_remote_code) # remove 'name_or_path' hf_tokenizer.name_or_path = '' @@ -94,6 +97,20 @@ def get_hf_tokenizer_from_composer_state_dict( return hf_tokenizer +def load_tokenizer( + tokenizer_save_dir: str, trust_remote_code: bool +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + try: + return AutoTokenizer.from_pretrained( + tokenizer_save_dir, trust_remote_code=trust_remote_code) + except ValueError as e: + raise ValueError( + f'Got error while loading tokenizer with trust_remote_code={trust_remote_code}: {e}. ' + + + 'If accessing a tokenizer defined outside of the transformers module,' + + ' please use --trust_remote_code.') + + def _write_zero_bias(weight_name: str, weight_file_path: str, bias_shape: Union[Tuple[int, ...], int]) -> None: """Write zeros for bias when converting MPT to FasterTransformer weights. diff --git a/scripts/inference/benchmarking/yamls/1b.yaml b/scripts/inference/benchmarking/yamls/1b.yaml index f94aa3d806..d1cfb3c913 100644 --- a/scripts/inference/benchmarking/yamls/1b.yaml +++ b/scripts/inference/benchmarking/yamls/1b.yaml @@ -12,7 +12,6 @@ tokenizer: model: name: mpt_causal_lm init_device: cpu - tokenizer_name: ${tokenizer_name} d_model: 2048 n_heads: 16 # Modified 24->16 so that d_head == 128 to statisfy FlashAttention n_layers: 24 diff --git a/scripts/inference/benchmarking/yamls/7b.yaml b/scripts/inference/benchmarking/yamls/7b.yaml index 55e9ae8413..f57ed2657f 100644 --- a/scripts/inference/benchmarking/yamls/7b.yaml +++ b/scripts/inference/benchmarking/yamls/7b.yaml @@ -12,7 +12,6 @@ tokenizer: model: name: mpt_causal_lm init_device: cpu - tokenizer_name: ${tokenizer_name} d_model: 4096 n_heads: 32 n_layers: 32 diff --git a/scripts/inference/convert_composer_mpt_to_ft.py b/scripts/inference/convert_composer_mpt_to_ft.py index 79275030b3..f59eb6005a 100644 --- a/scripts/inference/convert_composer_mpt_to_ft.py +++ b/scripts/inference/convert_composer_mpt_to_ft.py @@ -67,6 +67,7 @@ def write_ft_checkpoint_from_composer_checkpoint( checkpoint_path: Union[Path, str], infer_gpu_num: int, save_dir: str, + trust_remote_code: bool, output_precision: str = 'fp32', local_checkpoint_save_location: Optional[Union[Path, str]] = None) -> None: @@ -79,6 +80,7 @@ def write_ft_checkpoint_from_composer_checkpoint( checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend supported by Composer. infer_gpu_num (int): The number of gpus you are planning to use for inference. + trust_remote_code (bool): Whether or not to use code outside of the transformers module. save_dir (str): Path of the directory to save the checkpoint in FT format. output_precision (str, optional): The precision of the output weights saved to the FasterTransformer model. Can be either ``fp32`` or ``fp16``. local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally. @@ -125,7 +127,7 @@ def write_ft_checkpoint_from_composer_checkpoint( print('#' * 30) print('Extracting HF Tokenizer...') hf_tokenizer = get_hf_tokenizer_from_composer_state_dict( - composer_state_dict) + composer_state_dict, trust_remote_code) if hf_tokenizer is None: print('Warning! No HF Tokenizer found!') @@ -206,6 +208,10 @@ def parse_args() -> Namespace: 'Data type of weights in the FasterTransformer output model. Input checkpoint weights will be converted to this dtype.', choices=['fp32', 'fp16'], default='fp32') + parser.add_argument( + '--trust_remote_code', + action='store_true', + help='Whether or not to use code outside of transformers module.') return parser.parse_args() @@ -229,4 +235,5 @@ def parse_args() -> Namespace: infer_gpu_num=args.infer_gpu_num, save_dir=save_dir, output_precision=args.output_precision, - local_checkpoint_save_location=args.local_checkpoint_save_location) + local_checkpoint_save_location=args.local_checkpoint_save_location, + trust_remote_code=args.trust_remote_code) diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index 5625a3b046..1b43762473 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -16,6 +16,7 @@ from llmfoundry import MPTConfig, MPTForCausalLM from llmfoundry.utils import get_hf_tokenizer_from_composer_state_dict +from llmfoundry.utils.checkpoint_conversion_helpers import load_tokenizer from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility @@ -23,6 +24,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path: Union[Path, str], output_path: Union[Path, str], + trust_remote_code: bool, output_precision: str = 'fp32', local_checkpoint_save_location: Optional[Union[Path, str]] = None ) -> Tuple[PretrainedConfig, Optional[PreTrainedTokenizerBase]]: @@ -63,6 +65,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. output_path (Union[Path, str]): Path to the folder to write the output to. + trust_remote_code (bool): Whether or not to use code outside of the transformers module. output_precision (str, optional): The precision of the output weights saved to `pytorch_model.bin`. Can be one of ``fp32``, ``fp16``, or ``bf16``. local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally. If the input ``checkpoint_path`` is already a local path, this will be a symlink. @@ -110,7 +113,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( print('#' * 30) print('Saving HF Tokenizer...') hf_tokenizer = get_hf_tokenizer_from_composer_state_dict( - composer_state_dict) + composer_state_dict, trust_remote_code) if hf_tokenizer is not None: hf_tokenizer.save_pretrained(output_path) print(hf_tokenizer) @@ -157,6 +160,10 @@ def parse_args() -> Namespace: default='fp32') parser.add_argument('--hf_repo_for_upload', type=str, default=None) parser.add_argument('--test_uploaded_model', action='store_true') + parser.add_argument( + '--trust_remote_code', + action='store_true', + help='Whether or not to use code outside of transformers module.') return parser.parse_args() @@ -179,6 +186,7 @@ def convert_composer_to_hf(args: Namespace) -> None: config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path=args.composer_path, output_path=local_folder_path, + trust_remote_code=args.trust_remote_code, output_precision=args.output_precision, local_checkpoint_save_location=args.local_checkpoint_save_location) @@ -206,7 +214,9 @@ def convert_composer_to_hf(args: Namespace) -> None: loaded_hf_model.save_pretrained(local_folder_path) print(f'Loading tokenizer from {local_folder_path}') - tokenizer = transformers.AutoTokenizer.from_pretrained(local_folder_path) + + tokenizer = load_tokenizer(local_folder_path, + trust_remote_code=args.trust_remote_code) tokenizer.save_pretrained(local_folder_path) # Only need to edit files for MPT because it has custom code diff --git a/scripts/train/yamls/pretrain/gpt-neo-125m.yaml b/scripts/train/yamls/pretrain/gpt-neo-125m.yaml index cfb447e2e4..12914e14bc 100644 --- a/scripts/train/yamls/pretrain/gpt-neo-125m.yaml +++ b/scripts/train/yamls/pretrain/gpt-neo-125m.yaml @@ -34,7 +34,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -47,7 +46,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false diff --git a/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml b/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml index fc1e3b0b7f..3da239c717 100644 --- a/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml +++ b/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml @@ -34,7 +34,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -47,7 +46,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false diff --git a/scripts/train/yamls/pretrain/gpt2-small.yaml b/scripts/train/yamls/pretrain/gpt2-small.yaml index dde59d55b1..d40cff6e9e 100644 --- a/scripts/train/yamls/pretrain/gpt2-small.yaml +++ b/scripts/train/yamls/pretrain/gpt2-small.yaml @@ -34,7 +34,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -47,7 +46,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false diff --git a/scripts/train/yamls/pretrain/opt-3b.yaml b/scripts/train/yamls/pretrain/opt-3b.yaml index 3ac281f0ea..4423784b54 100644 --- a/scripts/train/yamls/pretrain/opt-3b.yaml +++ b/scripts/train/yamls/pretrain/opt-3b.yaml @@ -27,7 +27,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -40,7 +39,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false diff --git a/setup.py b/setup.py index ba383a4d7f..9967e49146 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.34.1,<4.35', - 'mosaicml-streaming>=0.6,<0.7', + 'mosaicml-streaming>=0.7.1,<0.8', 'torch>=1.13.1,<2.1.1', 'datasets>=2.14.5,<2.15', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py new file mode 100644 index 0000000000..1a46fcbccd --- /dev/null +++ b/tests/test_fsdp_act_checkpoint.py @@ -0,0 +1,73 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from composer import Trainer +from composer.utils import get_device +from omegaconf import OmegaConf as om +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + CheckpointWrapper + +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM + + +@pytest.mark.world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize('activation_checkpointing', [True, False]) +@pytest.mark.parametrize( + 'activation_checkpointing_target', + [[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']]) +def test_fsdp_act_checkpoint(activation_checkpointing: bool, + activation_checkpointing_target: list): + device = get_device('gpu') + model_cfg = { + 'name': 'mpt_causal_lm', + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 2, + 'expansion_ratio': 1, + 'max_seq_len': 16, + 'vocab_size': 50368, + 'attn_config': { + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 2, + }, + 'activation_checkpointing_target': activation_checkpointing_target + } + model_cfg = om.create(model_cfg) + + fsdp_config = { + 'activation_checkpointing': activation_checkpointing, + 'activation_checkpointing_reentrant': False, + 'activation_cpu_offload': False, + } + + model = ComposerMPTCausalLM(model_cfg) + model = device.module_to_device(model) + + trainer = Trainer( + model=model, + device='gpu', + fsdp_config=fsdp_config, + ) + + assert trainer.state.fsdp_enabled + if not activation_checkpointing: + assert not isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0], CheckpointWrapper) + elif (not activation_checkpointing_target + ) or activation_checkpointing_target == [ + 'mptblock', 'grouped_query_attention' + ]: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module, CheckpointWrapper) + elif activation_checkpointing_target == ['grouped_query_attention']: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + else: + raise ValueError( + f'Unknown activation_checkpointing_target: {activation_checkpointing_target}' + ) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d2c2a9e1c9..af94126225 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -248,20 +248,21 @@ def test_callback_inits_with_defaults(): @pytest.mark.world_size(2) @pytest.mark.gpu -@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) +@pytest.mark.parametrize( + 'model,tie_word_embeddings', + [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)], +) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) @patch('os.cpu_count', MagicMock(return_value=None)) -def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, - fsdp_state_dict_type: Optional[str], - log_to_mlflow: bool, - hf_save_interval: str, - save_interval: str, max_duration: str, - expected_hf_checkpoints: int, - expected_normal_checkpoints: int): +def test_huggingface_conversion_callback( + model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool, + fsdp_state_dict_type: Optional[str], log_to_mlflow: bool, + hf_save_interval: str, save_interval: str, max_duration: str, + expected_hf_checkpoints: int, expected_normal_checkpoints: int): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -298,9 +299,11 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, 'attn_impl': 'torch', }, 'loss_fn': 'torch_crossentropy', + 'tie_word_embeddings': tie_word_embeddings, } tokenizer_name = 'EleutherAI/gpt-neox-20b' elif model == 'neo': + assert tie_word_embeddings is None model_cfg = { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'EleutherAI/gpt-neo-125M', @@ -313,6 +316,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, } tokenizer_name = 'EleutherAI/gpt-neo-125M' elif model == 'llama2': + assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.' @@ -489,19 +493,26 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, delete_transformers_cache() -@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) -def test_convert_and_generate(model: str, tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model,tie_word_embeddings', + [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)], +) +def test_convert_and_generate(model: str, tie_word_embeddings: bool, + tmp_path: pathlib.Path): delete_transformers_cache() om_cfg = None if model == 'mpt': om_cfg = get_config( conf_path='scripts/train/yamls/pretrain/testing.yaml') + om_cfg['tie_word_embeddings'] = tie_word_embeddings elif model == 'neo': + assert tie_word_embeddings is None om_cfg = get_config( conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml') om_cfg['model']['config_overrides']['hidden_size'] = 36 elif model == 'llama2': + assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.' @@ -530,6 +541,7 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path): output_precision='fp32', local_checkpoint_save_location=None, hf_repo_for_upload=None, + trust_remote_code=False, test_uploaded_model=False) convert_composer_to_hf(args) @@ -561,11 +573,14 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path): @pytest.mark.gpu -def test_convert_and_generate_triton(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_convert_and_generate_triton(tie_word_embeddings: str, + tmp_path: pathlib.Path): delete_transformers_cache() cfg = get_config() cfg['model']['init_device'] = 'cpu' + cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( 'EleutherAI/gpt-neox-20b') model = ComposerMPTCausalLM(cfg['model'], tokenizer) @@ -577,6 +592,7 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path): output_precision='fp32', local_checkpoint_save_location=None, hf_repo_for_upload=None, + trust_remote_code=False, test_uploaded_model=False) convert_composer_to_hf(args) @@ -600,7 +616,9 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path): delete_transformers_cache() -def test_convert_and_generate_meta(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_convert_and_generate_meta(tie_word_embeddings: str, + tmp_path: pathlib.Path): delete_transformers_cache() from composer.utils import dist @@ -610,6 +628,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path): om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') om_cfg['model']['init_device'] = 'cpu' + om_cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( om_cfg.tokenizer.name) original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( @@ -631,6 +650,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path): output_precision='fp32', local_checkpoint_save_location=None, hf_repo_for_upload=None, + trust_remote_code=False, test_uploaded_model=False) convert_composer_to_hf(args) diff --git a/tests/test_model.py b/tests/test_model.py index 41b62f0ccf..3308c65fd3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -466,7 +466,8 @@ def test_opt_wrapping(): @pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) @pytest.mark.parametrize('no_bias', [False, True]) -def test_mpt_creation(norm_type: str, no_bias: bool): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', @@ -482,6 +483,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool): }, norm_type=norm_type, no_bias=no_bias, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) @@ -493,6 +495,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool): assert mpt.transformer.wte.weight.shape == torch.Size( [hf_config.vocab_size, hf_config.d_model]) + if not tie_word_embeddings: + assert mpt.lm_head is not None + assert mpt.lm_head.weight.shape == mpt.transformer.wte.weight.shape assert mpt.transformer.wpe.weight.shape == torch.Size( [hf_config.max_seq_len, hf_config.d_model]) assert mpt.transformer.emb_drop.p == 0.1 @@ -544,8 +549,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_padding(attention_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, tie_word_embeddings: bool): # Test that different placement of padding does not affect the output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -580,6 +586,7 @@ def test_forward_with_padding(attention_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.eval() @@ -736,10 +743,13 @@ def test_advanced_mask_building(attention_impl: str): assert torch.equal(attn_bias, expected_attn_bias) -@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), - ('flash', 'gpu'), - ('triton', 'gpu'), - ('torch', 'gpu')]) +@pytest.mark.parametrize('attention_impl,device,precision', [ + ('torch', 'cpu', 'fp32'), + ('flash', 'gpu', 'amp_bf16'), + ('triton', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'fp32'), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -766,7 +776,9 @@ def test_advanced_mask_building(attention_impl: str): 'factor': 1.0, }, }]) -def test_generate(attention_impl: str, device: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_generate(attention_impl: str, device: str, precision: str, + pos_emb_config: dict, tie_word_embeddings: bool): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': @@ -780,6 +792,8 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') + if attention_impl == 'torch' and precision == 'amp_bf16' and tie_word_embeddings == False: + pytest.skip(f'This test configuration has precision / sampling issues.') composer_device = get_device(device) @@ -796,10 +810,11 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): 'attn_impl': attention_impl, **pos_emb_config, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) - mpt.eval() mpt = composer_device.module_to_device(mpt) + mpt.eval() # padding on the left of the input left_padding_input_ids = torch.tensor( @@ -830,8 +845,7 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): batched_attention_mask = composer_device.tensor_to_device( batched_attention_mask) - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context(precision): # check that a batch with different amounts of padding doesn't crash # and produces the right output shape batched_generation = mpt.generate(input_ids=batched_input_ids, @@ -861,8 +875,9 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('use_cache', [False, True]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, - use_cache: bool): + use_cache: bool, tie_word_embeddings: bool): if not torch.cuda.is_available(): pytest.skip(f'This test requires CUDA to be available.') if not torch.cuda.device_count() >= world_size: @@ -882,6 +897,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, 'attn_impl': 'torch', }, use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.save_pretrained(save_path) @@ -994,8 +1010,10 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_cache_and_padding(attn_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, + tie_word_embeddings: bool): # Tests that the result is the same with or without padding when using kv caching if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1028,6 +1046,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) @@ -1133,7 +1152,9 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'factor': 1.0, }, }]) -def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict, + tie_word_embeddings: bool): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': @@ -1168,6 +1189,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1237,7 +1259,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): torch.testing.assert_close( second_output.logits, full_output.logits[:, -1, :].unsqueeze(1), - atol=1e-2, + atol=1.1e-2, rtol=1e-2, ) @@ -1274,8 +1296,9 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_past_kv(attn_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, tie_word_embeddings: bool): if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' @@ -1307,6 +1330,7 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1325,7 +1349,8 @@ def test_generate_with_past_kv(attn_impl: str, device: str, with mock.patch.object(MPTForCausalLM, 'forward', autospec=True) as forward_mocked: forward_mocked.return_value = CausalLMOutputWithPast( - logits=torch.randn((1, 3, hf_config.vocab_size)), + logits=composer_device.tensor_to_device( + torch.randn((1, 3, hf_config.vocab_size))), past_key_values=[(torch.randn(1, 3, hf_config.d_model), torch.randn(1, 3, hf_config.d_model)) for _ in range(hf_config.n_layers)]) @@ -1386,9 +1411,11 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kwargs: Dict[str, Any], - pos_emb_config: dict): + pos_emb_config: dict, + tie_word_embeddings: bool): if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' @@ -1417,6 +1444,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, **pos_emb_config, }, use_cache=True, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1467,7 +1495,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, 'factor': 1.0, }, }]) -def test_model_to(attention_impl: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_model_to(attention_impl: str, pos_emb_config: dict, + tie_word_embeddings: bool): # test that moving the model to diff devices and dtypes in diff ways does not break the model if not torch.cuda.is_available(): pytest.skip( @@ -1498,6 +1528,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = mpt.bfloat16() @@ -1600,9 +1631,11 @@ def test_alibi_vs_hf(): }]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( attn_impl: str, device: str, pos_emb_config: dict, - output_attentions: bool, output_hidden_states: bool): + output_attentions: bool, output_hidden_states: bool, + tie_word_embeddings: bool): # Test that model forward with output_attentions_and_output_hidden_states if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1639,6 +1672,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py index c52b765480..413e39bf8c 100644 --- a/tests/test_mpt_gen.py +++ b/tests/test_mpt_gen.py @@ -55,9 +55,11 @@ def forward( @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['triton', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) @patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM', new=MockMPTForCausalLM) def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, + tie_word_embeddings: bool, build_tiny_mpt: Callable[..., ComposerMPTCausalLM], mpt_tokenizer: PreTrainedTokenizerBase): @@ -67,11 +69,14 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, """ device = get_device('gpu') - model = build_tiny_mpt(attn_config={ - 'attn_impl': attn_impl, - 'attn_uses_sequence_id': False, - 'alibi': use_alibi - },) + model = build_tiny_mpt( + tie_word_embeddings=tie_word_embeddings, + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi + }, + ) model = device.module_to_device(model) model.eval() @@ -88,13 +93,25 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, @pytest.mark.gpu -def test_mpt_generate_callback(build_tiny_mpt: Callable[..., +@pytest.mark.parametrize('attn_impl', ['triton', 'torch']) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, + tie_word_embeddings: bool, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], tiny_ft_dataloader: DataLoader): device = get_device('gpu') # build mpt model - model = build_tiny_mpt() + model = build_tiny_mpt( + tie_word_embeddings=tie_word_embeddings, + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi + }, + ) model = device.module_to_device(model) # generate callback diff --git a/tests/test_onnx.py b/tests/test_onnx.py index d0e01746eb..becd3c773f 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -3,6 +3,7 @@ import pathlib +import pytest import torch from transformers import AutoModelForCausalLM @@ -25,7 +26,8 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int): return batch -def test_onnx_export(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_onnx_export(tie_word_embeddings: bool, tmp_path: pathlib.Path): from transformers.models.auto.configuration_auto import CONFIG_MAPPING CONFIG_MAPPING._extra_content['mpt'] = MPTConfig AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM) @@ -48,6 +50,7 @@ def test_onnx_export(tmp_path: pathlib.Path): use_cache=True, vocab_size=vocab_size, norm_type='layernorm', + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.eval()