Skip to content

Commit

Permalink
yes!
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jan 8, 2025
1 parent b05df66 commit 39d3a35
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,10 @@ def recursive_check(batched_object, single_row_object, model_name, key):
),
)

set_model_tester_for_less_flaky_test(self)

config, batched_input = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)
equivalence = get_tensor_equivalence_function(batched_input)

for model_class in self.all_model_classes:
Expand All @@ -827,6 +830,7 @@ def recursive_check(batched_object, single_row_object, model_name, key):
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
model = model_class(config).to(torch_device).eval()
set_model_for_less_flaky_test(model)

batch_size = self.model_tester.batch_size
single_row_input = {}
Expand Down

0 comments on commit 39d3a35

Please sign in to comment.