From a691ccb0c224f6f76ef585535eec26456236b2e3 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:05:04 +0100 Subject: [PATCH] Change back to `Thread` for SF conversion (#35236) * fix * fix * fix --------- Co-authored-by: ydshieh --- src/transformers/modeling_utils.py | 6 +++--- src/transformers/safetensors_conversion.py | 2 +- src/transformers/testing_utils.py | 21 +++++++++++++++++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f349847b1fd..c86559e62f9 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,7 +29,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps -from multiprocessing import Process +from threading import Thread from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from zipfile import is_zipfile @@ -3825,11 +3825,11 @@ def from_pretrained( **has_file_kwargs, } if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - Process( + Thread( target=auto_conversion, args=(pretrained_model_name_or_path,), kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, - name="Process-auto_conversion", + name="Thread-auto_conversion", ).start() else: # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. diff --git a/src/transformers/safetensors_conversion.py b/src/transformers/safetensors_conversion.py index 5c0179350ea..f1612d3ea57 100644 --- a/src/transformers/safetensors_conversion.py +++ b/src/transformers/safetensors_conversion.py @@ -67,7 +67,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): # security breaches. pr = previous_pr(api, model_id, pr_title, token=token) - if pr is None or (not private and pr.author != "SFConvertBot"): + if pr is None or (not private and pr.author != "SFconvertbot"): spawn_conversion(token, private, model_id) pr = previous_pr(api, model_id, pr_title, token=token) else: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 30f7b5a68fb..409f274d41e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -28,6 +28,7 @@ import subprocess import sys import tempfile +import threading import time import unittest from collections import defaultdict @@ -2311,12 +2312,28 @@ class RequestCounter: def __enter__(self): self._counter = defaultdict(int) - self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug) + self._thread_id = threading.get_ident() + self._extra_info = [] + + def patched_with_thread_info(func): + def wrap(*args, **kwargs): + self._extra_info.append(threading.get_ident()) + return func(*args, **kwargs) + + return wrap + + self.patcher = patch.object( + urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug) + ) self.mock = self.patcher.start() return self def __exit__(self, *args, **kwargs) -> None: - for call in self.mock.call_args_list: + assert len(self.mock.call_args_list) == len(self._extra_info) + + for thread_id, call in zip(self._extra_info, self.mock.call_args_list): + if thread_id != self._thread_id: + continue log = call.args[0] % call.args[1:] for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): if method in log: