From 6d4bc9a44d98154e2a8a39b65072912f21e66328 Mon Sep 17 00:00:00 2001 From: mrhan1993 <50648276+mrhan1993@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:25:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8--gpu-device-id=20=E6=8C=87?= =?UTF-8?q?=E5=AE=9AGPU=E5=90=AF=E5=8A=A8=E7=9A=84=E9=97=AE=E9=A2=98=20Fix?= =?UTF-8?q?es=20#270?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 62f6d02..dc16519 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,14 @@ sys.path.append(module_path) -print("[System ARGV] " + str(sys.argv)) +logger.std_info("[System ARGV] " + str(sys.argv)) + +try: + index = sys.argv.index('--gpu-device-id') + os.environ["CUDA_VISIBLE_DEVICES"] = str(sys.argv[index+1]) + logger.std_info(f"[Fooocus] Set device to: {str(sys.argv[index+1])}") +except ValueError: + pass os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" @@ -74,9 +81,6 @@ def prepare_environments(args) -> bool: Args: args: command line arguments """ - if args.gpu_device_id is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_device_id) - print("Set device to:", args.gpu_device_id) if args.base_url is None or len(args.base_url.strip()) == 0: host = args.host