diff --git a/tests/models/textnet/test_modeling_textnet.py b/tests/models/textnet/test_modeling_textnet.py index b881ecfec9f650..ffcf6c2e13f4e5 100644 --- a/tests/models/textnet/test_modeling_textnet.py +++ b/tests/models/textnet/test_modeling_textnet.py @@ -324,7 +324,8 @@ def test_inference_textnet_image_classification(self): inputs = processor(images=image, return_tensors="pt").to(torch_device) # forward pass - output = model(**inputs) + with torch.no_grad(): + output = model(**inputs) # verify logits self.assertEqual(output.logits.shape, torch.Size([1, 2]))