Skip to content

Commit

Permalink
Modify warnings in a with block to avoid flaky tests (huggingface…
Browse files Browse the repository at this point in the history
…#31893)

* fix

* [test_all] check before merge

---------

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Jul 10, 2024
1 parent ec03d97 commit 080e14b
Showing 1 changed file with 45 additions and 44 deletions.
89 changes: 45 additions & 44 deletions tests/trainer/test_trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,52 +218,53 @@ def test_event_flow(self):
import warnings

# XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested
warnings.simplefilter(action="ignore", category=UserWarning)

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

# Independent log/save/eval
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

# A bit of everything
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback],
logging_steps=3,
save_steps=10,
eval_steps=5,
eval_strategy="steps",
)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

# warning should be emitted for duplicated callbacks
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=UserWarning)

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback])
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

# Independent log/save/eval
trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5)
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch")
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

# A bit of everything
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
callbacks=[MyTestTrainerCallback],
logging_steps=3,
save_steps=10,
eval_steps=5,
eval_strategy="steps",
)
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]
trainer.train()
events = trainer.callback_handler.callbacks[-2].events
self.assertEqual(events, self.get_expected_events(trainer))

# warning should be emitted for duplicated callbacks
with patch("transformers.trainer_callback.logger.warning") as warn_mock:
trainer = self.get_trainer(
callbacks=[MyTestTrainerCallback, MyTestTrainerCallback],
)
assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]

def test_stateful_callbacks(self):
# Use something with non-defaults
Expand Down

0 comments on commit 080e14b

Please sign in to comment.