Skip to content

Commit

Permalink
Fix code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 7, 2023
1 parent 34b75e5 commit 2ae1c94
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def forward(
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
)
if use_cache:
if presents is not None:
presents += (present,)

if output_attentions:
Expand Down
51 changes: 28 additions & 23 deletions tests/test_hf_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch

import pytest
import torch
from composer.core.precision import get_precision_context
from composer.utils import get_device, reproducibility
from composer.utils import dist, get_device, reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from composer.utils import dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoModelForCausalLM

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.mpt.modeling_mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer
import torch

from transformers import AutoModelForCausalLM
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from unittest.mock import patch
from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM, MPTConfig

@pytest.mark.gpu
@pytest.mark.parametrize('device', ['cpu', 'gpu'])
Expand Down Expand Up @@ -79,11 +78,13 @@ def test_init_hfhub_mpt(device: str, attn_impl: str):
def test_init_hfhub_mpt_cpu():
test_init_hfhub_mpt(device='cpu', attn_impl='torch')


EOS_TOKEN_ID = 0


class MockMPTForCausalLM(MPTForCausalLM):
"""Class that overrides the forward of MPTForCausalLM.
"""
"""Class that overrides the forward of MPTForCausalLM."""

def forward(
self,
input_ids: torch.LongTensor,
Expand All @@ -98,7 +99,10 @@ def forward(
use_cache: Optional[bool] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
):
result = super().forward(input_ids, past_key_values, attention_mask, prefix_mask, sequence_id, labels, return_dict, output_attentions, output_hidden_states, use_cache, inputs_embeds)
result = super().forward(input_ids, past_key_values, attention_mask,
prefix_mask, sequence_id, labels, return_dict,
output_attentions, output_hidden_states,
use_cache, inputs_embeds)
# Modify the logits to select the next token.
if dist.get_global_rank() == 0:
# Rank 0 hits EOS immediately.
Expand All @@ -108,17 +112,20 @@ def forward(
result.logits[:, :, EOS_TOKEN_ID] = -torch.inf
return result


def mock_from_config(config: MPTConfig, **_):
config_dict = config.to_dict()
config = MPTConfig.from_dict(config_dict)
return MockMPTForCausalLM._from_config(config)


@pytest.mark.world_size(2)
@pytest.mark.gpu
@patch.object(AutoModelForCausalLM, 'from_config', new=mock_from_config)
def test_mpt_generate_multi_gpu():
"""Tests mpt generation with mutiple gpus and
generations of different lengths.
"""Tests mpt generation with mutiple gpus.
and generations of different lengths.
"""
composer_device = get_device('gpu')
dist.initialize_dist(composer_device)
Expand All @@ -144,20 +151,18 @@ def test_mpt_generate_multi_gpu():

# build tokenizer
tokenizer_name = test_cfg.tokenizer.name
tokenizer = build_tokenizer(tokenizer_name, {'max_seq_len': 15})
tokenizer = build_tokenizer(tokenizer_name, {'max_seq_len': 15})

# build model
model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
tokenizer)
tokenizer)
model = composer_device.module_to_device(model)

model.model = FSDP(model.model)

_ = model.generate(
composer_device.tensor_to_device(
tokenizer('hello', return_tensors='pt')['input_ids']),
max_new_tokens=10,
eos_token_id=EOS_TOKEN_ID,
use_cache=True,
synced_gpus=True
)
_ = model.generate(composer_device.tensor_to_device(
tokenizer('hello', return_tensors='pt')['input_ids']),
max_new_tokens=10,
eos_token_id=EOS_TOKEN_ID,
use_cache=True,
synced_gpus=True)

0 comments on commit 2ae1c94

Please sign in to comment.