From ea57c8c064792f2bd0a088dd9038df19b5484e23 Mon Sep 17 00:00:00 2001 From: Sam Havens Date: Thu, 30 Nov 2023 16:38:35 -0800 Subject: [PATCH] yapf --- tests/test_model.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index 66a3ff0b6f..acb2074ae9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -110,9 +110,9 @@ def gen_random_batch(batch_size: int, high=test_cfg.model.vocab_size, size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) if inp == 'inputs_embeds': - batch['inputs_embeds'] = torch.randn(batch_size, test_cfg.max_seq_len, - test_cfg.model.d_model).to( - test_cfg.device) + batch['inputs_embeds'] = torch.randn( + batch_size, test_cfg.max_seq_len, + test_cfg.model.d_model).to(test_cfg.device) batch['labels'] = torch.randint(low=0, high=test_cfg.model.vocab_size, @@ -179,7 +179,7 @@ def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2): assert not torch.equal(original_params, updated_params) -@pytest.mark.parametrize('inputs', [[], ['input_ids','inputs_embeds']]) +@pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']]) def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]): test_cfg, model, _ = get_objs( conf_path='scripts/train/yamls/pretrain/testing.yaml') @@ -907,25 +907,25 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, # check that both/neither ids and embeds do not error # note that we need to set the BOS token ID for generating from neither _ = mpt.generate(input_ids=no_padding_input_ids, - inputs_embeds=inputs_embeds, - attention_mask=no_padding_attention_mask, - max_new_tokens=5, - use_cache=False) + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=False) _ = mpt.generate(input_ids=no_padding_input_ids, - inputs_embeds=inputs_embeds, - attention_mask=no_padding_attention_mask, - max_new_tokens=5, - use_cache=True) + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=True) _ = mpt.generate(input_ids=None, - inputs_embeds=None, - max_new_tokens=5, - use_cache=False, - bos_token_id=50256) + inputs_embeds=None, + max_new_tokens=5, + use_cache=False, + bos_token_id=50256) _ = mpt.generate(input_ids=None, - inputs_embeds=None, - max_new_tokens=5, - use_cache=True, - bos_token_id=50256) + inputs_embeds=None, + max_new_tokens=5, + use_cache=True, + bos_token_id=50256) @pytest.mark.gpu