diff --git a/tests/test_model.py b/tests/test_model.py index 65e983f260..16ef0ad645 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -867,7 +867,7 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, no_padding_attention_mask) # inputs_embeds - inputs_embeds = torch.randn(2, 3, 128).to(device) + inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128)) # a single batch with different amounts of left padding in the input batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11],