From f53a5dec7b03eb195dc89c82ae761b033db1ceb6 Mon Sep 17 00:00:00 2001 From: Huazhong Ji Date: Thu, 25 Jul 2024 17:04:04 +0800 Subject: [PATCH] remove unnecessary guard code related with pytorch versions 1.4.2 ~ 1.7.0 (#32210) remove unnecessary guard code related with pytorch versions 1.4.2 ~ 1.7.0 --- src/transformers/trainer_pt_utils.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 69b547dec572fe..a3c2db27d2f7f2 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -64,12 +64,6 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0 -try: - from torch.optim.lr_scheduler import SAVE_STATE_WARNING -except ImportError: - SAVE_STATE_WARNING = "" - logger = logging.get_logger(__name__) @@ -251,10 +245,10 @@ def distributed_broadcast_scalars( def reissue_pt_warnings(caught_warnings): - # Reissue warnings that are not the SAVE_STATE_WARNING + # Reissue warnings if len(caught_warnings) > 1: for w in caught_warnings: - if w.category is not UserWarning or w.message != SAVE_STATE_WARNING: + if w.category is not UserWarning: warnings.warn(w.message, w.category)