From 0fc2970363796c36054b5f41ffa6b6aa3906736e Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 20 Dec 2024 15:40:55 +0100 Subject: [PATCH] Use `weights_only=True` with `torch.load` for `transfo_xl` (#35241) fix Co-authored-by: ydshieh --- .../models/deprecated/transfo_xl/tokenization_transfo_xl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py index ca80636b23565d..53dec63cfc4fd8 100644 --- a/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py @@ -222,7 +222,7 @@ def __init__( "from a PyTorch pretrained vocabulary, " "or activate it with environment variables USE_TORCH=1 and USE_TF=0." ) - vocab_dict = torch.load(pretrained_vocab_file) + vocab_dict = torch.load(pretrained_vocab_file, weights_only=True) if vocab_dict is not None: for key, value in vocab_dict.items(): @@ -705,7 +705,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, # Instantiate tokenizer. corpus = cls(*inputs, **kwargs) - corpus_dict = torch.load(resolved_corpus_file) + corpus_dict = torch.load(resolved_corpus_file, weights_only=True) for key, value in corpus_dict.items(): corpus.__dict__[key] = value corpus.vocab = vocab @@ -784,7 +784,7 @@ def get_lm_corpus(datadir, dataset): fn_pickle = os.path.join(datadir, "cache.pkl") if os.path.exists(fn): logger.info("Loading cached dataset...") - corpus = torch.load(fn_pickle) + corpus = torch.load(fn_pickle, weights_only=True) elif os.path.exists(fn): logger.info("Loading cached dataset from pickle...") if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):