From d11ba8209bcbd9d1afefa9a468caecdca979c137 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 13 Nov 2023 10:40:11 -0800 Subject: [PATCH 01/11] Make TiktokenTokenizerWrapper compatible with convert_composer_to_hf.py (#730) --- .../utils/checkpoint_conversion_helpers.py | 25 ++++++++++++++++--- .../inference/convert_composer_mpt_to_ft.py | 11 ++++++-- scripts/inference/convert_composer_to_hf.py | 14 +++++++++-- tests/test_hf_conversion_script.py | 3 +++ 4 files changed, 45 insertions(+), 8 deletions(-) diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 0627cec4cd..35e77eab6c 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -19,7 +19,8 @@ import numpy as np import sentencepiece as spm -from transformers import AutoTokenizer, PreTrainedTokenizer +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) log = logging.getLogger(__name__) @@ -35,8 +36,9 @@ def _get_weight_data_type(data_type: str): # TODO: move this functionality to composer once the bug fixes are upstreamed def get_hf_tokenizer_from_composer_state_dict( - state_dict: Dict[str, Any], - tokenizer_save_dir: Optional[str] = None + state_dict: Dict[str, Any], + trust_remote_code: bool, + tokenizer_save_dir: Optional[str] = None, ) -> Optional[PreTrainedTokenizer]: if 'state' not in state_dict: raise RuntimeError( @@ -85,7 +87,8 @@ def get_hf_tokenizer_from_composer_state_dict( with open(tokenizer_file_path, 'wb') as _tmp_file: _tmp_file.write(s.serialized_model_proto()) - hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_dir) + hf_tokenizer = load_tokenizer(tokenizer_save_dir, + trust_remote_code=trust_remote_code) # remove 'name_or_path' hf_tokenizer.name_or_path = '' @@ -94,6 +97,20 @@ def get_hf_tokenizer_from_composer_state_dict( return hf_tokenizer +def load_tokenizer( + tokenizer_save_dir: str, trust_remote_code: bool +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + try: + return AutoTokenizer.from_pretrained( + tokenizer_save_dir, trust_remote_code=trust_remote_code) + except ValueError as e: + raise ValueError( + f'Got error while loading tokenizer with trust_remote_code={trust_remote_code}: {e}. ' + + + 'If accessing a tokenizer defined outside of the transformers module,' + + ' please use --trust_remote_code.') + + def _write_zero_bias(weight_name: str, weight_file_path: str, bias_shape: Union[Tuple[int, ...], int]) -> None: """Write zeros for bias when converting MPT to FasterTransformer weights. diff --git a/scripts/inference/convert_composer_mpt_to_ft.py b/scripts/inference/convert_composer_mpt_to_ft.py index 79275030b3..f59eb6005a 100644 --- a/scripts/inference/convert_composer_mpt_to_ft.py +++ b/scripts/inference/convert_composer_mpt_to_ft.py @@ -67,6 +67,7 @@ def write_ft_checkpoint_from_composer_checkpoint( checkpoint_path: Union[Path, str], infer_gpu_num: int, save_dir: str, + trust_remote_code: bool, output_precision: str = 'fp32', local_checkpoint_save_location: Optional[Union[Path, str]] = None) -> None: @@ -79,6 +80,7 @@ def write_ft_checkpoint_from_composer_checkpoint( checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend supported by Composer. infer_gpu_num (int): The number of gpus you are planning to use for inference. + trust_remote_code (bool): Whether or not to use code outside of the transformers module. save_dir (str): Path of the directory to save the checkpoint in FT format. output_precision (str, optional): The precision of the output weights saved to the FasterTransformer model. Can be either ``fp32`` or ``fp16``. local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally. @@ -125,7 +127,7 @@ def write_ft_checkpoint_from_composer_checkpoint( print('#' * 30) print('Extracting HF Tokenizer...') hf_tokenizer = get_hf_tokenizer_from_composer_state_dict( - composer_state_dict) + composer_state_dict, trust_remote_code) if hf_tokenizer is None: print('Warning! No HF Tokenizer found!') @@ -206,6 +208,10 @@ def parse_args() -> Namespace: 'Data type of weights in the FasterTransformer output model. Input checkpoint weights will be converted to this dtype.', choices=['fp32', 'fp16'], default='fp32') + parser.add_argument( + '--trust_remote_code', + action='store_true', + help='Whether or not to use code outside of transformers module.') return parser.parse_args() @@ -229,4 +235,5 @@ def parse_args() -> Namespace: infer_gpu_num=args.infer_gpu_num, save_dir=save_dir, output_precision=args.output_precision, - local_checkpoint_save_location=args.local_checkpoint_save_location) + local_checkpoint_save_location=args.local_checkpoint_save_location, + trust_remote_code=args.trust_remote_code) diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index 5625a3b046..1b43762473 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -16,6 +16,7 @@ from llmfoundry import MPTConfig, MPTForCausalLM from llmfoundry.utils import get_hf_tokenizer_from_composer_state_dict +from llmfoundry.utils.checkpoint_conversion_helpers import load_tokenizer from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility @@ -23,6 +24,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path: Union[Path, str], output_path: Union[Path, str], + trust_remote_code: bool, output_precision: str = 'fp32', local_checkpoint_save_location: Optional[Union[Path, str]] = None ) -> Tuple[PretrainedConfig, Optional[PreTrainedTokenizerBase]]: @@ -63,6 +65,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend supported by :meth:`composer.utils.maybe_create_object_store_from_uri`. output_path (Union[Path, str]): Path to the folder to write the output to. + trust_remote_code (bool): Whether or not to use code outside of the transformers module. output_precision (str, optional): The precision of the output weights saved to `pytorch_model.bin`. Can be one of ``fp32``, ``fp16``, or ``bf16``. local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally. If the input ``checkpoint_path`` is already a local path, this will be a symlink. @@ -110,7 +113,7 @@ def write_huggingface_pretrained_from_composer_checkpoint( print('#' * 30) print('Saving HF Tokenizer...') hf_tokenizer = get_hf_tokenizer_from_composer_state_dict( - composer_state_dict) + composer_state_dict, trust_remote_code) if hf_tokenizer is not None: hf_tokenizer.save_pretrained(output_path) print(hf_tokenizer) @@ -157,6 +160,10 @@ def parse_args() -> Namespace: default='fp32') parser.add_argument('--hf_repo_for_upload', type=str, default=None) parser.add_argument('--test_uploaded_model', action='store_true') + parser.add_argument( + '--trust_remote_code', + action='store_true', + help='Whether or not to use code outside of transformers module.') return parser.parse_args() @@ -179,6 +186,7 @@ def convert_composer_to_hf(args: Namespace) -> None: config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path=args.composer_path, output_path=local_folder_path, + trust_remote_code=args.trust_remote_code, output_precision=args.output_precision, local_checkpoint_save_location=args.local_checkpoint_save_location) @@ -206,7 +214,9 @@ def convert_composer_to_hf(args: Namespace) -> None: loaded_hf_model.save_pretrained(local_folder_path) print(f'Loading tokenizer from {local_folder_path}') - tokenizer = transformers.AutoTokenizer.from_pretrained(local_folder_path) + + tokenizer = load_tokenizer(local_folder_path, + trust_remote_code=args.trust_remote_code) tokenizer.save_pretrained(local_folder_path) # Only need to edit files for MPT because it has custom code diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d2c2a9e1c9..6d5a282993 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -530,6 +530,7 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path): output_precision='fp32', local_checkpoint_save_location=None, hf_repo_for_upload=None, + trust_remote_code=False, test_uploaded_model=False) convert_composer_to_hf(args) @@ -577,6 +578,7 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path): output_precision='fp32', local_checkpoint_save_location=None, hf_repo_for_upload=None, + trust_remote_code=False, test_uploaded_model=False) convert_composer_to_hf(args) @@ -631,6 +633,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path): output_precision='fp32', local_checkpoint_save_location=None, hf_repo_for_upload=None, + trust_remote_code=False, test_uploaded_model=False) convert_composer_to_hf(args) From 789917883f58578df34a62a4895341728098d2be Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Mon, 13 Nov 2023 13:25:44 -0800 Subject: [PATCH 02/11] Enable `tie_word_embeddings` config setting to enable / disable weight tied embeddings (#728) * enable disabling embed weight tying * fix bug * updt with descriptive var names * fix hf config * move comment with code * bug fix * add _tie_weights method * undo mcli yaml change * refactor * add tests * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Sasha Doubov * pr comments * updt tests to guard against numerical issues --------- Co-authored-by: Sasha Doubov --- llmfoundry/models/mpt/configuration_mpt.py | 8 ++- llmfoundry/models/mpt/modeling_mpt.py | 72 ++++++++++++++++------ tests/test_hf_conversion_script.py | 41 ++++++++---- tests/test_model.py | 72 ++++++++++++++++------ tests/test_mpt_gen.py | 31 +++++++--- tests/test_onnx.py | 5 +- 6 files changed, 169 insertions(+), 60 deletions(-) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index c4ca68d733..c0a1e65248 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -59,6 +59,7 @@ def __init__( use_cache: bool = False, init_config: Dict = init_config_defaults, fc_type: str = 'torch', + tie_word_embeddings: bool = True, verbose: Optional[int] = None, **kwargs: Any, ): @@ -128,6 +129,7 @@ def __init__( --- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. + tie_word_embeddings (bool): Whether to tie the input embedding and output layers. """ self.d_model = d_model self.n_heads = n_heads @@ -164,7 +166,11 @@ def __init__( warnings.warn( f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`' ) - super().__init__(**kwargs) + # tie_word_embeddings is set in Huggingface's PretrainedConfig __init__ + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self._validate_config() diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0cb3ebd56c..10c042d27c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -231,10 +231,11 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') - def get_input_embeddings(self) -> nn.Embedding: + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.wte - def set_input_embeddings(self, value: nn.Embedding) -> None: + def set_input_embeddings( + self, value: Union[SharedEmbedding, nn.Embedding]) -> None: self.wte = value @torch.no_grad() @@ -574,14 +575,20 @@ class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config: MPTConfig): super().__init__(config) - if not config.tie_word_embeddings: - raise ValueError( - 'MPTForCausalLM only supports tied word embeddings') - log.info(f'Instantiating an MPTForCausalLM model from {__file__}') self.transformer: MPTModel = MPTModel(config) + self.lm_head = None + if not config.tie_word_embeddings: + self.lm_head = nn.Linear( + config.d_model, + config.vocab_size, + bias=False, + device=config.init_device, + ) + self.lm_head._fsdp_wrap = True + for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): continue @@ -602,19 +609,38 @@ def __init__(self, config: MPTConfig): ) self.logit_scale = logit_scale - def get_input_embeddings(self) -> nn.Embedding: - return self.transformer.wte + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: + return self.transformer.get_input_embeddings() def set_input_embeddings( self, value: Union[SharedEmbedding, nn.Embedding]) -> None: - self.transformer.wte = value + self.transformer.set_input_embeddings(value) - def get_output_embeddings(self) -> nn.Embedding: - return self.transformer.wte + def get_output_embeddings( + self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]: + if self.lm_head is not None: + return self.lm_head + return self.transformer.get_input_embeddings() def set_output_embeddings( - self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None: - self.transformer.wte = new_embeddings + self, new_embeddings: Union[SharedEmbedding, nn.Embedding, + nn.Linear]) -> None: + if self.lm_head is not None: + self.lm_head = new_embeddings + else: + if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)): + raise ValueError( + 'new_embeddings must be an instance of SharedEmbedding ' + + f'or nn.Embedding, but got {type(new_embeddings)}.') + warnings.warn( + 'Using `set_output_embeddings` to set the embedding layer of ' + + 'MPTForCausalLM with tied weights. Given weights are tied, ' + + 'using `set_input_embeddings` is recommended over using ' + + '`set_output_embeddings`.') + self.transformer.set_input_embeddings(new_embeddings) + + def tie_weights(self) -> None: + self.lm_head = None def set_decoder(self, decoder: MPTModel) -> None: self.transformer = decoder @@ -658,12 +684,14 @@ def forward( use_cache=use_cache, ) - # move outputs to same device as weights for token embedding - # needed to support HF `device_map` - logits = self.transformer.wte( - outputs.last_hidden_state.to(self.transformer.wte.weight.device), - True, - ) + if self.lm_head is not None: + logits = self.lm_head(outputs.last_hidden_state) + else: + # move outputs to same device as weights for token embedding + # needed to support HF `device_map` + out = outputs.last_hidden_state + out = out.to(self.transformer.wte.weight.device) + logits = self.transformer.wte(out, True) if self.logit_scale is not None: if self.logit_scale == 0: @@ -859,7 +887,11 @@ def flops_per_batch(self, batch: Mapping) -> int: # assume the backward pass is approximately 2x the forward pass bs, msl = batch['input_ids'].shape[0:2] - params_flops_per_token = 2 * self.n_active_params + params = self.n_active_params + if not self.model.transformer.config.tie_word_embeddings: + # embedding layers are lookup tables, therefore are not counted in the FLOP computation + params -= self.model.transformer.wte.weight.numel() + params_flops_per_token = 2 * params params_flops_per_seq = params_flops_per_token * msl attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 * (self.model.config.d_model * (msl**2))) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 6d5a282993..af94126225 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -248,20 +248,21 @@ def test_callback_inits_with_defaults(): @pytest.mark.world_size(2) @pytest.mark.gpu -@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) +@pytest.mark.parametrize( + 'model,tie_word_embeddings', + [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)], +) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) @patch('os.cpu_count', MagicMock(return_value=None)) -def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, - fsdp_state_dict_type: Optional[str], - log_to_mlflow: bool, - hf_save_interval: str, - save_interval: str, max_duration: str, - expected_hf_checkpoints: int, - expected_normal_checkpoints: int): +def test_huggingface_conversion_callback( + model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool, + fsdp_state_dict_type: Optional[str], log_to_mlflow: bool, + hf_save_interval: str, save_interval: str, max_duration: str, + expected_hf_checkpoints: int, expected_normal_checkpoints: int): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -298,9 +299,11 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, 'attn_impl': 'torch', }, 'loss_fn': 'torch_crossentropy', + 'tie_word_embeddings': tie_word_embeddings, } tokenizer_name = 'EleutherAI/gpt-neox-20b' elif model == 'neo': + assert tie_word_embeddings is None model_cfg = { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'EleutherAI/gpt-neo-125M', @@ -313,6 +316,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, } tokenizer_name = 'EleutherAI/gpt-neo-125M' elif model == 'llama2': + assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.' @@ -489,19 +493,26 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, delete_transformers_cache() -@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) -def test_convert_and_generate(model: str, tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model,tie_word_embeddings', + [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)], +) +def test_convert_and_generate(model: str, tie_word_embeddings: bool, + tmp_path: pathlib.Path): delete_transformers_cache() om_cfg = None if model == 'mpt': om_cfg = get_config( conf_path='scripts/train/yamls/pretrain/testing.yaml') + om_cfg['tie_word_embeddings'] = tie_word_embeddings elif model == 'neo': + assert tie_word_embeddings is None om_cfg = get_config( conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml') om_cfg['model']['config_overrides']['hidden_size'] = 36 elif model == 'llama2': + assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.' @@ -562,11 +573,14 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path): @pytest.mark.gpu -def test_convert_and_generate_triton(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_convert_and_generate_triton(tie_word_embeddings: str, + tmp_path: pathlib.Path): delete_transformers_cache() cfg = get_config() cfg['model']['init_device'] = 'cpu' + cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( 'EleutherAI/gpt-neox-20b') model = ComposerMPTCausalLM(cfg['model'], tokenizer) @@ -602,7 +616,9 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path): delete_transformers_cache() -def test_convert_and_generate_meta(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_convert_and_generate_meta(tie_word_embeddings: str, + tmp_path: pathlib.Path): delete_transformers_cache() from composer.utils import dist @@ -612,6 +628,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path): om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') om_cfg['model']['init_device'] = 'cpu' + om_cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( om_cfg.tokenizer.name) original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( diff --git a/tests/test_model.py b/tests/test_model.py index 41b62f0ccf..3308c65fd3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -466,7 +466,8 @@ def test_opt_wrapping(): @pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) @pytest.mark.parametrize('no_bias', [False, True]) -def test_mpt_creation(norm_type: str, no_bias: bool): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', @@ -482,6 +483,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool): }, norm_type=norm_type, no_bias=no_bias, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) @@ -493,6 +495,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool): assert mpt.transformer.wte.weight.shape == torch.Size( [hf_config.vocab_size, hf_config.d_model]) + if not tie_word_embeddings: + assert mpt.lm_head is not None + assert mpt.lm_head.weight.shape == mpt.transformer.wte.weight.shape assert mpt.transformer.wpe.weight.shape == torch.Size( [hf_config.max_seq_len, hf_config.d_model]) assert mpt.transformer.emb_drop.p == 0.1 @@ -544,8 +549,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_padding(attention_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, tie_word_embeddings: bool): # Test that different placement of padding does not affect the output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -580,6 +586,7 @@ def test_forward_with_padding(attention_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.eval() @@ -736,10 +743,13 @@ def test_advanced_mask_building(attention_impl: str): assert torch.equal(attn_bias, expected_attn_bias) -@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), - ('flash', 'gpu'), - ('triton', 'gpu'), - ('torch', 'gpu')]) +@pytest.mark.parametrize('attention_impl,device,precision', [ + ('torch', 'cpu', 'fp32'), + ('flash', 'gpu', 'amp_bf16'), + ('triton', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'fp32'), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -766,7 +776,9 @@ def test_advanced_mask_building(attention_impl: str): 'factor': 1.0, }, }]) -def test_generate(attention_impl: str, device: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_generate(attention_impl: str, device: str, precision: str, + pos_emb_config: dict, tie_word_embeddings: bool): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': @@ -780,6 +792,8 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') + if attention_impl == 'torch' and precision == 'amp_bf16' and tie_word_embeddings == False: + pytest.skip(f'This test configuration has precision / sampling issues.') composer_device = get_device(device) @@ -796,10 +810,11 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): 'attn_impl': attention_impl, **pos_emb_config, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) - mpt.eval() mpt = composer_device.module_to_device(mpt) + mpt.eval() # padding on the left of the input left_padding_input_ids = torch.tensor( @@ -830,8 +845,7 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): batched_attention_mask = composer_device.tensor_to_device( batched_attention_mask) - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context(precision): # check that a batch with different amounts of padding doesn't crash # and produces the right output shape batched_generation = mpt.generate(input_ids=batched_input_ids, @@ -861,8 +875,9 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('use_cache', [False, True]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, - use_cache: bool): + use_cache: bool, tie_word_embeddings: bool): if not torch.cuda.is_available(): pytest.skip(f'This test requires CUDA to be available.') if not torch.cuda.device_count() >= world_size: @@ -882,6 +897,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, 'attn_impl': 'torch', }, use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.save_pretrained(save_path) @@ -994,8 +1010,10 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_cache_and_padding(attn_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, + tie_word_embeddings: bool): # Tests that the result is the same with or without padding when using kv caching if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1028,6 +1046,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) @@ -1133,7 +1152,9 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'factor': 1.0, }, }]) -def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict, + tie_word_embeddings: bool): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': @@ -1168,6 +1189,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1237,7 +1259,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): torch.testing.assert_close( second_output.logits, full_output.logits[:, -1, :].unsqueeze(1), - atol=1e-2, + atol=1.1e-2, rtol=1e-2, ) @@ -1274,8 +1296,9 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_past_kv(attn_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, tie_word_embeddings: bool): if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' @@ -1307,6 +1330,7 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1325,7 +1349,8 @@ def test_generate_with_past_kv(attn_impl: str, device: str, with mock.patch.object(MPTForCausalLM, 'forward', autospec=True) as forward_mocked: forward_mocked.return_value = CausalLMOutputWithPast( - logits=torch.randn((1, 3, hf_config.vocab_size)), + logits=composer_device.tensor_to_device( + torch.randn((1, 3, hf_config.vocab_size))), past_key_values=[(torch.randn(1, 3, hf_config.d_model), torch.randn(1, 3, hf_config.d_model)) for _ in range(hf_config.n_layers)]) @@ -1386,9 +1411,11 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kwargs: Dict[str, Any], - pos_emb_config: dict): + pos_emb_config: dict, + tie_word_embeddings: bool): if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' @@ -1417,6 +1444,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, **pos_emb_config, }, use_cache=True, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1467,7 +1495,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, 'factor': 1.0, }, }]) -def test_model_to(attention_impl: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_model_to(attention_impl: str, pos_emb_config: dict, + tie_word_embeddings: bool): # test that moving the model to diff devices and dtypes in diff ways does not break the model if not torch.cuda.is_available(): pytest.skip( @@ -1498,6 +1528,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = mpt.bfloat16() @@ -1600,9 +1631,11 @@ def test_alibi_vs_hf(): }]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( attn_impl: str, device: str, pos_emb_config: dict, - output_attentions: bool, output_hidden_states: bool): + output_attentions: bool, output_hidden_states: bool, + tie_word_embeddings: bool): # Test that model forward with output_attentions_and_output_hidden_states if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1639,6 +1672,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py index c52b765480..413e39bf8c 100644 --- a/tests/test_mpt_gen.py +++ b/tests/test_mpt_gen.py @@ -55,9 +55,11 @@ def forward( @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['triton', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) @patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM', new=MockMPTForCausalLM) def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, + tie_word_embeddings: bool, build_tiny_mpt: Callable[..., ComposerMPTCausalLM], mpt_tokenizer: PreTrainedTokenizerBase): @@ -67,11 +69,14 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, """ device = get_device('gpu') - model = build_tiny_mpt(attn_config={ - 'attn_impl': attn_impl, - 'attn_uses_sequence_id': False, - 'alibi': use_alibi - },) + model = build_tiny_mpt( + tie_word_embeddings=tie_word_embeddings, + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi + }, + ) model = device.module_to_device(model) model.eval() @@ -88,13 +93,25 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, @pytest.mark.gpu -def test_mpt_generate_callback(build_tiny_mpt: Callable[..., +@pytest.mark.parametrize('attn_impl', ['triton', 'torch']) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, + tie_word_embeddings: bool, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], tiny_ft_dataloader: DataLoader): device = get_device('gpu') # build mpt model - model = build_tiny_mpt() + model = build_tiny_mpt( + tie_word_embeddings=tie_word_embeddings, + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi + }, + ) model = device.module_to_device(model) # generate callback diff --git a/tests/test_onnx.py b/tests/test_onnx.py index d0e01746eb..becd3c773f 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -3,6 +3,7 @@ import pathlib +import pytest import torch from transformers import AutoModelForCausalLM @@ -25,7 +26,8 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int): return batch -def test_onnx_export(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_onnx_export(tie_word_embeddings: bool, tmp_path: pathlib.Path): from transformers.models.auto.configuration_auto import CONFIG_MAPPING CONFIG_MAPPING._extra_content['mpt'] = MPTConfig AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM) @@ -48,6 +50,7 @@ def test_onnx_export(tmp_path: pathlib.Path): use_cache=True, vocab_size=vocab_size, norm_type='layernorm', + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.eval() From 8ba697cec6560fa8adaddc779b6d3ed2ff4adb36 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Mon, 13 Nov 2023 14:13:23 -0800 Subject: [PATCH 03/11] add act checkpoint at sub layer level (#720) * add act checkpoint at sub layer level * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Mihir Patel * address comments * addess coments * add log info * fix pyright * refactor * better log info and error msg * add test * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Mihir Patel * remove unneeded comments --------- Co-authored-by: Mihir Patel Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/models/mpt/modeling_mpt.py | 34 ++++++++++++- tests/test_fsdp_act_checkpoint.py | 73 +++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 tests/test_fsdp_act_checkpoint.py diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 10c042d27c..274c1b76e5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -45,7 +45,9 @@ from transformers.models.llama.modeling_llama import \ LlamaRotaryEmbedding as HFRotaryEmbedding -from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias +from llmfoundry.models.layers.attention import (ATTN_CLASS_REGISTRY, + attn_bias_shape, + build_attn_bias) from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY @@ -733,7 +735,35 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: # Activation Checkpointing def activation_checkpointing_fn(self, module: nn.Module) -> bool: - return isinstance(module, MPTBlock) + act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', + None) or ['MPTBlock'] + + if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: + if len(act_ckpt_list) > 1: + log.info( + 'Activation checkpointing MPTBlock only (ignoring other sub-block modules specified in activation_checkpointing_target).' + ) + return isinstance(module, MPTBlock) + + mod_types = () + for mod_name in act_ckpt_list: + if mod_name.lower() == 'mptblock': + mod_types += (MPTBlock,) + elif mod_name in ATTN_CLASS_REGISTRY: + mod_types += (ATTN_CLASS_REGISTRY[mod_name],) + elif mod_name in FFN_CLASS_REGISTRY: + mod_types += (FFN_CLASS_REGISTRY[mod_name],) + elif mod_name in NORM_CLASS_REGISTRY: + mod_types += (NORM_CLASS_REGISTRY[mod_name],) + else: + msg = ', '.join( + list(ATTN_CLASS_REGISTRY.keys()) + + list(FFN_CLASS_REGISTRY.keys()) + + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) + raise ValueError( + f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.' + ) + return isinstance(module, mod_types) def prepare_inputs_for_generation( self, diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py new file mode 100644 index 0000000000..1a46fcbccd --- /dev/null +++ b/tests/test_fsdp_act_checkpoint.py @@ -0,0 +1,73 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from composer import Trainer +from composer.utils import get_device +from omegaconf import OmegaConf as om +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \ + CheckpointWrapper + +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM + + +@pytest.mark.world_size(2) +@pytest.mark.gpu +@pytest.mark.parametrize('activation_checkpointing', [True, False]) +@pytest.mark.parametrize( + 'activation_checkpointing_target', + [[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']]) +def test_fsdp_act_checkpoint(activation_checkpointing: bool, + activation_checkpointing_target: list): + device = get_device('gpu') + model_cfg = { + 'name': 'mpt_causal_lm', + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 2, + 'expansion_ratio': 1, + 'max_seq_len': 16, + 'vocab_size': 50368, + 'attn_config': { + 'attn_type': 'grouped_query_attention', + 'kv_n_heads': 2, + }, + 'activation_checkpointing_target': activation_checkpointing_target + } + model_cfg = om.create(model_cfg) + + fsdp_config = { + 'activation_checkpointing': activation_checkpointing, + 'activation_checkpointing_reentrant': False, + 'activation_cpu_offload': False, + } + + model = ComposerMPTCausalLM(model_cfg) + model = device.module_to_device(model) + + trainer = Trainer( + model=model, + device='gpu', + fsdp_config=fsdp_config, + ) + + assert trainer.state.fsdp_enabled + if not activation_checkpointing: + assert not isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0], CheckpointWrapper) + elif (not activation_checkpointing_target + ) or activation_checkpointing_target == [ + 'mptblock', 'grouped_query_attention' + ]: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module, CheckpointWrapper) + elif activation_checkpointing_target == ['grouped_query_attention']: + assert isinstance( + trainer.state.model.model._fsdp_wrapped_module.transformer. + blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) + else: + raise ValueError( + f'Unknown activation_checkpointing_target: {activation_checkpointing_target}' + ) From d1960f2ca842397bcb39d1bd13139b363c21641e Mon Sep 17 00:00:00 2001 From: snarayan21 Date: Mon, 13 Nov 2023 14:35:41 -0800 Subject: [PATCH 04/11] Better defaults for StreamingDataset subclasses (#723) --- llmfoundry/data/denoising.py | 6 +++--- llmfoundry/data/finetuning/dataloader.py | 6 +++--- llmfoundry/data/finetuning/tasks.py | 22 +++++++++++--------- llmfoundry/data/text_data.py | 26 ++++++++++++++---------- setup.py | 2 +- 5 files changed, 34 insertions(+), 28 deletions(-) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 7d497b4efd..8ccf7f25e9 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -477,13 +477,13 @@ def build_text_denoising_dataloader( remote=cfg.dataset.get('remote'), split=cfg.dataset.get('split'), shuffle=cfg.dataset.get('shuffle', False), - predownload=cfg.dataset.get('predownload', 100_000), + predownload=cfg.dataset.get('predownload', None), keep_zip=cfg.dataset.get('keep_zip', False), download_retry=cfg.dataset.get('download_retry', 2), download_timeout=cfg.dataset.get('download_timeout', 60), - validate_hash=cfg.dataset.get('validate_hash'), + validate_hash=cfg.dataset.get('validate_hash', None), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), - num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', 128), + num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, ) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 44d6d345f5..b19cab841f 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -136,13 +136,13 @@ def build_finetuning_dataloader(cfg: DictConfig, epoch_size=cfg.dataset.get('epoch_size', None), predownload=cfg.dataset.get('predownload', None), cache_limit=cfg.dataset.get('cache_limit', None), - partition_algo=cfg.dataset.get('partition_algo', 'orig'), + partition_algo=cfg.dataset.get('partition_algo', 'relaxed'), num_canonical_nodes=cfg.dataset.get('num_canonical_nodes', None), batch_size=device_batch_size, shuffle=cfg.dataset.get('shuffle', False), - shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1b'), + shuffle_algo=cfg.dataset.get('shuffle_algo', 'py1e'), shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), - shuffle_block_size=cfg.dataset.get('shuffle_block_size', 1 << 18), + shuffle_block_size=cfg.dataset.get('shuffle_block_size', None), sampling_method=cfg.dataset.get('sampling_method', 'balanced'), sampling_granularity=cfg.dataset.get('sampling_granularity', 1), batching_method=cfg.dataset.get('batching_method', 'random'), diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 6ba6ad96c8..bc712a7504 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -88,12 +88,12 @@ class StreamingFinetuningDataset(StreamingDataset): keep_zip (bool): Whether to keep or delete the compressed form when decompressing downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to `False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all + epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all streams. If ``None``, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger or smaller epoch size. Defaults to ``None``. predownload (int, optional): Target number of samples ahead to download the shards of while - iterating. Defaults to ``100_000``. + iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to None @@ -101,15 +101,17 @@ class StreamingFinetuningDataset(StreamingDataset): bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None. partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. Defaults to ``None``, which is interpreted as the number of nodes of the - initial run. + resumption. If ``None``, this is interpreted as 64 times the number of physical + nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the + number of physical nodes of the initial run otherwise. Defaults to ``None``. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + shuffle_block_size (int): Unit of shuffle. If ``None``, its value is calculated as + ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to ``None``. sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``. sampling_granularity (int): When picking samples for a stream's final partial repeat, @@ -129,16 +131,16 @@ def __init__(self, download_timeout: float = 60, validate_hash: Optional[str] = None, keep_zip: bool = False, - epoch_size: Optional[int] = None, + epoch_size: Optional[Union[int, str]] = None, predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, - partition_algo: str = 'orig', + partition_algo: str = 'relaxed', num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, shuffle: bool = False, - shuffle_algo: str = 'py1b', + shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, + shuffle_block_size: Optional[int] = None, sampling_method: str = 'balanced', sampling_granularity: int = 1, batching_method: str = 'random', diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 93af2f63ed..51fd6b38dc 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -46,12 +46,12 @@ class StreamingTextDataset(StreamingDataset): keep_zip (bool): Whether to keep or delete the compressed form when decompressing downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to `False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all + epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all streams. If ``None``, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger or smaller epoch size. Defaults to ``None``. predownload (int, optional): Target number of samples ahead to download the shards of while - iterating. Defaults to ``100_000``. + iterating. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. cache_limit (Union[int, str], optional) - Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to None @@ -59,15 +59,19 @@ class StreamingTextDataset(StreamingDataset): bytes (e.g., 100b, 64kb, 77mb, and so on). Defaults to None. partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. Defaults to ``None``, which is interpreted as the number of nodes of the - initial run. + resumption. If ``None``, this is interpreted as 64 times the number of physical + nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the + number of physical nodes of the initial run otherwise. Defaults to ``None``. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1b``. + shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split + into blocks of this size, and samples within each block are shuffled. If ``None``, its + value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to + ``None``. sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``. Defaults to ``balanced``. sampling_granularity (int): When picking samples for a stream's final partial repeat, @@ -89,16 +93,16 @@ def __init__(self, download_timeout: float = 60, validate_hash: Optional[str] = None, keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: int = 100_000, + epoch_size: Optional[Union[int, str]] = None, + predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, - partition_algo: str = 'orig', + partition_algo: str = 'relaxed', num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None, shuffle: bool = False, - shuffle_algo: str = 'py1b', + shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, + shuffle_block_size: Optional[int] = None, sampling_method: str = 'balanced', sampling_granularity: int = 1, batching_method: str = 'random', diff --git a/setup.py b/setup.py index 81178686d2..05d1d1bbbe 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.34.1,<4.35', - 'mosaicml-streaming>=0.6,<0.7', + 'mosaicml-streaming>=0.7.1,<0.8', 'torch>=1.13.1,<2.1.1', 'datasets>=2.14.5,<2.15', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data From 753feff96801a8959d22477b7857422076f6b4dc Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Mon, 13 Nov 2023 22:34:18 -0800 Subject: [PATCH 05/11] Rename log message (#734) --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index e02bf03693..788a8943b1 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -204,7 +204,7 @@ def _save_checkpoint(self, state: State, logger: Logger): state_dict[k] = v.to(dtype=self.dtype) if dist.get_global_rank() == 0: - log.debug('Saving Hugging Face checkpoint to disk') + log.debug('Saving Hugging Face checkpoint in global rank 0') copied_config = copy.deepcopy(original_model.config) if copied_config.model_type == 'mpt': From 45113ebf4ef2ad3714c1a9b51d9cca79bcafb921 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 14 Nov 2023 10:48:56 -0800 Subject: [PATCH 06/11] remove tokenizer_name field (#735) --- scripts/inference/benchmarking/yamls/1b.yaml | 1 - scripts/inference/benchmarking/yamls/7b.yaml | 1 - scripts/train/yamls/pretrain/gpt-neo-125m.yaml | 2 -- scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml | 2 -- scripts/train/yamls/pretrain/gpt2-small.yaml | 2 -- scripts/train/yamls/pretrain/opt-3b.yaml | 2 -- 6 files changed, 10 deletions(-) diff --git a/scripts/inference/benchmarking/yamls/1b.yaml b/scripts/inference/benchmarking/yamls/1b.yaml index f94aa3d806..d1cfb3c913 100644 --- a/scripts/inference/benchmarking/yamls/1b.yaml +++ b/scripts/inference/benchmarking/yamls/1b.yaml @@ -12,7 +12,6 @@ tokenizer: model: name: mpt_causal_lm init_device: cpu - tokenizer_name: ${tokenizer_name} d_model: 2048 n_heads: 16 # Modified 24->16 so that d_head == 128 to statisfy FlashAttention n_layers: 24 diff --git a/scripts/inference/benchmarking/yamls/7b.yaml b/scripts/inference/benchmarking/yamls/7b.yaml index 55e9ae8413..f57ed2657f 100644 --- a/scripts/inference/benchmarking/yamls/7b.yaml +++ b/scripts/inference/benchmarking/yamls/7b.yaml @@ -12,7 +12,6 @@ tokenizer: model: name: mpt_causal_lm init_device: cpu - tokenizer_name: ${tokenizer_name} d_model: 4096 n_heads: 32 n_layers: 32 diff --git a/scripts/train/yamls/pretrain/gpt-neo-125m.yaml b/scripts/train/yamls/pretrain/gpt-neo-125m.yaml index cfb447e2e4..12914e14bc 100644 --- a/scripts/train/yamls/pretrain/gpt-neo-125m.yaml +++ b/scripts/train/yamls/pretrain/gpt-neo-125m.yaml @@ -34,7 +34,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -47,7 +46,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false diff --git a/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml b/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml index fc1e3b0b7f..3da239c717 100644 --- a/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml +++ b/scripts/train/yamls/pretrain/gpt-neo-125m_eval.yaml @@ -34,7 +34,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -47,7 +46,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false diff --git a/scripts/train/yamls/pretrain/gpt2-small.yaml b/scripts/train/yamls/pretrain/gpt2-small.yaml index dde59d55b1..d40cff6e9e 100644 --- a/scripts/train/yamls/pretrain/gpt2-small.yaml +++ b/scripts/train/yamls/pretrain/gpt2-small.yaml @@ -34,7 +34,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -47,7 +46,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false diff --git a/scripts/train/yamls/pretrain/opt-3b.yaml b/scripts/train/yamls/pretrain/opt-3b.yaml index 3ac281f0ea..4423784b54 100644 --- a/scripts/train/yamls/pretrain/opt-3b.yaml +++ b/scripts/train/yamls/pretrain/opt-3b.yaml @@ -27,7 +27,6 @@ train_loader: remote: ${data_remote} split: train shuffle: true - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: true @@ -40,7 +39,6 @@ eval_loader: remote: ${data_remote} split: val shuffle: false - tokenizer_name: ${tokenizer_name} max_seq_len: ${max_seq_len} shuffle_seed: ${global_seed} drop_last: false From f114dad550d82c82fc763262fb73be62a21ba810 Mon Sep 17 00:00:00 2001 From: Sasha Doubov Date: Wed, 15 Nov 2023 08:48:39 -0800 Subject: [PATCH 07/11] Fix pairwise attention comparison in test (#737) --- tests/test_flash_triton_torch.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 3f2c229d6d..1ede36c0b5 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -74,7 +74,7 @@ def test_attn_impl(attn_impl_0: str, cfg = om.create({ 'attn_impl': 'flash', - 'd_model': 128, + 'd_model': 64, 'n_heads': 4, 'attn_pdrop': 0, 'clip_qkv': clip_qkv, @@ -88,6 +88,7 @@ def test_attn_impl(attn_impl_0: str, cfg.attn_impl = attn_impl_0 attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + cfg.attn_impl = attn_impl_1 attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) attn1.load_state_dict(attn0.state_dict()) @@ -182,7 +183,15 @@ def gen_bias(attn_impl: str): assert p.grad is not None assert tp.grad is not None assert allclose_helper(p, tp) - assert allclose_helper(p.grad, tp.grad) + + using_hf_rope = pos_emb_config['rope'] and pos_emb_config[ + 'rope_impl'] == 'hf' + + # special case that (likely) fails due to numerics + if clip_qkv and qk_ln and using_hf_rope and attn_type == 'grouped_query_attention': + assert allclose_helper(p.grad, tp.grad, atol=2.e-2, rtol=2.e-2) + else: + assert allclose_helper(p.grad, tp.grad) assert x0.grad is not None assert x1.grad is not None From db279d092befc38f8219c0a3bffb1542681c034a Mon Sep 17 00:00:00 2001 From: Wenfei Yan <87323464+wenfeiy-db@users.noreply.github.com> Date: Wed, 15 Nov 2023 11:03:16 -0800 Subject: [PATCH 08/11] Fix passed metadata to mlflow logging (#713) --- llmfoundry/callbacks/hf_checkpointer.py | 14 +++++------ llmfoundry/utils/builders.py | 5 +++- tests/test_builders.py | 32 ++++++++++++++++++++++++ tests/test_hf_conversion_script.py | 33 ++++++++++++++++++++----- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 788a8943b1..c79537c781 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -74,12 +74,13 @@ def __init__( if self.mlflow_registered_model_name is not None: # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work - if 'metadata' not in mlflow_logging_config: - mlflow_logging_config['metadata'] = { - 'task': 'llm/v1/completions' - } - if 'task' not in mlflow_logging_config: - mlflow_logging_config['task'] = 'text-generation' + default_metadata = {'task': 'llm/v1/completions'} + passed_metadata = mlflow_logging_config.get('metadata', {}) + mlflow_logging_config['metadata'] = { + **default_metadata, + **passed_metadata + } + mlflow_logging_config.setdefault('task', 'text-generation') self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join( @@ -93,7 +94,6 @@ def __init__( self.save_interval = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) - self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( save_folder, loggers=[]) if self.remote_ud is not None: diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 142e714b55..dedf6f5434 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -73,7 +73,8 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: +def build_callback(name: str, kwargs: Union[DictConfig, Dict[str, + Any]]) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -117,6 +118,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'early_stopper': return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': + if isinstance(kwargs, DictConfig): + kwargs = om.to_object(kwargs) # pyright: ignore return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index 0d24d2154f..237e27b52b 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -6,8 +6,10 @@ import pytest from composer.callbacks import Generate +from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase +from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_callback, build_tokenizer @@ -78,3 +80,33 @@ def test_build_generate_callback_unspecified_interval(): 'foo': 'bar', 'something': 'else', }) + + +def test_build_hf_checkpointer_callback(): + with mock.patch.object(HuggingFaceCheckpointer, + '__init__') as mock_hf_checkpointer: + mock_hf_checkpointer.return_value = None + save_folder = 'path_to_save_folder' + save_interval = 1 + mlflow_logging_config_dict = { + 'metadata': { + 'databricks_model_family': 'MptForCausalLM', + 'databricks_model_size_parameters': '7b', + 'databricks_model_source': 'mosaic-fine-tuning', + 'task': 'llm/v1/completions' + } + } + build_callback(name='hf_checkpointer', + kwargs=om.create({ + 'save_folder': save_folder, + 'save_interval': save_interval, + 'mlflow_logging_config': mlflow_logging_config_dict + })) + + assert mock_hf_checkpointer.call_count == 1 + _, _, kwargs = mock_hf_checkpointer.mock_calls[0] + assert kwargs['save_folder'] == save_folder + assert kwargs['save_interval'] == save_interval + assert isinstance(kwargs['mlflow_logging_config'], dict) + assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict) + assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index af94126225..dcb743b536 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger @@ -242,9 +242,22 @@ def get_config( return cast(DictConfig, test_cfg) -def test_callback_inits_with_defaults(): +def test_callback_inits(): + # test with defaults _ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba') + # test default metatdata when mlflow registered name is given + hf_checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + mlflow_registered_model_name='test_model_name') + assert hf_checkpointer.mlflow_logging_config == { + 'task': 'text-generation', + 'metadata': { + 'task': 'llm/v1/completions' + } + } + @pytest.mark.world_size(2) @pytest.mark.gpu @@ -425,10 +438,18 @@ def test_huggingface_conversion_callback( trainer.fit() if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - else 0) - assert mlflow_logger_mock.register_model.call_count == ( - 1 if log_to_mlflow else 0) + if log_to_mlflow: + assert mlflow_logger_mock.save_model.call_count == 1 + mlflow_logger_mock.save_model.assert_called_with( + flavor='transformers', + transformers_model=ANY, + path=ANY, + task='text-generation', + metadata={'task': 'llm/v1/completions'}) + assert mlflow_logger_mock.register_model.call_count == 1 + else: + assert mlflow_logger_mock.save_model.call_count == 0 + assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0 From e7962187b4397a22a0a63625b1af955a9a2424df Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 15 Nov 2023 19:09:13 -0500 Subject: [PATCH 09/11] fix script (#741) --- scripts/inference/hf_generate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 96592ca477..45ddc6b63e 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -217,6 +217,7 @@ def main(args: Namespace) -> None: if device is not None: print(f'Placing model on {device=}...') model.to(device) + model.to(model_dtype) except Exception as e: raise RuntimeError( 'Unable to load HF model. ' + From e730995c4f0dfdaf9d9d547783739eec48880edb Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 15 Nov 2023 20:53:18 -0800 Subject: [PATCH 10/11] Bump to composer 0.17 (#736) --- setup.py | 6 ++--- tests/test_mpt_gen.py | 55 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 05d1d1bbbe..afdfce8d48 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17', + 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17,<0.18', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.34.1,<4.35', 'mosaicml-streaming>=0.7.1,<0.8', @@ -84,11 +84,11 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]', + 'mosaicml[databricks]>=0.17,<0.18', ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.16.1,<0.17', + 'mosaicml[tensorboard]>=0.17,<0.18', ] extra_deps['gpu'] = [ diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py index 413e39bf8c..9f022ef487 100644 --- a/tests/test_mpt_gen.py +++ b/tests/test_mpt_gen.py @@ -95,9 +95,7 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['triton', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False]) -@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, - tie_word_embeddings: bool, build_tiny_mpt: Callable[..., ComposerMPTCausalLM], tiny_ft_dataloader: DataLoader): @@ -105,7 +103,7 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, # build mpt model model = build_tiny_mpt( - tie_word_embeddings=tie_word_embeddings, + tie_word_embeddings=True, attn_config={ 'attn_impl': attn_impl, 'attn_uses_sequence_id': False, @@ -143,3 +141,54 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, generate.generate.assert_called_once() trainer.logger.log_table.assert_called_once() + + +@pytest.mark.gpu +@pytest.mark.parametrize('attn_impl', ['triton', 'torch']) +@pytest.mark.parametrize('use_alibi', [True, False]) +def test_mpt_generate_callback_not_tied( + use_alibi: bool, attn_impl: str, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], + tiny_ft_dataloader: DataLoader): + device = get_device('gpu') + + # build mpt model + model = build_tiny_mpt( + tie_word_embeddings=False, + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi, + }, + ) + model = device.module_to_device(model) + + # generate callback + prompts = [ + 'The best banana bread recipe is', + '2+2=', + 'how much wood could a woodchuck chuck', + ] + gen_interval = 1 + generate = ComposerGenerate( + prompts, + interval=f'{gen_interval}ba', + max_new_tokens=5, + batch_size=len(prompts), + use_cache=True, + ) + generate.generate = Mock(wraps=generate.generate, autospec=True) + + # build trainer + trainer = Trainer( + model=model, + train_dataloader=tiny_ft_dataloader, + device=device, + max_duration=f'{gen_interval}ba', + callbacks=[generate], + ) + trainer.logger.log_table = Mock() + trainer.fit() + + generate.generate.assert_called_once() + trainer.logger.log_table.assert_called_once() From 25bb63f128e55477f1da2cf45d7c4118453b9206 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:25:20 -0800 Subject: [PATCH 11/11] Patch os cpu count to avoid extra multiprocessing inside pytest which sometimes hangs (#745) --- tests/fixtures/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index 39032146b6..16dd01347d 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path +from unittest.mock import MagicMock, patch from composer.utils import dist from omegaconf import DictConfig @@ -25,6 +26,7 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: @fixture +@patch('os.cpu_count', MagicMock(return_value=None)) def tiny_ft_dataloader(tiny_ft_dataset_path: Path, mpt_tokenizer: PreTrainedTokenizerBase, max_seq_len: int = 128,