Skip to content

Commit

Permalink
fix device merge artifact in test_model.oy
Browse files Browse the repository at this point in the history
  • Loading branch information
samhavens committed Nov 30, 2023
1 parent b10b106 commit 4d80bbc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict,
no_padding_attention_mask)

# inputs_embeds
inputs_embeds = torch.randn(2, 3, 128).to(device)
inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128))

# a single batch with different amounts of left padding in the input
batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11],
Expand Down

0 comments on commit 4d80bbc

Please sign in to comment.