From ca86836cd47aebc5fc8271acd28358c0c74136aa Mon Sep 17 00:00:00 2001 From: dmMaze Date: Fri, 10 Nov 2023 20:39:37 +0800 Subject: [PATCH 1/2] add lama large --- README.md | 2 +- README_CN.md | 2 +- manga_translator/inpainting/__init__.py | 3 +- .../inpainting/inpainting_lama_mpe.py | 53 ++++++++++++++----- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 6b33861d5..dd4de41ad 100644 --- a/README.md +++ b/README.md @@ -399,7 +399,7 @@ THA: Thai image, DO NOT use craft for manga, it's not designed for it --ocr {32px,48px_ctc} Optical character recognition (OCR) model to use ---inpainter {default,lama_mpe,sd,none,original} +--inpainter {default,lama_large,lama_mpe,sd,none,original} Inpainting model to use --upscaler {waifu2x,esrgan,4xultrasharp} Upscaler to use. --upscale-ratio has to be set for it to take effect diff --git a/README_CN.md b/README_CN.md index 587413491..3b6527194 100644 --- a/README_CN.md +++ b/README_CN.md @@ -131,7 +131,7 @@ THA: Thai image, DO NOT use craft for manga, it's not designed for it --ocr {32px,48px_ctc} Optical character recognition (OCR) model to use ---inpainter {default,lama_mpe,sd,none,original} +--inpainter {default,lama_large,lama_mpe,sd,none,original} Inpainting model to use --upscaler {waifu2x,esrgan,4xultrasharp} Upscaler to use. --upscale-ratio has to be set for it to take effect diff --git a/manga_translator/inpainting/__init__.py b/manga_translator/inpainting/__init__.py index db709c096..5b7a9b161 100644 --- a/manga_translator/inpainting/__init__.py +++ b/manga_translator/inpainting/__init__.py @@ -2,13 +2,14 @@ from .common import CommonInpainter, OfflineInpainter from .inpainting_aot import AotInpainter -from .inpainting_lama_mpe import LamaMPEInpainter +from .inpainting_lama_mpe import LamaMPEInpainter, LamaLargeInpainter from .inpainting_sd import StableDiffusionInpainter from .none import NoneInpainter from .original import OriginalInpainter INPAINTERS = { 'default': AotInpainter, + 'lama_large': LamaLargeInpainter, 'lama_mpe': LamaMPEInpainter, 'sd': StableDiffusionInpainter, 'none': NoneInpainter, diff --git a/manga_translator/inpainting/inpainting_lama_mpe.py b/manga_translator/inpainting/inpainting_lama_mpe.py index 754c9dbb8..76232a353 100644 --- a/manga_translator/inpainting/inpainting_lama_mpe.py +++ b/manga_translator/inpainting/inpainting_lama_mpe.py @@ -15,6 +15,11 @@ from ..utils import resize_keep_aspect class LamaMPEInpainter(OfflineInpainter): + + ''' + Better mark as deprecated and replace with lama large + ''' + _MODEL_MAPPING = { 'model': { 'url': 'https://github.com/zyddnys/manga-image-translator/releases/download/beta-0.3/inpainting_lama_mpe.ckpt', @@ -85,6 +90,24 @@ async def _infer(self, image: np.ndarray, mask: np.ndarray, inpainting_size: int img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR) ans = img_inpainted * mask_original + img_original * (1 - mask_original) return ans + + +class LamaLargeInpainter(LamaMPEInpainter): + + _MODEL_MAPPING = { + 'model': { + 'url': 'https://huggingface.co/dreMaz/AnimeMangaInpainting/resolve/main/lama_large_512px.ckpt', + 'hash': '11d30fbb3000fb2eceae318b75d9ced9229d99ae990a7f8b3ac35c8d31f2c935', + 'file': '.', + }, + } + + async def _load(self, device: str): + self.model = load_lama_mpe(self._get_file_path('lama_large_512px.ckpt'), device='cpu', use_mpe=False, large_arch=True) + self.model.eval() + self.use_cuda = device == 'cuda' + if self.use_cuda: + self.model = self.model.cuda() def set_requires_grad(module, value): @@ -171,12 +194,10 @@ def forward(self, x): r_size = x.size() # (batch, c, h, w/2+1, 2) fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) - # x: torch.float16 - if x.dtype == torch.float16: - half = True + + if x.dtype in (torch.float16, torch.bfloat16): x = x.type(torch.float32) - else: - half = False + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) ffted = torch.stack((ffted.real, ffted.imag), dim=-1) ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) @@ -196,7 +217,7 @@ def forward(self, x): ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) - if ffted.dtype == torch.float16: + if ffted.dtype in (torch.float16, torch.bfloat16): ffted = ffted.type(torch.float32) ffted = torch.complex(ffted[..., 0], ffted[..., 1]) @@ -586,9 +607,15 @@ def forward(self, rel_pos=None, direct=None): class LamaFourier: - def __init__(self, build_discriminator=True, use_mpe=False) -> None: + def __init__(self, build_discriminator=True, use_mpe=False, large_arch: bool = False) -> None: # super().__init__() + + n_blocks = 9 + if large_arch: + n_blocks = 18 + self.generator = FFCResNetGenerator(4, 3, add_out_act='sigmoid', + n_blocks = n_blocks, init_conv_kwargs={ 'ratio_gin': 0, 'ratio_gout': 0, @@ -601,9 +628,9 @@ def __init__(self, build_discriminator=True, use_mpe=False) -> None: 'ratio_gin': 0.75, 'ratio_gout': 0.75, 'enable_lfu': False - } + }, ) - self.enable_fp16 = False + self.discriminator = NLayerDiscriminator() if build_discriminator else None self.inpaint_only = False if use_mpe: @@ -759,10 +786,12 @@ def load_masked_position_encoding(self, mask): return rel_pos, abs_pos, direct -def load_lama_mpe(model_path, device) -> LamaFourier: - model = LamaFourier(build_discriminator=False, use_mpe=True) + +def load_lama_mpe(model_path, device, use_mpe: bool = True, large_arch: bool = False) -> LamaFourier: + model = LamaFourier(build_discriminator=False, use_mpe=use_mpe, large_arch=large_arch) sd = torch.load(model_path, map_location = 'cpu') model.generator.load_state_dict(sd['gen_state_dict']) - model.mpe.load_state_dict(sd['str_state_dict']) + if use_mpe: + model.mpe.load_state_dict(sd['str_state_dict']) model.eval().to(device) return model \ No newline at end of file From 17998fbfd8ecec0f1ec4d3a6cd55192091d7817f Mon Sep 17 00:00:00 2001 From: dmMaze Date: Fri, 10 Nov 2023 21:39:29 +0800 Subject: [PATCH 2/2] support bf16 inference for lama --- README.md | 2 ++ README_CN.md | 2 ++ manga_translator/args.py | 3 ++- .../inpainting/inpainting_lama_mpe.py | 25 +++++++++++++++++-- manga_translator/manga_translator.py | 2 ++ 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index dd4de41ad..228e2ce63 100644 --- a/README.md +++ b/README.md @@ -433,6 +433,8 @@ THA: Thai --min-text-length MIN_TEXT_LENGTH Minimum text length of a text region --inpainting-size INPAINTING_SIZE Size of image used for inpainting (too large will result in OOM) +--inpainting-precision INPAINTING_PRECISION Inpainting precision for lama, + use bf16 while you can. --colorization-size COLORIZATION_SIZE Size of image used for colorization. Set to -1 to use full image size --denoise-sigma DENOISE_SIGMA Used by colorizer and affects color strength, range diff --git a/README_CN.md b/README_CN.md index 3b6527194..06e7d239d 100644 --- a/README_CN.md +++ b/README_CN.md @@ -165,6 +165,8 @@ THA: Thai --min-text-length MIN_TEXT_LENGTH Minimum text length of a text region --inpainting-size INPAINTING_SIZE Size of image used for inpainting (too large will result in OOM) +--inpainting-precision INPAINTING_PRECISION Inpainting precision for lama, + use bf16 while you can. --colorization-size COLORIZATION_SIZE Size of image used for colorization. Set to -1 to use full image size --denoise-sigma DENOISE_SIGMA Used by colorizer and affects color strength, range diff --git a/manga_translator/args.py b/manga_translator/args.py index ec42b591a..24238a011 100644 --- a/manga_translator/args.py +++ b/manga_translator/args.py @@ -105,7 +105,7 @@ def _format_action_invocation(self, action: argparse.Action) -> str: parser.add_argument('--detector', default='default', type=str, choices=DETECTORS, help='Text detector used for creating a text mask from an image, DO NOT use craft for manga, it\'s not designed for it') parser.add_argument('--ocr', default='48px_ctc', type=str, choices=OCRS, help='Optical character recognition (OCR) model to use') -parser.add_argument('--inpainter', default='lama_mpe', type=str, choices=INPAINTERS, help='Inpainting model to use') +parser.add_argument('--inpainter', default='lama_large', type=str, choices=INPAINTERS, help='Inpainting model to use') parser.add_argument('--upscaler', default='esrgan', type=str, choices=UPSCALERS, help='Upscaler to use. --upscale-ratio has to be set for it to take effect') parser.add_argument('--upscale-ratio', default=None, type=float, help='Image upscale ratio applied before detection. Can improve text detection.') parser.add_argument('--colorizer', default=None, type=str, choices=COLORIZERS, help='Colorization model to use.') @@ -126,6 +126,7 @@ def _format_action_invocation(self, action: argparse.Action) -> str: parser.add_argument('--text-threshold', default=0.5, type=float, help='Threshold for text detection') parser.add_argument('--min-text-length', default=0, type=int, help='Minimum text length of a text region') parser.add_argument('--inpainting-size', default=2048, type=int, help='Size of image used for inpainting (too large will result in OOM)') +parser.add_argument('--inpainting-precision', default='fp32', type=str, help='Inpainting precision for lama, use bf16 while you can.', choices=['fp32', 'fp16', 'bf16']) parser.add_argument('--colorization-size', default=576, type=int, help='Size of image used for colorization. Set to -1 to use full image size') parser.add_argument('--denoise-sigma', default=30, type=int, help='Used by colorizer and affects color strength, range from 0 to 255 (default 30). -1 turns it off.') parser.add_argument('--mask-dilation-offset', default=0, type=int, help='By how much to extend the text mask to remove left-over text pixels of the original image.') diff --git a/manga_translator/inpainting/inpainting_lama_mpe.py b/manga_translator/inpainting/inpainting_lama_mpe.py index 76232a353..8a2d99de4 100644 --- a/manga_translator/inpainting/inpainting_lama_mpe.py +++ b/manga_translator/inpainting/inpainting_lama_mpe.py @@ -14,6 +14,14 @@ from .common import OfflineInpainter from ..utils import resize_keep_aspect + +TORCH_DTYPE_MAP = { + 'fp32': torch.float32, + 'fp16': torch.float16, + 'bf16': torch.bfloat16, +} + + class LamaMPEInpainter(OfflineInpainter): ''' @@ -81,7 +89,20 @@ async def _infer(self, image: np.ndarray, mask: np.ndarray, inpainting_size: int mask_torch = mask_torch.cuda() with torch.no_grad(): img_torch *= (1 - mask_torch) - img_inpainted_torch = self.model(img_torch, mask_torch) + if not self.use_cuda: + img_inpainted_torch = self.model(img_torch, mask_torch) + else: + # Note: lama's weight shouldn't be convert to fp16 or bf16 otherwise it produces darkened results. + # but it can inference under torch.autocast + precision = TORCH_DTYPE_MAP[os.environ.get("INPAINTING_PRECISION", "fp32")] + + if precision == torch.float16: + precision = torch.bfloat16 + self.logger.warning('Switch to bf16 due to Lama only compatible with bf16 and fp32.') + + with torch.autocast(device_type="cuda", dtype=precision): + img_inpainted_torch = self.model(img_torch, mask_torch) + if isinstance(self.model, LamaFourier): img_inpainted = (img_inpainted_torch.cpu().squeeze_(0).permute(1, 2, 0).numpy() * 255.).astype(np.uint8) else: @@ -578,7 +599,7 @@ def forward(self, img, mask, rel_pos=None, direct=None) -> Tensor: if rel_pos is None: return self.model(masked_img) else: - + x_l, x_g = self.model[:2](masked_img) x_l = x_l.to(torch.float32) x_l += rel_pos diff --git a/manga_translator/manga_translator.py b/manga_translator/manga_translator.py index 8930121ac..104e9eab9 100644 --- a/manga_translator/manga_translator.py +++ b/manga_translator/manga_translator.py @@ -108,6 +108,8 @@ def parse_init_params(self, params: dict): if params.get('model_dir'): ModelWrapper._MODEL_DIR = params.get('model_dir') + os.environ['INPAINTING_PRECISION'] = params.get('inpainting_precision', 'fp32') + @property def using_cuda(self): return self.device.startswith('cuda')