From 4d80bbc08db4c5950da94594390a73e70a0a079e Mon Sep 17 00:00:00 2001 From: Sam Havens Date: Thu, 30 Nov 2023 14:40:28 -0800 Subject: [PATCH] fix device merge artifact in test_model.oy --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 65e983f260..16ef0ad645 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -867,7 +867,7 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, no_padding_attention_mask) # inputs_embeds - inputs_embeds = torch.randn(2, 3, 128).to(device) + inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128)) # a single batch with different amounts of left padding in the input batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11],