diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index edd73f29dc9814..48145979e362ff 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -218,52 +218,53 @@ def test_event_flow(self): import warnings # XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested - warnings.simplefilter(action="ignore", category=UserWarning) - - trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) - trainer.train() - events = trainer.callback_handler.callbacks[-2].events - self.assertEqual(events, self.get_expected_events(trainer)) - - # Independent log/save/eval - trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5) - trainer.train() - events = trainer.callback_handler.callbacks[-2].events - self.assertEqual(events, self.get_expected_events(trainer)) - - trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5) - trainer.train() - events = trainer.callback_handler.callbacks[-2].events - self.assertEqual(events, self.get_expected_events(trainer)) - - trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps") - trainer.train() - events = trainer.callback_handler.callbacks[-2].events - self.assertEqual(events, self.get_expected_events(trainer)) - - trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch") - trainer.train() - events = trainer.callback_handler.callbacks[-2].events - self.assertEqual(events, self.get_expected_events(trainer)) - - # A bit of everything - trainer = self.get_trainer( - callbacks=[MyTestTrainerCallback], - logging_steps=3, - save_steps=10, - eval_steps=5, - eval_strategy="steps", - ) - trainer.train() - events = trainer.callback_handler.callbacks[-2].events - self.assertEqual(events, self.get_expected_events(trainer)) - - # warning should be emitted for duplicated callbacks - with patch("transformers.trainer_callback.logger.warning") as warn_mock: + with warnings.catch_warnings(): + warnings.simplefilter(action="ignore", category=UserWarning) + + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + # Independent log/save/eval + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5) + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5) + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps") + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch") + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + # A bit of everything trainer = self.get_trainer( - callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], + callbacks=[MyTestTrainerCallback], + logging_steps=3, + save_steps=10, + eval_steps=5, + eval_strategy="steps", ) - assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0] + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + # warning should be emitted for duplicated callbacks + with patch("transformers.trainer_callback.logger.warning") as warn_mock: + trainer = self.get_trainer( + callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], + ) + assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0] def test_stateful_callbacks(self): # Use something with non-defaults