diff --git a/tests/models/llama/test_inference_llama.py b/tests/models/llama/test_inference_llama.py index 3b267470..3a1c9dc9 100644 --- a/tests/models/llama/test_inference_llama.py +++ b/tests/models/llama/test_inference_llama.py @@ -33,6 +33,6 @@ def test_foo(self): generate_ids = model.generate(inputs.input_ids, max_length=30) output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - expected_output = "Hey, are you conscious? Can you talk to me?\nI'm here, but I'm not sure I'm conscious." + expected_output = "Hey, are you conscious? Can you talk to me?\nI'm not sure if you can hear me, but I'm talking" assert output == expected_output