Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Nov 28, 2024
1 parent 8795d1f commit 8356698
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# - assisted_decoding does not support `use_cache = False`
# - assisted_decoding does not support `batch_size > 1`

set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
Expand All @@ -1132,13 +1133,15 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):

# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
set_config_for_less_flaky_test(config)

# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

config.is_decoder = True
model = model_class(config).to(torch_device).eval()
set_model_for_less_flaky_test(model)
# Sets assisted generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) the assistant model always generates two tokens when it is called, to ensure the input preparation of
Expand Down

0 comments on commit 8356698

Please sign in to comment.