Skip to content

Commit

Permalink
yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
samhavens committed Dec 1, 2023
1 parent 1765822 commit ea57c8c
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def gen_random_batch(batch_size: int,
high=test_cfg.model.vocab_size,
size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device)
if inp == 'inputs_embeds':
batch['inputs_embeds'] = torch.randn(batch_size, test_cfg.max_seq_len,
test_cfg.model.d_model).to(
test_cfg.device)
batch['inputs_embeds'] = torch.randn(
batch_size, test_cfg.max_seq_len,
test_cfg.model.d_model).to(test_cfg.device)

batch['labels'] = torch.randint(low=0,
high=test_cfg.model.vocab_size,
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2):
assert not torch.equal(original_params, updated_params)


@pytest.mark.parametrize('inputs', [[], ['input_ids','inputs_embeds']])
@pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']])
def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]):
test_cfg, model, _ = get_objs(
conf_path='scripts/train/yamls/pretrain/testing.yaml')
Expand Down Expand Up @@ -907,25 +907,25 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict,
# check that both/neither ids and embeds do not error
# note that we need to set the BOS token ID for generating from neither
_ = mpt.generate(input_ids=no_padding_input_ids,
inputs_embeds=inputs_embeds,
attention_mask=no_padding_attention_mask,
max_new_tokens=5,
use_cache=False)
inputs_embeds=inputs_embeds,
attention_mask=no_padding_attention_mask,
max_new_tokens=5,
use_cache=False)
_ = mpt.generate(input_ids=no_padding_input_ids,
inputs_embeds=inputs_embeds,
attention_mask=no_padding_attention_mask,
max_new_tokens=5,
use_cache=True)
inputs_embeds=inputs_embeds,
attention_mask=no_padding_attention_mask,
max_new_tokens=5,
use_cache=True)
_ = mpt.generate(input_ids=None,
inputs_embeds=None,
max_new_tokens=5,
use_cache=False,
bos_token_id=50256)
inputs_embeds=None,
max_new_tokens=5,
use_cache=False,
bos_token_id=50256)
_ = mpt.generate(input_ids=None,
inputs_embeds=None,
max_new_tokens=5,
use_cache=True,
bos_token_id=50256)
inputs_embeds=None,
max_new_tokens=5,
use_cache=True,
bos_token_id=50256)


@pytest.mark.gpu
Expand Down

0 comments on commit ea57c8c

Please sign in to comment.