diff --git a/.gitignore b/.gitignore
index b6e4761..a121d31 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,3 +127,5 @@ dmypy.json
# Pyre type checker
.pyre/
+
+weights/*
\ No newline at end of file
diff --git a/README.md b/README.md
index e863d28..33e36d7 100644
--- a/README.md
+++ b/README.md
@@ -32,6 +32,8 @@ _2[Department of Computing, The Hong Kong Polytechnic University](htt
## News
+(2023-12-16) GPEN can run on Apple Silicon GPU now by using `--use_mps`.
+
(2023-02-15) **GPEN-BFR-1024** and **GPEN-BFR-2048** are now publicly available. Please download them via \[[ModelScope2](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement-hires/summary)\].
(2023-02-15) We provide online demos via \[[ModelScope1](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement/summary)\] and \[[ModelScope2](https://www.modelscope.cn/models/damo/cv_gpen_image-portrait-enhancement-hires/summary)\].
diff --git a/demo.py b/demo.py
index 3ca8de3..1f9dc20 100644
--- a/demo.py
+++ b/demo.py
@@ -78,6 +78,7 @@ def generate_mask(H, W, img=None):
parser.add_argument('--alpha', type=float, default=1, help='blending the results')
parser.add_argument('--use_sr', action='store_true', help='use sr or not')
parser.add_argument('--use_cuda', action='store_true', help='use cuda or not')
+ parser.add_argument('--use_mps', action='store_true', help='use Apple Silicon or not')
parser.add_argument('--save_face', action='store_true', help='save face or not')
parser.add_argument('--aligned', action='store_true', help='input are aligned faces or not')
parser.add_argument('--sr_model', type=str, default='realesrnet', help='SR model')
@@ -93,14 +94,21 @@ def generate_mask(H, W, img=None):
os.makedirs(args.outdir, exist_ok=True)
+ if args.use_cuda:
+ device = 'cuda'
+ elif args.use_mps:
+ device = 'mps'
+ else:
+ device = 'cpu'
+
if args.task == 'FaceEnhancement':
- processer = FaceEnhancement(args, in_size=args.in_size, model=args.model, use_sr=args.use_sr, device='cuda' if args.use_cuda else 'cpu')
+ processer = FaceEnhancement(args, in_size=args.in_size, model=args.model, use_sr=args.use_sr, device=device)
elif args.task == 'FaceColorization':
- processer = FaceColorization(in_size=args.in_size, model=args.model, device='cuda' if args.use_cuda else 'cpu')
+ processer = FaceColorization(in_size=args.in_size, model=args.model, device=device)
elif args.task == 'FaceInpainting':
- processer = FaceInpainting(in_size=args.in_size, model=args.model, device='cuda' if args.use_cuda else 'cpu')
+ processer = FaceInpainting(in_size=args.in_size, model=args.model, device=device)
elif args.task == 'Segmentation2Face':
- processer = Segmentation2Face(in_size=args.in_size, model=args.model, is_norm=False, device='cuda' if args.use_cuda else 'cpu')
+ processer = Segmentation2Face(in_size=args.in_size, model=args.model, is_norm=False, device=device)
files = sorted(glob.glob(os.path.join(args.indir, '*.*g')))
diff --git a/face_parse/face_parsing.py b/face_parse/face_parsing.py
index 63e5e93..b3c06bc 100755
--- a/face_parse/face_parsing.py
+++ b/face_parse/face_parsing.py
@@ -58,7 +58,7 @@ def process_tensor(self, imt):
def img2tensor(self, img):
img = img[..., ::-1]
img = img / 255. * 2 - 1
- img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
+ img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device, dtype=torch.float32)
return img_tensor.float()
def tenor2mask(self, tensor):