diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6c2a8aa4251895..c7238c42df877c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1109,6 +1109,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): # - assisted_decoding does not support `use_cache = False` # - assisted_decoding does not support `batch_size > 1` + set_model_tester_for_less_flaky_test(self) for model_class in self.all_generative_model_classes: if model_class._is_stateful: self.skipTest(reason="Stateful models don't support assisted generation") @@ -1132,6 +1133,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + set_config_for_less_flaky_test(config) # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -1139,6 +1141,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): config.is_decoder = True model = model_class(config).to(torch_device).eval() + set_model_for_less_flaky_test(model) # Sets assisted generation arguments such that: # a) no EOS is generated, to ensure generation doesn't break early # b) the assistant model always generates two tokens when it is called, to ensure the input preparation of