diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 87f7ba5f00..4f4581b177 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -443,7 +443,7 @@ def forward( is_causal=self.is_causal, output_attentions=bool(output_attentions), ) - if use_cache: + if presents is not None: presents += (present,) if output_attentions: diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py index 8bbbb6ef48..73b98ffb32 100644 --- a/tests/test_hf_mpt_gen.py +++ b/tests/test_hf_mpt_gen.py @@ -2,22 +2,21 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import patch import pytest +import torch from composer.core.precision import get_precision_context -from composer.utils import get_device, reproducibility +from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om -from composer.utils import dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from transformers import AutoModelForCausalLM from llmfoundry import COMPOSER_MODEL_REGISTRY +from llmfoundry.models.mpt.modeling_mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer -import torch -from transformers import AutoModelForCausalLM -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from unittest.mock import patch -from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM, MPTConfig @pytest.mark.gpu @pytest.mark.parametrize('device', ['cpu', 'gpu']) @@ -79,11 +78,13 @@ def test_init_hfhub_mpt(device: str, attn_impl: str): def test_init_hfhub_mpt_cpu(): test_init_hfhub_mpt(device='cpu', attn_impl='torch') + EOS_TOKEN_ID = 0 + class MockMPTForCausalLM(MPTForCausalLM): - """Class that overrides the forward of MPTForCausalLM. - """ + """Class that overrides the forward of MPTForCausalLM.""" + def forward( self, input_ids: torch.LongTensor, @@ -98,7 +99,10 @@ def forward( use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): - result = super().forward(input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, labels, return_dict, output_attentions, output_hidden_states, use_cache, inputs_embeds) + result = super().forward(input_ids, past_key_values, attention_mask, + prefix_mask, sequence_id, labels, return_dict, + output_attentions, output_hidden_states, + use_cache, inputs_embeds) # Modify the logits to select the next token. if dist.get_global_rank() == 0: # Rank 0 hits EOS immediately. @@ -108,17 +112,20 @@ def forward( result.logits[:, :, EOS_TOKEN_ID] = -torch.inf return result + def mock_from_config(config: MPTConfig, **_): config_dict = config.to_dict() config = MPTConfig.from_dict(config_dict) return MockMPTForCausalLM._from_config(config) + @pytest.mark.world_size(2) @pytest.mark.gpu @patch.object(AutoModelForCausalLM, 'from_config', new=mock_from_config) def test_mpt_generate_multi_gpu(): - """Tests mpt generation with mutiple gpus and - generations of different lengths. + """Tests mpt generation with mutiple gpus. + + and generations of different lengths. """ composer_device = get_device('gpu') dist.initialize_dist(composer_device) @@ -144,20 +151,18 @@ def test_mpt_generate_multi_gpu(): # build tokenizer tokenizer_name = test_cfg.tokenizer.name - tokenizer = build_tokenizer(tokenizer_name, {'max_seq_len': 15}) - + tokenizer = build_tokenizer(tokenizer_name, {'max_seq_len': 15}) + # build model model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, - tokenizer) + tokenizer) model = composer_device.module_to_device(model) model.model = FSDP(model.model) - _ = model.generate( - composer_device.tensor_to_device( - tokenizer('hello', return_tensors='pt')['input_ids']), - max_new_tokens=10, - eos_token_id=EOS_TOKEN_ID, - use_cache=True, - synced_gpus=True - ) + _ = model.generate(composer_device.tensor_to_device( + tokenizer('hello', return_tensors='pt')['input_ids']), + max_new_tokens=10, + eos_token_id=EOS_TOKEN_ID, + use_cache=True, + synced_gpus=True)