diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9bf35147307..26e62db31c0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -819,14 +819,18 @@ def recursive_check(batched_object, single_row_object, model_name, key): config, batched_input = self.model_tester.prepare_config_and_inputs_for_common() equivalence = get_tensor_equivalence_function(batched_input) + set_model_tester_for_less_flaky_test(self) + for model_class in self.all_model_classes: config.output_hidden_states = True + set_config_for_less_flaky_test(config) model_name = model_class.__name__ if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"): config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class) batched_input_prepared = self._prepare_for_class(batched_input, model_class) model = model_class(config).to(torch_device).eval() + set_model_for_less_flaky_test(model) batch_size = self.model_tester.batch_size single_row_input = {}