diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 1743efa6..2ca39edd 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -1,4 +1,3 @@ - from dataclasses import dataclass import pytest @@ -17,6 +16,7 @@ class DecodeTestParams: expected_text: str do_sample: bool = False max_new_tokens: int = 20 + top_k: int = 50 @pytest.mark.parametrize("params", @@ -70,7 +70,13 @@ def _test_decode_single(params): generator = AutoGenerator.from_pretrained( model_path, revision="", max_batch_size=1, max_sequence_length=params.sequence_length ) - request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=params.do_sample) + request = create_request( + id=0, + inputs=input_text, + max_new_tokens=max_new_tokens, + do_sample=params.do_sample, + top_k=params.top_k, + ) batch = Batch(id=0, requests=[request], size=1, max_tokens=params.sequence_length) generations, next_batch = generator.prefill(batch) # We already generated one token: call decode max_new_tokens - 1 times @@ -88,6 +94,7 @@ def _test_decode_single(params): output = generations[0].generated_text assert output.generated_tokens == max_new_tokens assert output.finish_reason == 0 + print(f"Generated text: {output.text}") if params.do_sample: assert output.text != params.expected_text else: @@ -102,11 +109,13 @@ def _test_decode_single(params): model_id="meta-llama/Llama-2-7b-hf", sequence_length=256, expected_text="\nWinston Smith, his chin nuzzled into his breast in an effort to escape", + top_k=100, ), DecodeTestParams( model_id="meta-llama/Meta-Llama-3-8B", sequence_length=256, expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,", + top_k=100, ), ], ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"],