Skip to content

Commit

Permalink
remove unnecessary guard code related with pytorch versions 1.4.2 ~ 1…
Browse files Browse the repository at this point in the history
….7.0 (#32210)

remove unnecessary guard code related with pytorch versions 1.4.2 ~
1.7.0
  • Loading branch information
statelesshz authored Jul 25, 2024
1 parent 5658e74 commit f53a5de
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler


# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
try:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
except ImportError:
SAVE_STATE_WARNING = ""

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -251,10 +245,10 @@ def distributed_broadcast_scalars(


def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the SAVE_STATE_WARNING
# Reissue warnings
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category is not UserWarning or w.message != SAVE_STATE_WARNING:
if w.category is not UserWarning:
warnings.warn(w.message, w.category)


Expand Down

0 comments on commit f53a5de

Please sign in to comment.