From 23de08d02d06e9d42c0fe936e779874ac80e0e4b Mon Sep 17 00:00:00 2001 From: ydshieh Date: Wed, 7 Feb 2024 11:31:33 +0100 Subject: [PATCH] docker --- tests/models/llama/test_inference_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index 1fc8f1d3..f08ff21e 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -26,7 +26,8 @@ def test_foo(self): model.to(device) # Generate - generate_ids = model.generate(inputs.input_ids, max_length=30) + with torch.no_grad(): + generate_ids = model.generate(inputs.input_ids, max_length=30) output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] expected_output = "Hey, are you conscious? Can you talk to me?\nI'm here, but I'm not sure I'm conscious."