diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index 1fc8f1d3..f08ff21e 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -26,7 +26,8 @@ def test_foo(self): model.to(device) # Generate - generate_ids = model.generate(inputs.input_ids, max_length=30) + with torch.no_grad(): + generate_ids = model.generate(inputs.input_ids, max_length=30) output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] expected_output = "Hey, are you conscious? Can you talk to me?\nI'm here, but I'm not sure I'm conscious."