diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index de7a2745ca073f..753f0aa0c6df9e 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -41,6 +41,9 @@ require_torch_gpu, require_torch_sdpa, require_torchaudio, + set_config_for_less_flaky_test, + set_model_for_less_flaky_test, + set_model_tester_for_less_flaky_test, slow, torch_device, ) @@ -516,8 +519,11 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): def get_mean_reldiff(failcase, x, ref, atol, rtol): return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + set_model_tester_for_less_flaky_test(self) + for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + set_config_for_less_flaky_test(config) model = model_class(config) is_encoder_decoder = model.config.is_encoder_decoder @@ -534,6 +540,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) + set_model_for_less_flaky_test(model_eager) + set_model_for_less_flaky_test(model_sdpa) + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = [] @@ -1522,8 +1531,11 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): def get_mean_reldiff(failcase, x, ref, atol, rtol): return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + set_model_tester_for_less_flaky_test(self) + for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + set_config_for_less_flaky_test(config) model = model_class(config) is_encoder_decoder = model.config.is_encoder_decoder @@ -1540,6 +1552,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) + set_model_for_less_flaky_test(model_eager) + set_model_for_less_flaky_test(model_sdpa) + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = []