diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c7238c42df877c..5ccc8599776102 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1167,6 +1167,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): # case when some are accepted and some not is hard to reproduce, so let's hope this catches most errors :) if assistant_type == "random": assistant_model = model_class(config).to(torch_device).eval() + set_model_for_less_flaky_test(assistant_model) else: assistant_model = model assistant_model.generation_config.num_assistant_tokens = 2 # see b)