From 02580c7cd207d05de51d67f537d2e898d9c9b17b Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Fri, 2 Jun 2023 06:20:23 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- dpt.py | 45 ++- encoding.py | 8 +- evaluation/Prompt.py | 2 +- freqencoder/backend.py | 9 +- freqencoder/setup.py | 9 +- gridencoder/backend.py | 9 +- gridencoder/grid.py | 8 +- gridencoder/setup.py | 9 +- guidance/if_utils.py | 15 +- guidance/sd_utils.py | 19 +- guidance/zero123_utils.py | 45 ++- ldm/extras.py | 4 +- ldm/lr_scheduler.py | 25 +- ldm/models/autoencoder.py | 73 ++-- ldm/models/diffusion/classifier.py | 13 +- ldm/models/diffusion/ddim.py | 15 +- ldm/models/diffusion/ddpm.py | 317 ++++++++---------- ldm/models/diffusion/plms.py | 13 +- ldm/models/diffusion/sampling_util.py | 3 +- ldm/modules/attention.py | 14 +- ldm/modules/diffusionmodules/model.py | 32 +- ldm/modules/diffusionmodules/openaimodel.py | 17 +- ldm/modules/distributions/distributions.py | 35 +- ldm/modules/ema.py | 4 +- ldm/modules/encoders/modules.py | 49 +-- ldm/modules/evaluate/adm_evaluator.py | 15 +- .../evaluate/evaluate_perceptualsim.py | 99 +++--- .../evaluate/frechet_video_distance.py | 20 +- ldm/modules/evaluate/ssim.py | 10 +- .../evaluate/torch_frechet_video_distance.py | 26 +- ldm/modules/image_degradation/bsrgan.py | 114 +++---- ldm/modules/image_degradation/bsrgan_light.py | 25 +- ldm/modules/image_degradation/utils_image.py | 46 ++- ldm/modules/losses/contperceptual.py | 26 +- ldm/modules/losses/vqperceptual.py | 44 ++- ldm/modules/x_transformer.py | 10 +- ldm/thirdp/psp/helpers.py | 11 +- ldm/thirdp/psp/id_loss.py | 3 +- ldm/thirdp/psp/model_irse.py | 38 ++- ldm/util.py | 22 +- nerf/gui.py | 99 +++--- nerf/network.py | 7 +- nerf/network_grid.py | 19 +- nerf/network_grid_taichi.py | 19 +- nerf/network_grid_tcnn.py | 19 +- nerf/provider.py | 12 +- nerf/renderer.py | 163 ++++----- nerf/utils.py | 6 +- optimizer.py | 23 +- preprocess_image.py | 20 +- raymarching/backend.py | 9 +- raymarching/setup.py | 9 +- shencoder/backend.py | 9 +- shencoder/setup.py | 9 +- taichi_modules/hash_encoder.py | 4 +- taichi_modules/ray_march.py | 26 +- taichi_modules/utils.py | 7 +- tets/generate_tets.py | 7 +- 58 files changed, 804 insertions(+), 964 deletions(-) diff --git a/dpt.py b/dpt.py index 8cc04794..6201d6bb 100644 --- a/dpt.py +++ b/dpt.py @@ -42,10 +42,7 @@ def __init__(self, start_index=1): self.start_index = start_index def forward(self, x): - if self.start_index == 2: - readout = (x[:, 0] + x[:, 1]) / 2 - else: - readout = x[:, 0] + readout = (x[:, 0] + x[:, 1]) / 2 if self.start_index == 2 else x[:, 0] return x[:, self.start_index :] + readout.unsqueeze(1) @@ -84,10 +81,10 @@ def forward_vit(pretrained, x): layer_3 = pretrained.activations["3"] layer_4 = pretrained.activations["4"] - layer_1 = pretrained.act_postprocess1[0:2](layer_1) - layer_2 = pretrained.act_postprocess2[0:2](layer_2) - layer_3 = pretrained.act_postprocess3[0:2](layer_3) - layer_4 = pretrained.act_postprocess4[0:2](layer_4) + layer_1 = pretrained.act_postprocess1[:2](layer_1) + layer_2 = pretrained.act_postprocess2[:2](layer_2) + layer_3 = pretrained.act_postprocess3[:2](layer_3) + layer_4 = pretrained.act_postprocess4[:2](layer_4) unflattened_dim = 2 @@ -96,7 +93,7 @@ def forward_vit(pretrained, x): int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')), ) unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size)) - + if layer_1.ndim == 3: layer_1 = unflatten(layer_1) @@ -107,10 +104,10 @@ def forward_vit(pretrained, x): if layer_4.ndim == 3: layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size) - layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) - layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) - layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) - layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + layer_1 = pretrained.act_postprocess1[3:](layer_1) + layer_2 = pretrained.act_postprocess2[3:](layer_2) + layer_3 = pretrained.act_postprocess3[3:](layer_3) + layer_4 = pretrained.act_postprocess4[3:](layer_4) return layer_1, layer_2, layer_3, layer_4 @@ -187,9 +184,7 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": - readout_oper = [ - ProjectReadout(vit_features, start_index) for out_feat in features - ] + readout_oper = [ProjectReadout(vit_features, start_index) for _ in features] else: assert ( False @@ -315,7 +310,7 @@ def _make_vit_b16_backbone( def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) - hooks = [5, 11, 17, 23] if hooks == None else hooks + hooks = [5, 11, 17, 23] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[256, 512, 1024, 1024], @@ -328,7 +323,7 @@ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) - hooks = [2, 5, 8, 11] if hooks == None else hooks + hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout ) @@ -337,7 +332,7 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) - hooks = [2, 5, 8, 11] if hooks == None else hooks + hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout ) @@ -348,7 +343,7 @@ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks= "vit_deit_base_distilled_patch16_384", pretrained=pretrained ) - hooks = [2, 5, 8, 11] if hooks == None else hooks + hooks = [2, 5, 8, 11] if hooks is None else hooks return _make_vit_b16_backbone( model, features=[96, 192, 384, 768], @@ -498,7 +493,7 @@ def _make_pretrained_vitb_rn50_384( ): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) - hooks = [0, 1, 8, 11] if hooks == None else hooks + hooks = [0, 1, 8, 11] if hooks is None else hooks return _make_vit_b_rn50_backbone( model, features=[256, 512, 768, 768], @@ -589,7 +584,7 @@ def _make_efficientnet_backbone(effnet): pretrained = nn.Module() pretrained.layer1 = nn.Sequential( - effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[:2] ) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) @@ -897,13 +892,11 @@ def forward(self, x): path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - out = self.scratch.output_conv(path_1) - - return out + return self.scratch.output_conv(path_1) class DPTDepthModel(DPT): def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs): - features = kwargs["features"] if "features" in kwargs else 256 + features = kwargs.get("features", 256) head = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), diff --git a/encoding.py b/encoding.py index 7edd0960..d97a1bbe 100644 --- a/encoding.py +++ b/encoding.py @@ -40,16 +40,12 @@ def forward(self, input, max_level=None, **kwargs): for i in range(max_level): freq = self.freq_bands[i] - for p_fn in self.periodic_fns: - out.append(p_fn(input * freq)) - + out.extend(p_fn(input * freq) for p_fn in self.periodic_fns) # append 0 if self.N_freqs - max_level > 0: out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype)) - - out = torch.cat(out, dim=-1) - return out + return torch.cat(out, dim=-1) def get_encoder(encoding, input_dim=3, multires=6, diff --git a/evaluation/Prompt.py b/evaluation/Prompt.py index 53603dbb..7451b971 100644 --- a/evaluation/Prompt.py +++ b/evaluation/Prompt.py @@ -86,6 +86,6 @@ print('+' + '-'*52 + '+') for line in wrapped_text.split('\n'): - print('| {} |'.format(line.ljust(50))) + print(f'| {line.ljust(50)} |') print('+' + '-'*52 + '+') #print(result) diff --git a/freqencoder/backend.py b/freqencoder/backend.py index fa0e820b..ed72746e 100644 --- a/freqencoder/backend.py +++ b/freqencoder/backend.py @@ -19,8 +19,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. diff --git a/freqencoder/setup.py b/freqencoder/setup.py index ea641129..3121f82b 100644 --- a/freqencoder/setup.py +++ b/freqencoder/setup.py @@ -20,8 +20,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. diff --git a/gridencoder/backend.py b/gridencoder/backend.py index b403f345..10bf2937 100644 --- a/gridencoder/backend.py +++ b/gridencoder/backend.py @@ -18,8 +18,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: diff --git a/gridencoder/grid.py b/gridencoder/grid.py index 3f91dafc..c9f53bab 100644 --- a/gridencoder/grid.py +++ b/gridencoder/grid.py @@ -194,13 +194,13 @@ def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): @torch.cuda.amp.autocast(enabled=False) def grad_weight_decay(self, weight=0.1): + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + # level-wise meaned weight decay (ref: zip-nerf) - + B = self.embeddings.shape[0] # size of embedding C = self.embeddings.shape[1] # embedding dim for each level L = self.offsets.shape[0] - 1 # level - - if self.embeddings.grad is None: - raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') _backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L) \ No newline at end of file diff --git a/gridencoder/setup.py b/gridencoder/setup.py index a91b0c19..d211834d 100644 --- a/gridencoder/setup.py +++ b/gridencoder/setup.py @@ -19,8 +19,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. diff --git a/guidance/if_utils.py b/guidance/if_utils.py index 0dcce221..d41e1444 100644 --- a/guidance/if_utils.py +++ b/guidance/if_utils.py @@ -39,7 +39,7 @@ def __init__(self, device, vram_O, t_range=[0.02, 0.98]): self.device = device - print(f'[INFO] loading DeepFloyd IF-I-XL...') + print('[INFO] loading DeepFloyd IF-I-XL...') model_key = "DeepFloyd/IF-I-XL-v1.0" @@ -70,7 +70,7 @@ def __init__(self, device, vram_O, t_range=[0.02, 0.98]): self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience - print(f'[INFO] loaded DeepFloyd IF-I-XL!') + print('[INFO] loaded DeepFloyd IF-I-XL!') @torch.no_grad() def get_text_embeds(self, prompt): @@ -79,9 +79,7 @@ def get_text_embeds(self, prompt): # TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28 prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt') - embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] - - return embeddings + return self.text_encoder(inputs.input_ids.to(self.device))[0] def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1): @@ -116,10 +114,7 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1 grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) grad = torch.nan_to_num(grad) - # since we omitted an item in grad, we need to use the custom function to specify the gradient - loss = SpecifyGradient.apply(images, grad) - - return loss + return SpecifyGradient.apply(images, grad) @torch.no_grad() def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5): @@ -129,7 +124,7 @@ def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps self.scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(self.scheduler.timesteps): + for t in self.scheduler.timesteps: # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. model_input = torch.cat([images] * 2) model_input = self.scheduler.scale_model_input(model_input, t) diff --git a/guidance/sd_utils.py b/guidance/sd_utils.py index 3a00ab9f..b7b2760d 100644 --- a/guidance/sd_utils.py +++ b/guidance/sd_utils.py @@ -42,7 +42,7 @@ def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range= self.device = device self.sd_version = sd_version - print(f'[INFO] loading stable diffusion...') + print('[INFO] loading stable diffusion...') if hf_key is not None: print(f'[INFO] using hugging face custom model key: {hf_key}') @@ -84,16 +84,14 @@ def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range= self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience - print(f'[INFO] loaded stable diffusion!') + print('[INFO] loaded stable diffusion!') @torch.no_grad() def get_text_embeds(self, prompt): # prompt: [str] inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') - embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] - - return embeddings + return self.text_encoder(inputs.input_ids.to(self.device))[0] def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, @@ -170,10 +168,7 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=Fa viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) save_image(viz_images, save_guidance_path) - # since we omitted an item in grad, we need to use the custom function to specify the gradient - loss = SpecifyGradient.apply(latents, grad) - - return loss + return SpecifyGradient.apply(latents, grad) @torch.no_grad() def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): @@ -183,7 +178,7 @@ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_ self.scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(self.scheduler.timesteps): + for t in self.scheduler.timesteps: # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) # predict the noise residual @@ -213,9 +208,7 @@ def encode_imgs(self, imgs): imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist - latents = posterior.sample() * self.vae.config.scaling_factor - - return latents + return posterior.sample() * self.vae.config.scaling_factor def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): diff --git a/guidance/zero123_utils.py b/guidance/zero123_utils.py index 15dd8563..9d4a4b64 100644 --- a/guidance/zero123_utils.py +++ b/guidance/zero123_utils.py @@ -140,7 +140,7 @@ def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scal grad_scale = 1.0 # claforte: I think this might converge faster...? else: assert False, f'Unrecognized `zero123_grad_scale`: {self.opt.zero123_grad_scale}' - + if as_latent: latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 else: @@ -172,7 +172,7 @@ def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scal noise_preds = [] # Loop through each ref image - for (zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius) in zip(zero123_ws.T, + for zero123_w, c_crossattn, c_concat, ref_polar, ref_azimuth, ref_radius in zip(zero123_ws.T, embeddings['c_crossattn'], embeddings['c_concat'], ref_polars, ref_azimuths, ref_radii): # polar,azimuth,radius are all actually delta wrt default @@ -183,9 +183,15 @@ def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scal # T = torch.tensor([math.radians(p), math.sin(math.radians(-a)), math.cos(math.radians(a)), r]) # T = T[None, None, :].to(self.device) T = torch.stack([torch.deg2rad(p), torch.sin(torch.deg2rad(-a)), torch.cos(torch.deg2rad(a)), r], dim=-1)[:, None, :] - cond = {} clip_emb = self.model.cc_projection(torch.cat([c_crossattn.repeat(len(T), 1, 1), T], dim=-1)) - cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)] + cond = { + 'c_crossattn': [ + torch.cat( + [torch.zeros_like(clip_emb).to(self.device), clip_emb], + dim=0, + ) + ] + } cond['c_concat'] = [torch.cat([torch.zeros_like(c_concat).repeat(len(T), 1, 1, 1).to(self.device), c_concat.repeat(len(T), 1, 1, 1)], dim=0)] noise_pred = self.model.apply_model(x_in, t_in, cond) noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) @@ -239,10 +245,7 @@ def train_step(self, embeddings, pred_rgb, polar, azimuth, radius, guidance_scal viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1) save_image(viz_images, save_guidance_path) - # since we omitted an item in grad, we need to use the custom function to specify the gradient - loss = SpecifyGradient.apply(latents, grad) - - return loss + return SpecifyGradient.apply(latents, grad) # verification @torch.no_grad() @@ -259,16 +262,21 @@ def __call__(self, T = torch.tensor([math.radians(polar), math.sin(math.radians(azimuth)), math.cos(math.radians(azimuth)), radius]) T = T[None, None, :].to(self.device) - cond = {} clip_emb = self.model.cc_projection(torch.cat([embeddings['c_crossattn'] if c_crossattn is None else c_crossattn, T], dim=-1)) - cond['c_crossattn'] = [torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)] + cond = { + 'c_crossattn': [ + torch.cat( + [torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0 + ) + ] + } cond['c_concat'] = [torch.cat([torch.zeros_like(embeddings['c_concat']).to(self.device), embeddings['c_concat']], dim=0)] if c_concat is None else [torch.cat([torch.zeros_like(c_concat).to(self.device), c_concat], dim=0)] # produce latents loop latents = torch.randn((1, 4, h // 8, w // 8), device=self.device) self.scheduler.set_timesteps(ddim_steps) - for i, t in enumerate(self.scheduler.timesteps): + for t in self.scheduler.timesteps: x_in = torch.cat([latents] * 2) t_in = torch.cat([t.view(1)] * 2).to(self.device) @@ -295,8 +303,15 @@ def encode_imgs(self, imgs): # imgs: [B, 3, 256, 256] RGB space image # with self.model.ema_scope(): imgs = imgs * 2 - 1 - latents = torch.cat([self.model.get_first_stage_encoding(self.model.encode_first_stage(img.unsqueeze(0))) for img in imgs], dim=0) - return latents # [B, 4, 32, 32] Latent space image + return torch.cat( + [ + self.model.get_first_stage_encoding( + self.model.encode_first_stage(img.unsqueeze(0)) + ) + for img in imgs + ], + dim=0, + ) if __name__ == '__main__': @@ -325,10 +340,10 @@ def encode_imgs(self, imgs): image = image.astype(np.float32) / 255.0 image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) - print(f'[INFO] loading model ...') + print('[INFO] loading model ...') zero123 = Zero123(device, opt.fp16, opt=opt) - print(f'[INFO] running model ...') + print('[INFO] running model ...') outputs = zero123(image, polar=opt.polar, azimuth=opt.azimuth, radius=opt.radius) plt.imshow(outputs[0]) plt.show() \ No newline at end of file diff --git a/ldm/extras.py b/ldm/extras.py index 62e654b3..45961a9a 100755 --- a/ldm/extras.py +++ b/ldm/extras.py @@ -39,8 +39,8 @@ def load_training_dir(train_dir, device, epoch="last"): train_dir = Path(train_dir) ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" - config = list(train_dir.rglob(f"*-project.yaml")) - assert len(ckpt) > 0, f"didn't find any config in {train_dir}" + config = list(train_dir.rglob("*-project.yaml")) + assert ckpt, f"didn't find any config in {train_dir}" if len(config) > 1: print(f"found {len(config)} matching config files") config = sorted(config)[-1] diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py index be39da9c..7d88c94d 100755 --- a/ldm/lr_scheduler.py +++ b/ldm/lr_scheduler.py @@ -19,15 +19,14 @@ def schedule(self, n, **kwargs): if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start - self.last_lr = lr - return lr else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 1 + np.cos(t * np.pi)) - self.last_lr = lr - return lr + + self.last_lr = lr + return lr def __call__(self, n, **kwargs): return self.schedule(n,**kwargs) @@ -50,11 +49,9 @@ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosit self.verbosity_interval = verbosity_interval def find_in_interval(self, n): - interval = 0 - for cl in self.cum_cycles[1:]: + for interval, cl in enumerate(self.cum_cycles[1:]): if n <= cl: return interval - interval += 1 def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) @@ -64,15 +61,14 @@ def schedule(self, n, **kwargs): f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f else: t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = min(t, 1.0) f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 1 + np.cos(t * np.pi)) - self.last_f = f - return f + + self.last_f = f + return f def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) @@ -89,10 +85,9 @@ def schedule(self, n, **kwargs): if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f else: f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) - self.last_f = f - return f + + self.last_f = f + return f diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py index 6a9c4f45..98abda7e 100755 --- a/ldm/models/autoencoder.py +++ b/ldm/models/autoencoder.py @@ -81,7 +81,7 @@ def init_from_ckpt(self, path, ignore_keys=list()): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") @@ -106,20 +106,16 @@ def encode_to_prequant(self, x): def decode(self, quant): quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec + return self.decoder(quant) def decode_code(self, code_b): quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec + return self.decode(quant_b) def forward(self, input, return_pred_indices=False): quant, diff, (_,_,ind) = self.encode(input) dec = self.decode(quant) - if return_pred_indices: - return dec, diff, ind - return dec, diff + return (dec, diff, ind) if return_pred_indices else (dec, diff) def get_input(self, batch, k): x = batch[k] @@ -170,19 +166,27 @@ def validation_step(self, batch, batch_idx): def _validation_step(self, batch, batch_idx, suffix=""): x = self.get_input(batch, self.image_key) xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split=f"val{suffix}", + predicted_indices=ind, + ) + + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split=f"val{suffix}", + predicted_indices=ind, + ) rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] self.log(f"val{suffix}/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) @@ -231,7 +235,7 @@ def get_last_layer(self): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: @@ -278,8 +282,7 @@ def decode(self, h, force_not_quantize=False): else: quant = h quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec + return self.decoder(quant) class AutoencoderKL(pl.LightningModule): @@ -316,7 +319,7 @@ def init_from_ckpt(self, path, ignore_keys=list()): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] self.load_state_dict(sd, strict=False) print(f"Restored from {path}") @@ -324,20 +327,15 @@ def init_from_ckpt(self, path, ignore_keys=list()): def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior + return DiagonalGaussianDistribution(moments) def decode(self, z): z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec + return self.decoder(z) def forward(self, input, sample_posterior=True): posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() + z = posterior.sample() if sample_posterior else posterior.mode() dec = self.decode(z) return dec, posterior @@ -345,8 +343,7 @@ def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x + return x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() def training_step(self, batch, batch_idx, optimizer_idx): inputs = self.get_input(batch, self.image_key) @@ -399,7 +396,7 @@ def get_last_layer(self): @torch.no_grad() def log_images(self, batch, only_inputs=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if not only_inputs: @@ -435,9 +432,7 @@ def decode(self, x, *args, **kwargs): return x def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x + return (x, None, [None, None, None]) if self.vq_interface else x def forward(self, x, *args, **kwargs): return x diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py index 67e98b9d..e7b75d21 100755 --- a/ldm/models/diffusion/classifier.py +++ b/ldm/models/diffusion/classifier.py @@ -75,7 +75,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): for k in keys: for ik in ignore_keys: if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) + print(f"Deleting key {k} from state_dict.") del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) @@ -139,11 +139,11 @@ def get_conditioning(self, batch, k=None): if self.label_key == 'segmentation': targets = rearrange(targets, 'b h w c -> b c h w') - for down in range(self.numd): + for _ in range(self.numd): h, w = targets.shape[-2:] targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') - # targets = rearrange(targets,'b c h w -> b h w c') + # targets = rearrange(targets,'b c h w -> b h w c') return targets @@ -161,8 +161,7 @@ def on_train_epoch_start(self): @torch.no_grad() def write_logs(self, loss, logits, targets): log_prefix = 'train' if self.training else 'val' - log = {} - log[f"{log_prefix}/loss"] = loss.mean() + log = {f"{log_prefix}/loss": loss.mean()} log[f"{log_prefix}/acc@1"] = self.compute_top_k( logits, targets, k=1, reduction="mean" ) @@ -236,10 +235,8 @@ def configure_optimizers(self): @torch.no_grad() def log_images(self, batch, N=8, *args, **kwargs): - log = dict() x = self.get_input(batch, self.diffusion_model.first_stage_key) - log['inputs'] = x - + log = {'inputs': x} y = self.get_conditioning(batch) if self.label_key == 'class_label': diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 0683d16f..f8dcd339 100755 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -96,9 +96,8 @@ def sample(self, if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + elif conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling @@ -134,14 +133,10 @@ def ddim_sampling(self, cond, shape, t_start=-1): device = self.model.betas.device b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - + img = torch.randn(shape, device=device) if x_T is None else x_T if timesteps is None: timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: + elif not ddim_use_original_steps: subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] @@ -287,7 +282,7 @@ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=No out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} if return_intermediates: - out.update({'intermediates': intermediates}) + out['intermediates'] = intermediates return x_next, out @torch.no_grad() diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 3fcb7adc..d06a65a4 100755 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -119,7 +119,7 @@ def __init__(self, if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) - self.ucg_training = ucg_training or dict() + self.ucg_training = ucg_training or {} if self.ucg_training: self.ucg_prng = np.random.RandomState() @@ -209,7 +209,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): desc="Fitting old weights to new weights", total=n_params ): - if not name in sd: + if name not in sd: continue old_shape = sd[name].shape new_shape = param.shape @@ -218,7 +218,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): # we only modify first two axes assert new_shape[2:] == old_shape[2:] # assumes first axis corresponds to output dim - if not new_shape == old_shape: + if new_shape != old_shape: new_param = param.clone() old_param = sd[name] if len(new_shape) == 1: @@ -310,9 +310,7 @@ def p_sample_loop(self, shape, return_intermediates=False): clip_denoised=self.clip_denoised) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) - if return_intermediates: - return img, intermediates - return img + return (img, intermediates) if return_intermediates else img @torch.no_grad() def sample(self, batch_size=16, return_intermediates=False): @@ -346,7 +344,6 @@ def p_losses(self, x_start, t, noise=None): x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t) - loss_dict = {} if self.parameterization == "eps": target = noise elif self.parameterization == "x0": @@ -358,15 +355,15 @@ def p_losses(self, x_start, t, noise=None): log_prefix = 'train' if self.training else 'val' - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_dict = {f'{log_prefix}/loss_simple': loss.mean()} loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + loss_dict[f'{log_prefix}/loss_vlb'] = loss_vlb loss = loss_simple + self.original_elbo_weight * loss_vlb - loss_dict.update({f'{log_prefix}/loss': loss}) + loss_dict[f'{log_prefix}/loss'] = loss return loss, loss_dict @@ -418,7 +415,7 @@ def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + loss_dict_ema = {f'{key}_ema': loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) @@ -435,15 +432,13 @@ def _get_rows_from_list(self, samples): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) x = x.to(self.device)[:N] - log["inputs"] = x - + log = {"inputs": x} # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -475,9 +470,8 @@ def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt + params += [self.logvar] + return torch.optim.AdamW(params, lr=lr) class LatentDiffusion(DDPM): @@ -575,31 +569,34 @@ def instantiate_first_stage(self, config): param.requires_grad = False def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") - self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: + if self.cond_stage_trainable: assert config != '__is_first_stage__' assert config != '__is_unconditional__' model = instantiate_from_config(config) self.cond_stage_model = model + elif config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): - denoise_row = [] - for zd in tqdm(samples, desc=desc): - denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) + denoise_row = [ + self.decode_first_stage( + zd.to(self.device), + force_not_quantize=force_no_decoder_quantization, + ) + for zd in tqdm(samples, desc=desc) + ] n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') @@ -633,8 +630,7 @@ def meshgrid(self, h, w): y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) - arr = torch.cat([y, x], dim=-1) - return arr + return torch.cat([y, x], dim=-1) def delta_border(self, h, w): """ @@ -647,8 +643,7 @@ def delta_border(self, h, w): arr = self.meshgrid(h, w) / lower_right_corner dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] - return edge_dist + return torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) @@ -769,100 +764,100 @@ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): z = 1. / self.scale_factor * z - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) + if ( + hasattr(self, "split_input_params") + and not self.split_input_params["patch_distributed_vq"] + and isinstance(self.first_stage_model, VQModelInterface) + or not hasattr(self, "split_input_params") + and isinstance(self.first_stage_model, VQModelInterface) + ): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + elif ( + hasattr(self, "split_input_params") + and not self.split_input_params["patch_distributed_vq"] + or not hasattr(self, "split_input_params") + ): + return self.first_stage_model.decode(z) else: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] else: - return self.first_stage_model.decode(z) + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded # @torch.no_grad() # wasted two hours to find this bug... why no grad here! def encode_first_stage(self, x): - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - df = self.split_input_params["vqf"] - self.split_input_params['original_image_size'] = x.shape[-2:] - bs, nc, h, w = x.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) - z = unfold(x) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] + if ( + not hasattr(self, "split_input_params") + or not self.split_input_params["patch_distributed_vq"] + ): + return self.first_stage_model.encode(x) + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") - o = torch.stack(output_list, axis=-1) - o = o * weighting + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization - return decoded + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - else: - return self.first_stage_model.encode(x) - else: - return self.first_stage_model.encode(x) + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded def shared_step(self, batch, **kwargs): x, c = self.get_input(batch, self.first_stage_key) - loss = self(x, c) - return loss + return self(x, c) def forward(self, x, c, *args, **kwargs): t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() @@ -887,10 +882,7 @@ def rescale_bbox(bbox): def apply_model(self, x_noisy, t, cond, return_ids=False): - if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict - pass - else: + if not isinstance(cond, dict): if not isinstance(cond, list): cond = [cond] key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' @@ -978,10 +970,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False): else: x_recon = self.model(x_noisy, t, **cond) - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon + return x_recon[0] if isinstance(x_recon, tuple) and not return_ids else x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ @@ -1006,7 +995,6 @@ def p_losses(self, x_start, cond, t, noise=None): x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_output = self.apply_model(x_noisy, t, cond) - loss_dict = {} prefix = 'train' if self.training else 'val' if self.parameterization == "x0": @@ -1017,22 +1005,21 @@ def p_losses(self, x_start, cond, t, noise=None): raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - + loss_dict = {f'{prefix}/loss_simple': loss_simple.mean()} logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) + loss_dict[f'{prefix}/loss_gamma'] = loss.mean() + loss_dict['logvar'] = self.logvar.data.mean() loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss_dict[f'{prefix}/loss_vlb'] = loss_vlb loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) + loss_dict[f'{prefix}/loss'] = loss return loss, loss_dict @@ -1111,10 +1098,7 @@ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quanti shape = [batch_size] + list(shape) else: b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T + img = torch.randn(shape, device=self.device) if x_T is None else x_T intermediates = [] if cond is not None: if isinstance(cond, dict): @@ -1164,11 +1148,7 @@ def p_sample_loop(self, cond, shape, return_intermediates=False, log_every_t = self.log_every_t device = self.betas.device b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - + img = torch.randn(shape, device=device) if x_T is None else x_T intermediates = [img] if timesteps is None: timesteps = self.num_timesteps @@ -1201,9 +1181,7 @@ def p_sample_loop(self, cond, shape, return_intermediates=False, if callback: callback(i) if img_callback: img_callback(img, i) - if return_intermediates: - return img, intermediates - return img + return (img, intermediates) if return_intermediates else img @torch.no_grad() def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, @@ -1239,24 +1217,25 @@ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): @torch.no_grad() def get_unconditional_conditioning(self, batch_size, null_label=None, image_size=512): - if null_label is not None: - xc = null_label - if isinstance(xc, ListConfig): - xc = list(xc) - if isinstance(xc, dict) or isinstance(xc, list): - c = self.get_learned_conditioning(xc) - else: - if hasattr(xc, "to"): - xc = xc.to(self.device) - c = self.get_learned_conditioning(xc) - else: + if null_label is None: # todo: get null label from cond_stage_model raise NotImplementedError() + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if not isinstance(xc, dict) and not isinstance(xc, list): + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) - cond = {} - cond["c_crossattn"] = [c] - cond["c_concat"] = [torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to(self.device)] - return cond + return { + "c_crossattn": [c], + "c_concat": [ + torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to( + self.device + ) + ], + } @torch.no_grad() def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, @@ -1267,7 +1246,6 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None - log = dict() z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1275,8 +1253,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= bs=N) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec + log = {"inputs": x, "reconstruction": xrec} if self.model.conditioning_key is not None: if hasattr(self.cond_stage_model, "decode"): xc = self.cond_stage_model.decode(c) @@ -1294,7 +1271,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta= if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: @@ -1403,13 +1380,13 @@ def configure_optimizers(self): if self.cond_stage_trainable: print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) + params += list(self.cond_stage_model.parameters()) if self.learn_logvar: print('Diffusion model optimizing logvar') params.append(self.logvar) if self.cc_projection is not None: - params = params + list(self.cc_projection.parameters()) + params += list(self.cc_projection.parameters()) print('========== optimizing for cc projection weight ==========') opt = torch.optim.AdamW([{"params": self.model.parameters(), "lr": lr}, diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 080edeec..190c89e5 100755 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -88,9 +88,8 @@ def sample(self, cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + elif conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling @@ -126,14 +125,10 @@ def plms_sampling(self, cond, shape, dynamic_threshold=None): device = self.model.betas.device b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - + img = torch.randn(shape, device=device) if x_T is None else x_T if timesteps is None: timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: + elif not ddim_use_original_steps: subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py index a0ae00fe..e3a648b2 100755 --- a/ldm/models/diffusion/sampling_util.py +++ b/ldm/models/diffusion/sampling_util.py @@ -35,8 +35,7 @@ def renorm_thresholding(x0, value): # re.renorm pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 - pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range - return pred_x0 + return (pred_max - pred_min) * pred_x0 + pred_min def norm_thresholding(x0, value): diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 124effbe..5f8e43c6 100755 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -241,9 +241,17 @@ def __init__(self, in_channels, n_heads, d_head, padding=0) self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, - disable_self_attn=disable_self_attn) - for d in range(depth)] + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim, + disable_self_attn=disable_self_attn, + ) + for _ in range(depth) + ] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a..2c39a0ba 100755 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -133,11 +133,7 @@ def forward(self, x, temb): h = self.conv2(h) if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) return x+h @@ -252,7 +248,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): + for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -395,7 +391,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): + for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -480,8 +476,9 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) + print( + f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions." + ) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, @@ -508,7 +505,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, block = nn.ModuleList() attn = nn.ModuleList() block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): + for _ in range(self.num_res_blocks+1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -593,11 +590,7 @@ def __init__(self, in_channels, out_channels, *args, **kwargs): def forward(self, x): for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - + x = layer(x, None) if i in [1,2,3] else layer(x) h = self.norm_out(x) h = nonlinearity(h) x = self.conv_out(h) @@ -619,7 +612,7 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, for i_level in range(self.num_resolutions): res_block = [] block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): + for _ in range(self.num_res_blocks + 1): res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, @@ -752,13 +745,6 @@ def __init__(self, in_channels=None, learned=False, mode="bilinear"): if self.with_conv: print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) def forward(self, x, scale_factor=1.0): if scale_factor==1.0: diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py index 09f0ae19..ab3ea3a8 100755 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -495,11 +495,11 @@ def __init__( self.out_channels = out_channels if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] - else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError("provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult") + elif len(num_res_blocks) == len(channel_mult): self.num_res_blocks = num_res_blocks + else: + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") #self.num_res_blocks = num_res_blocks if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not @@ -771,10 +771,7 @@ def forward(self, x, timesteps=None, context=None, y=None,**kwargs): h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) + return self.id_predictor(h) if self.predict_codebook_ids else self.out(h) class EncoderUNetModel(nn.Module): @@ -989,8 +986,8 @@ def forward(self, x, timesteps): if self.pool.startswith("spatial"): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = th.cat(results, axis=-1) - return self.out(h) else: h = h.type(x.dtype) - return self.out(h) + + return self.out(h) diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py index f2b8ef90..5ac4e6a9 100755 --- a/ldm/modules/distributions/distributions.py +++ b/ldm/modules/distributions/distributions.py @@ -33,22 +33,22 @@ def __init__(self, parameters, deterministic=False): self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x + return self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) def kl(self, other=None): if self.deterministic: return torch.Tensor([0.]) + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) def nll(self, sample, dims=[1,2,3]): if self.deterministic: @@ -69,11 +69,14 @@ def normal_kl(mean1, logvar1, mean2, logvar2): Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break + tensor = next( + ( + obj + for obj in (mean1, logvar1, mean2, logvar2) + if isinstance(obj, torch.Tensor) + ), + None, + ) assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py index c8c75af4..6c559f2a 100755 --- a/ldm/modules/ema.py +++ b/ldm/modules/ema.py @@ -41,7 +41,7 @@ def forward(self,model): shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -50,7 +50,7 @@ def copy_to(self, model): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index b1afccfc..046ff43b 100755 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -99,8 +99,7 @@ def forward(self, batch, key=None): key = self.key # this is for use in crossattn c = batch[key][:, None] - c = self.embedding(c) - return c + return self.embedding(c) class TransformerEmbedder(AbstractEncoder): @@ -113,8 +112,7 @@ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): def forward(self, tokens): tokens = tokens.to(self.device) # meh - z = self.transformer(tokens, return_embeddings=True) - return z + return self.transformer(tokens, return_embeddings=True) def encode(self, x): return self(x) @@ -133,15 +131,12 @@ def __init__(self, device="cuda", vq_interface=True, max_length=77): def forward(self, text): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) - return tokens + return batch_encoding["input_ids"].to(self.device) @torch.no_grad() def encode(self, text): tokens = self(text) - if not self.vq_interface: - return tokens - return None, None, [None, None, tokens] + return tokens if not self.vq_interface else (None, None, [None, None, tokens]) def decode(self, text): return text @@ -161,12 +156,8 @@ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, emb_dropout=embedding_dropout) def forward(self, text): - if self.use_tknz_fn: - tokens = self.tknz_fn(text)#.to(self.device) - else: - tokens = text - z = self.transformer(tokens, return_embeddings=True) - return z + tokens = self.tknz_fn(text) if self.use_tknz_fn else text + return self.transformer(tokens, return_embeddings=True) def encode(self, text): # output of length 77 @@ -203,8 +194,7 @@ def forward(self, text): tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) - z = outputs.last_hidden_state - return z + return outputs.last_hidden_state def encode(self, text): return self(text) @@ -221,8 +211,8 @@ def __init__(self, model_path, augment=False): p.requires_grad = False # Mapper is trainable self.mapper = torch.nn.Linear(512, 768) - p = 0.25 if augment: + p = 0.25 self.augment = K.AugmentationSequential( K.RandomHorizontalFlip(p=0.5), K.RandomEqualize(p=p), @@ -273,8 +263,7 @@ def forward(self, text): tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) - z = outputs.last_hidden_state - return z + return outputs.last_hidden_state def encode(self, text): return self(text) @@ -301,8 +290,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # c def get_null_cond(self, version, max_length): device = self.mean.device embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) - null_cond = embedder([""]) - return null_cond + return embedder([""]) def preprocess(self, x): # Expects inputs in the range -1, 1 @@ -310,9 +298,7 @@ def preprocess(self, x): interpolation='bicubic',align_corners=True, antialias=self.antialias) x = (x + 1.) / 2. - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x + return kornia.enhance.normalize(x, self.mean, self.std) def forward(self, x): if isinstance(x, list): @@ -366,9 +352,7 @@ def preprocess(self, x): interpolation='bicubic',align_corners=True, antialias=self.antialias) x = (x + 1.) / 2. - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x + return kornia.enhance.normalize(x, self.mean, self.std) def forward(self, x): # x is assumed to be in range [-1,1] @@ -411,14 +395,11 @@ def preprocess(self, x): # Expects inputs in the range -1, 1 randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1)) max_crops = self.max_crops - patches = [] crops = [randcrop(x) for _ in range(max_crops)] - patches.extend(crops) + patches = list(crops) x = torch.cat(patches, dim=0) x = (x + 1.) / 2. - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x + return kornia.enhance.normalize(x, self.mean, self.std) def forward(self, x): # x is assumed to be in range [-1,1] @@ -460,7 +441,7 @@ def __init__(self, self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) def forward(self,x): - for stage in range(self.n_stages): + for _ in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) diff --git a/ldm/modules/evaluate/adm_evaluator.py b/ldm/modules/evaluate/adm_evaluator.py index 508cddf2..f01fd3fe 100755 --- a/ldm/modules/evaluate/adm_evaluator.py +++ b/ldm/modules/evaluate/adm_evaluator.py @@ -116,10 +116,7 @@ def frechet_distance(self, other, eps=1e-6): # product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): - msg = ( - "fid calculation produces singular product; adding %s to diagonal of cov estimates" - % eps - ) + msg = f"fid calculation produces singular product; adding {eps} to diagonal of cov estimates" warnings.warn(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) @@ -128,7 +125,7 @@ def frechet_distance(self, other, eps=1e-6): if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) - raise ValueError("Imaginary component {}".format(m)) + raise ValueError(f"Imaginary component {m}") covmean = covmean.real tr_covmean = np.trace(covmean) @@ -598,7 +595,7 @@ def _download_inception_model(): print("downloading InceptionV3 model...") with requests.get(INCEPTION_V3_URL, stream=True) as r: r.raise_for_status() - tmp_path = INCEPTION_V3_PATH + ".tmp" + tmp_path = f"{INCEPTION_V3_PATH}.tmp" with open(tmp_path, "wb") as f: for chunk in tqdm(r.iter_content(chunk_size=8192)): f.write(chunk) @@ -613,7 +610,7 @@ def _create_feature_graph(input_batch): graph_def.ParseFromString(f.read()) pool3, spatial = tf.import_graph_def( graph_def, - input_map={f"ExpandDims:0": input_batch}, + input_map={"ExpandDims:0": input_batch}, return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], name=prefix, ) @@ -629,7 +626,7 @@ def _create_softmax_graph(input_batch): graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) (matmul,) = tf.import_graph_def( - graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix + graph_def, return_elements=["softmax/logits/MatMul"], name=prefix ) w = matmul.inputs[1] logits = tf.matmul(input_batch, w) @@ -644,7 +641,7 @@ def _update_shapes(pool3): shape = o.get_shape() if shape._dims is not None: # pylint: disable=protected-access # shape = [s.value for s in shape] TF 1.x - shape = [s for s in shape] # TF 2.x + shape = list(shape) new_shape = [] for j, s in enumerate(shape): if s == 1 and j == 0: diff --git a/ldm/modules/evaluate/evaluate_perceptualsim.py b/ldm/modules/evaluate/evaluate_perceptualsim.py index c85fef96..d7db4e98 100755 --- a/ldm/modules/evaluate/evaluate_perceptualsim.py +++ b/ldm/modules/evaluate/evaluate_perceptualsim.py @@ -88,12 +88,10 @@ def forward(self, X): "SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], ) - out = vgg_outputs( + return vgg_outputs( h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7 ) - return out - class alexnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): @@ -135,9 +133,7 @@ def forward(self, X): alexnet_outputs = namedtuple( "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] ) - out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) - - return out + return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) class vgg16(torch.nn.Module): @@ -179,9 +175,7 @@ def forward(self, X): "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"], ) - out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) - - return out + return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) class resnet(torch.nn.Module): @@ -226,9 +220,7 @@ def forward(self, X): outputs = namedtuple( "Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"] ) - out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) - - return out + return outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) # Off-the-shelf deep network class PNet(torch.nn.Module): @@ -280,17 +272,11 @@ def forward(self, in0, in1, retPerLayer=False): all_scores = [] for (kk, out0) in enumerate(outs0): cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk]) - if kk == 0: - val = 1.0 * cur_score - else: - val = val + cur_score + val = 1.0 * cur_score if kk == 0 else val + cur_score if retPerLayer: all_scores += [cur_score] - if retPerLayer: - return (val, all_scores) - else: - return val + return (val, all_scores) if retPerLayer else val @@ -303,33 +289,27 @@ def ssim_metric(img1, img2, mask=None): # The PSNR metric def psnr(img1, img2, mask=None,reshape=False): b = img1.size(0) - if not (mask is None): + if mask is not None: b = img1.size(0) mse_err = (img1 - img2).pow(2) * mask - if reshape: - mse_err = mse_err.reshape(b, -1).sum(dim=1) / ( - 3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1) - ) - else: - mse_err = mse_err.view(b, -1).sum(dim=1) / ( - 3 * mask.view(b, -1).sum(dim=1).clamp(min=1) - ) + mse_err = ( + mse_err.reshape(b, -1).sum(dim=1) + / (3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)) + if reshape + else mse_err.view(b, -1).sum(dim=1) + / (3 * mask.view(b, -1).sum(dim=1).clamp(min=1)) + ) + elif reshape: + mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1) else: - if reshape: - mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1) - else: - mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1) + mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1) - psnr = 10 * (1 / mse_err).log10() - return psnr + return 10 * (1 / mse_err).log10() # The perceptual similarity metric def perceptual_sim(img1, img2, vgg16): - # First extract features - dist = vgg16(img1 * 2 - 1, img2 * 2 - 1) - - return dist + return vgg16(img1 * 2 - 1, img2 * 2 - 1) def load_img(img_name, size=None): try: @@ -343,7 +323,7 @@ def load_img(img_name, size=None): img = transform(img).cuda() img = img.unsqueeze(0) except Exception as e: - print("Failed at loading %s " % img_name) + print(f"Failed at loading {img_name} ") print(e) img = torch.zeros(1, 3, 256, 256).cuda() raise @@ -452,13 +432,12 @@ def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list, values_ssim += [ssim_sim] if psnr_sim != np.float("inf"): values_psnr += [psnr_sim] + elif torch.allclose(p_img, t_img): + equal_count += 1 + print(f"{equal_count} equal src and wrp images.") else: - if torch.allclose(p_img, t_img): - equal_count += 1 - print("{} equal src and wrp images.".format(equal_count)) - else: - ambig_count += 1 - print("{} ambiguous src and wrp images.".format(ambig_count)) + ambig_count += 1 + print(f"{ambig_count} ambiguous src and wrp images.") if take_every_other: n_valuespercsim = [] @@ -525,9 +504,9 @@ def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list, perc_sim = 10000 ssim_sim = -10 psnr_sim = -10 - sample_percsim = list() - sample_ssim = list() - sample_psnr = list() + sample_percsim = [] + sample_ssim = [] + sample_psnr = [] for p_img in pred_imgs: if resize: t_img = load_img(tgt_imgs[0], size=(256,256)) @@ -615,16 +594,14 @@ def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list, folder, pred_img, tgt_img, opts.take_every_other ) - f = open(opts.output_file, 'w') - for key in results: - print("%s for %s: \n" % (key, opts.folder)) - print( - "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) - ) - - f.write("%s for %s: \n" % (key, opts.folder)) - f.write( - "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) - ) + with open(opts.output_file, 'w') as f: + for key in results: + print("%s for %s: \n" % (key, opts.folder)) + print( + "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) + ) - f.close() + f.write("%s for %s: \n" % (key, opts.folder)) + f.write( + "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) + ) diff --git a/ldm/modules/evaluate/frechet_video_distance.py b/ldm/modules/evaluate/frechet_video_distance.py index d9e13c41..0ebb0d65 100755 --- a/ldm/modules/evaluate/frechet_video_distance.py +++ b/ldm/modules/evaluate/frechet_video_distance.py @@ -50,8 +50,7 @@ def preprocess(videos, target_resolution): resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] output_videos = tf.reshape(resized_videos, target_shape) - scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 - return scaled_videos + return 2. * tf.cast(output_videos, tf.float32) / 255. - 1 def _is_in_graph(tensor_name): @@ -81,10 +80,6 @@ def create_id3_embedding(videos,warmup=False,batch_size=16): ValueError: when a provided embedding_layer is not supported. """ - # batch_size = 16 - module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" - - # Making sure that we import the graph separately for # each different input video tensor. module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( @@ -108,7 +103,7 @@ def create_id3_embedding(videos,warmup=False,batch_size=16): with tf.control_dependencies(assert_ops): videos = tf.identity(videos) - module_scope = "%s_apply_default/" % module_name + module_scope = f"{module_name}_apply_default/" # To check whether the module has already been loaded into the graph, we look # for a given tensor name. If this tensor name exists, we assume the function @@ -121,15 +116,18 @@ def create_id3_embedding(videos,warmup=False,batch_size=16): if warmup: video_batch_size = int(videos.shape[0]) assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" - tensor_name = module_scope + "RGB/inception_i3d/Mean:0" + tensor_name = f"{module_scope}RGB/inception_i3d/Mean:0" if not _is_in_graph(tensor_name): + # batch_size = 16 + module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" + + i3d_model = hub.Module(module_spec, name=module_name) i3d_model(videos) # gets the kinetics-i3d-400-logits layer - tensor_name = module_scope + "RGB/inception_i3d/Mean:0" - tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) - return tensor + tensor_name = f"{module_scope}RGB/inception_i3d/Mean:0" + return tf.get_default_graph().get_tensor_by_name(tensor_name) def calculate_fvd(real_activations, diff --git a/ldm/modules/evaluate/ssim.py b/ldm/modules/evaluate/ssim.py index 4e8883cc..9705ffec 100755 --- a/ldm/modules/evaluate/ssim.py +++ b/ldm/modules/evaluate/ssim.py @@ -22,10 +22,9 @@ def gaussian(window_size, sigma): def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - window = Variable( + return Variable( _2D_window.expand(channel, 1, window_size, window_size).contiguous() ) - return window def _ssim( @@ -58,7 +57,7 @@ def _ssim( (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) ) - if not (mask is None): + if mask is not None: b = mask.size(0) ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( @@ -70,10 +69,7 @@ def _ssim( pdb.set_trace - if size_average: - return ssim_map.mean() - else: - return ssim_map.mean(1).mean(1).mean(1) + return ssim_map.mean() if size_average else ssim_map.mean(1).mean(1).mean(1) class SSIM(torch.nn.Module): diff --git a/ldm/modules/evaluate/torch_frechet_video_distance.py b/ldm/modules/evaluate/torch_frechet_video_distance.py index 04856b82..4afb065b 100755 --- a/ldm/modules/evaluate/torch_frechet_video_distance.py +++ b/ldm/modules/evaluate/torch_frechet_video_distance.py @@ -73,7 +73,7 @@ def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_file url_data = None with requests.Session() as session: if verbose: - print("Downloading %s ..." % url, end="", flush=True) + print(f"Downloading {url} ...", end="", flush=True) for attempts_left in reversed(range(num_attempts)): try: with session.get(url) as res: @@ -130,11 +130,13 @@ def get_data_from_str(input_str,nprc = None): pool = mp.Pool(processes=nprc) - vids = [] - for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'): - vids.append(v) - - + vids = list( + tqdm( + pool.imap_unordered(load_video, vid_filelist), + total=len(vid_filelist), + desc='Loading videos...', + ) + ) vids = torch.stack(vids,dim=0).float() return vids @@ -277,18 +279,16 @@ def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only sample_embed.append(feats_sample) ref_embed.append(feats_ref) - out = dict() - if len(sample_embed) > 0: + out = {} + if sample_embed: sample_embed = np.concatenate(sample_embed,axis=0) mu_sample, sigma_sample = compute_stats(sample_embed) - out.update({'mu_sample': mu_sample, - 'sigma_sample': sigma_sample}) + out |= {'mu_sample': mu_sample, 'sigma_sample': sigma_sample} - if len(ref_embed) > 0: + if ref_embed: ref_embed = np.concatenate(ref_embed,axis=0) mu_ref, sigma_ref = compute_stats(ref_embed) - out.update({'mu_ref': mu_ref, - 'sigma_ref': sigma_ref}) + out |= {'mu_ref': mu_ref, 'sigma_ref': sigma_ref} return out diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py index 32ef5616..c3f2f761 100755 --- a/ldm/modules/image_degradation/bsrgan.py +++ b/ldm/modules/image_degradation/bsrgan.py @@ -78,9 +78,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k + return gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) def gm_blur_kernel(mean, cov, size=15): @@ -175,13 +173,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var ZZ_t = ZZ.transpose(0, 1, 3, 2) raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel + return raw_kernel / np.sum(raw_kernel) def fspecial_gaussian(hsize, sigma): @@ -203,8 +195,7 @@ def fspecial_laplacian(alpha): h1 = alpha / (alpha + 1) h2 = (1 - alpha) / (alpha + 1) h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] - h = np.array(h) - return h + return np.array(h) def fspecial(filter_type, *args, **kwargs): @@ -476,10 +467,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): for i in shuffle_order: - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: + if i in [0, 1]: img = add_blur(img, sf=sf) elif i == 2: @@ -565,10 +553,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): for i in shuffle_order: - if i == 0: - image = add_blur(image, sf=sf) - - elif i == 1: + if i in [0, 1]: image = add_blur(image, sf=sf) elif i == 2: @@ -600,17 +585,16 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): if random.random() < jpeg_prob: image = add_JPEG_noise(image) - # elif i == 6: - # # add processed camera sensor noise - # if random.random() < isp_prob and isp_model is not None: - # with torch.no_grad(): - # img, hq = isp_model.forward(img.copy(), hq) + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image":image} - return example + return {"image":image} # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... @@ -651,40 +635,24 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 for i in shuffle_order: - if i == 0: + if i in [0, 7]: img = add_blur(img, sf=sf) - elif i == 1: + elif i in [1, 8]: img = add_resize(img, sf=sf) - elif i == 2: + elif i in [2, 9]: img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 3: + elif i in [3, 10]: if random.random() < poisson_prob: img = add_Poisson_noise(img) - elif i == 4: + elif i in [4, 11]: if random.random() < speckle_prob: img = add_speckle_noise(img) - elif i == 5: + elif i in [5, 12]: if random.random() < isp_prob and isp_model is not None: with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) elif i == 6: img = add_JPEG_noise(img) - elif i == 7: - img = add_blur(img, sf=sf) - elif i == 8: - img = add_resize(img, sf=sf) - elif i == 9: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 10: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 11: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 12: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) else: print('check the shuffle!') @@ -702,29 +670,29 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, f'{str(i)}.png') diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py index dfa76068..ad4e7ec0 100755 --- a/ldm/modules/image_degradation/bsrgan_light.py +++ b/ldm/modules/image_degradation/bsrgan_light.py @@ -78,9 +78,7 @@ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k + return gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) def gm_blur_kernel(mean, cov, size=15): @@ -175,13 +173,7 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var ZZ_t = ZZ.transpose(0, 1, 3, 2) raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel + return raw_kernel / np.sum(raw_kernel) def fspecial_gaussian(hsize, sigma): @@ -203,8 +195,7 @@ def fspecial_laplacian(alpha): h1 = alpha / (alpha + 1) h2 = (1 - alpha) / (alpha + 1) h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] - h = np.array(h) - return h + return np.array(h) def fspecial(filter_type, *args, **kwargs): @@ -480,10 +471,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): for i in shuffle_order: - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: + if i in [0, 1]: img = add_blur(img, sf=sf) elif i == 2: @@ -617,8 +605,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image": image} - return example + return {"image": image} @@ -647,4 +634,4 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0) img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') + util.imsave(img_concat, f'{str(i)}.png') diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py index 0175f155..3779f7c5 100755 --- a/ldm/modules/image_degradation/utils_image.py +++ b/ldm/modules/image_degradation/utils_image.py @@ -65,10 +65,11 @@ def surf(Z, cmap='rainbow', figsize=None): def get_image_paths(dataroot): - paths = None # return None if dataroot is None - if dataroot is not None: - paths = sorted(_get_paths_from_images(dataroot)) - return paths + return ( + sorted(_get_paths_from_images(dataroot)) + if dataroot is not None + else None + ) def _get_paths_from_images(path): @@ -101,8 +102,7 @@ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): # print(w1) # print(h1) for i in w1: - for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) + patches.extend(img[i:i+p_size, j:j+p_size,:] for j in h1) else: patches.append(img) @@ -118,7 +118,9 @@ def imssave(imgs, img_path): for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + new_path = os.path.join( + os.path.dirname(img_path), img_name + '_s{:04d}'.format(i) + '.png' + ) cv2.imwrite(new_path, img) @@ -165,7 +167,7 @@ def mkdirs(paths): def mkdir_and_rename(path): if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() + new_name = f'{path}_archived_{get_timestamp()}' print('Path already exists. Rename it to [{:s}]'.format(new_name)) os.rename(path, new_name) os.makedirs(path) @@ -447,23 +449,19 @@ def augment_img_np3(img, mode=0): return img[::-1, :, :] elif mode == 3: img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img + return img.transpose(1, 0, 2) elif mode == 4: return img[:, ::-1, :] elif mode == 5: img = img[:, ::-1, :] - img = img.transpose(1, 0, 2) - return img + return img.transpose(1, 0, 2) elif mode == 6: img = img[:, ::-1, :] - img = img[::-1, :, :] - return img + return img[::-1, :, :] elif mode == 7: img = img[:, ::-1, :] img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img + return img.transpose(1, 0, 2) def augment_imgs(img_list, hflip=True, rot=True): @@ -622,7 +620,7 @@ def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] #img1 = img1.squeeze() #img2 = img2.squeeze() - if not img1.shape == img2.shape: + if img1.shape != img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] img1 = img1[border:h-border, border:w-border] @@ -631,9 +629,7 @@ def calculate_psnr(img1, img2, border=0): img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img1 - img2)**2) - if mse == 0: - return float('inf') - return 20 * math.log10(255.0 / math.sqrt(mse)) + return float('inf') if mse == 0 else 20 * math.log10(255.0 / math.sqrt(mse)) # -------------------------------------------- @@ -646,7 +642,7 @@ def calculate_ssim(img1, img2, border=0): ''' #img1 = img1.squeeze() #img2 = img2.squeeze() - if not img1.shape == img2.shape: + if img1.shape != img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] img1 = img1[border:h-border, border:w-border] @@ -656,9 +652,7 @@ def calculate_ssim(img1, img2, border=0): return ssim(img1, img2) elif img1.ndim == 3: if img1.shape[2] == 3: - ssims = [] - for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + ssims = [ssim(img1[:,:,i], img2[:,:,i]) for i in range(3)] return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) @@ -767,7 +761,7 @@ def imresize(img, scale, antialiasing=True): # Now the scale should be the same for H and W # input: img: pytorch tensor, CHW or HW [0,1] # output: CHW or HW [0,1] w/o round - need_squeeze = True if img.dim() == 2 else False + need_squeeze = img.dim() == 2 if need_squeeze: img.unsqueeze_(0) in_C, in_H, in_W = img.size() @@ -841,7 +835,7 @@ def imresize_np(img, scale, antialiasing=True): # input: img: Numpy, HWC or HW [0,1] # output: HWC or HW [0,1] w/o round img = torch.from_numpy(img) - need_squeeze = True if img.dim() == 2 else False + need_squeeze = img.dim() == 2 if need_squeeze: img.unsqueeze_(2) diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py index 672c1e32..238a78b2 100755 --- a/ldm/modules/losses/contperceptual.py +++ b/ldm/modules/losses/contperceptual.py @@ -82,13 +82,16 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + f"{split}/total_loss": loss.clone().detach().mean(), + f"{split}/logvar": self.logvar.detach(), + f"{split}/kl_loss": kl_loss.detach().mean(), + f"{split}/nll_loss": nll_loss.detach().mean(), + f"{split}/rec_loss": rec_loss.detach().mean(), + f"{split}/d_weight": d_weight.detach(), + f"{split}/disc_factor": torch.tensor(disc_factor), + f"{split}/g_loss": g_loss.detach().mean(), + } return loss, log if optimizer_idx == 1: @@ -103,9 +106,10 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } + log = { + f"{split}/disc_loss": d_loss.clone().detach().mean(), + f"{split}/logits_real": logits_real.detach().mean(), + f"{split}/logits_fake": logits_fake.detach().mean(), + } return d_loss, log diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py index f6998176..15cdeecd 100755 --- a/ldm/modules/losses/vqperceptual.py +++ b/ldm/modules/losses/vqperceptual.py @@ -14,8 +14,7 @@ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss + return 0.5 * (loss_real + loss_fake) def adopt_weight(weight, global_step, threshold=0, value=0.): if global_step < threshold: @@ -52,18 +51,13 @@ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, assert pixel_loss in ["l1", "l2"] self.codebook_weight = codebook_weight self.pixel_weight = pixelloss_weight - if perceptual_loss == "lpips": - print(f"{self.__class__.__name__}: Running with LPIPS.") - self.perceptual_loss = LPIPS().eval() - else: + if perceptual_loss != "lpips": raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight - if pixel_loss == "l1": - self.pixel_loss = l1 - else: - self.pixel_loss = l2 - + self.pixel_loss = l1 if pixel_loss == "l1" else l2 self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, @@ -131,15 +125,16 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } + log = { + f"{split}/total_loss": loss.clone().detach().mean(), + f"{split}/quant_loss": codebook_loss.detach().mean(), + f"{split}/nll_loss": nll_loss.detach().mean(), + f"{split}/rec_loss": rec_loss.detach().mean(), + f"{split}/p_loss": p_loss.detach().mean(), + f"{split}/d_weight": d_weight.detach(), + f"{split}/disc_factor": torch.tensor(disc_factor), + f"{split}/g_loss": g_loss.detach().mean(), + } if predicted_indices is not None: assert self.n_classes is not None with torch.no_grad(): @@ -160,8 +155,9 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } + log = { + f"{split}/disc_loss": d_loss.clone().detach().mean(), + f"{split}/logits_real": logits_real.detach().mean(), + f"{split}/logits_fake": logits_fake.detach().mean(), + } return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py index 5fc15bf9..eba7bddd 100755 --- a/ldm/modules/x_transformer.py +++ b/ldm/modules/x_transformer.py @@ -91,7 +91,7 @@ def pick_and_pop(keys, d): def group_dict_by_key(cond, d): - return_val = [dict(), dict()] + return_val = [{}, {}] for key in d.keys(): match = bool(cond(key)) ind = int(not match) @@ -423,7 +423,7 @@ def __init__( if cross_attend and not only_cross: default_block = ('a', 'c', 'f') - elif cross_attend and only_cross: + elif cross_attend: default_block = ('c', 'f') else: default_block = ('a', 'f') @@ -467,11 +467,7 @@ def __init__( if isinstance(layer, Attention) and exists(branch_fn): layer = branch_fn(layer) - if gate_residual: - residual_fn = GRUGating(dim) - else: - residual_fn = Residual() - + residual_fn = GRUGating(dim) if gate_residual else Residual() self.layers.append(nn.ModuleList([ norm_fn(), layer, diff --git a/ldm/thirdp/psp/helpers.py b/ldm/thirdp/psp/helpers.py index 983baaa5..929b42ef 100755 --- a/ldm/thirdp/psp/helpers.py +++ b/ldm/thirdp/psp/helpers.py @@ -16,8 +16,7 @@ def forward(self, input): def l2_norm(input, axis=1): norm = torch.norm(input, 2, axis, True) - output = torch.div(input, norm) - return output + return torch.div(input, norm) class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): @@ -25,7 +24,9 @@ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): def get_block(in_channel, depth, num_units, stride=2): - return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + return [Bottleneck(in_channel, depth, stride)] + [ + Bottleneck(depth, depth, 1) for _ in range(num_units - 1) + ] def get_blocks(num_layers): @@ -51,7 +52,9 @@ def get_blocks(num_layers): get_block(in_channel=256, depth=512, num_units=3) ] else: - raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + raise ValueError( + f"Invalid number of layers: {num_layers}. Must be one of [50, 100, 152]" + ) return blocks diff --git a/ldm/thirdp/psp/id_loss.py b/ldm/thirdp/psp/id_loss.py index e08ee095..0a763ec7 100755 --- a/ldm/thirdp/psp/id_loss.py +++ b/ldm/thirdp/psp/id_loss.py @@ -19,5 +19,4 @@ def forward(self, x, crop=False): x = torch.nn.functional.interpolate(x, (256, 256), mode="area") x = x[:, :, 35:223, 32:220] x = self.face_pool(x) - x_feats = self.facenet(x) - return x_feats + return self.facenet(x) diff --git a/ldm/thirdp/psp/model_irse.py b/ldm/thirdp/psp/model_irse.py index 21cedd29..c6b02973 100755 --- a/ldm/thirdp/psp/model_irse.py +++ b/ldm/thirdp/psp/model_irse.py @@ -37,10 +37,10 @@ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=Tru modules = [] for block in blocks: - for bottleneck in block: - modules.append(unit_module(bottleneck.in_channel, - bottleneck.depth, - bottleneck.stride)) + modules.extend( + unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride) + for bottleneck in block + ) self.body = Sequential(*modules) def forward(self, x): @@ -52,35 +52,41 @@ def forward(self, x): def IR_50(input_size): """Constructs a ir-50 model.""" - model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) - return model + return Backbone( + input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False + ) def IR_101(input_size): """Constructs a ir-101 model.""" - model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) - return model + return Backbone( + input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False + ) def IR_152(input_size): """Constructs a ir-152 model.""" - model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) - return model + return Backbone( + input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False + ) def IR_SE_50(input_size): """Constructs a ir_se-50 model.""" - model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) - return model + return Backbone( + input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False + ) def IR_SE_101(input_size): """Constructs a ir_se-101 model.""" - model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) - return model + return Backbone( + input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False + ) def IR_SE_152(input_size): """Constructs a ir_se-152 model.""" - model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) - return model \ No newline at end of file + return Backbone( + input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False + ) \ No newline at end of file diff --git a/ldm/util.py b/ldm/util.py index 7dcad706..440eafcd 100755 --- a/ldm/util.py +++ b/ldm/util.py @@ -42,7 +42,7 @@ def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) - txts = list() + txts = [] for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) @@ -71,7 +71,7 @@ def ismap(x): def isimage(x): if not isinstance(x,torch.Tensor): return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + return len(x.shape) == 4 and x.shape[1] in [3, 1] def exists(x): @@ -123,18 +123,18 @@ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: che weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code ema_power=1., param_names=()): """AdamW that saves EMA versions of the parameters.""" - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if eps < 0.0: + raise ValueError(f"Invalid epsilon value: {eps}") if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") if not 0.0 <= ema_decay <= 1.0: - raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + raise ValueError(f"Invalid ema_decay value: {ema_decay}") defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, ema_power=ema_power, param_names=param_names) diff --git a/nerf/gui.py b/nerf/gui.py index 65faa5cd..7d73c775 100644 --- a/nerf/gui.py +++ b/nerf/gui.py @@ -87,7 +87,7 @@ def __init__(self, opt, trainer, loader=None, debug=True): self.mode = 'image' # choose from ['image', 'depth'] self.shading = 'albedo' - self.dynamic_resolution = True if not self.opt.dmtet else False + self.dynamic_resolution = not self.opt.dmtet self.downscale = 1 self.train_steps = 16 @@ -128,45 +128,44 @@ def train_step(self): def prepare_buffer(self, outputs): if self.mode == 'image': return outputs['image'].astype(np.float32) - else: - depth = outputs['depth'].astype(np.float32) - depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) - return np.expand_dims(depth, -1).repeat(3, -1) + depth = outputs['depth'].astype(np.float32) + depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) + return np.expand_dims(depth, -1).repeat(3, -1) def test_step(self): - if self.need_update or self.spp < self.opt.max_spp: - - starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - starter.record() - - outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading) - - ender.record() - torch.cuda.synchronize() - t = starter.elapsed_time(ender) - - # update dynamic resolution - if self.dynamic_resolution: - # max allowed infer time per-frame is 200 ms - full_t = t / (self.downscale ** 2) - downscale = min(1, max(1/4, math.sqrt(200 / full_t))) - if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: - self.downscale = downscale - - if self.need_update: - self.render_buffer = self.prepare_buffer(outputs) - self.spp = 1 - self.need_update = False - else: - self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) - self.spp += 1 - - dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') - dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') - dpg.set_value("_log_spp", self.spp) - dpg.set_value("_texture", self.render_buffer) + if not self.need_update and self.spp >= self.opt.max_spp: + return + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.cam.mvp, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + # update dynamic resolution + if self.dynamic_resolution: + # max allowed infer time per-frame is 200 ms + full_t = t / (self.downscale ** 2) + downscale = min(1, max(1/4, math.sqrt(200 / full_t))) + if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: + self.downscale = downscale + + if self.need_update: + self.render_buffer = self.prepare_buffer(outputs) + self.spp = 1 + self.need_update = False + else: + self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) + self.spp += 1 + + dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') + dpg.set_value("_log_spp", self.spp) + dpg.set_value("_texture", self.render_buffer) def register_dpg(self): @@ -191,10 +190,13 @@ def register_dpg(self): # text prompt if self.opt.text is not None: - dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text") - + dpg.add_text(f"text: {self.opt.text}", tag="_log_prompt_text") + if self.opt.negative != '': - dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text") + dpg.add_text( + f"negative text: {self.opt.negative}", + tag="_log_prompt_negative_text", + ) # button theme with dpg.theme() as theme_button: @@ -214,7 +216,7 @@ def register_dpg(self): with dpg.group(horizontal=True): dpg.add_text("Infer time: ") dpg.add_text("no data", tag="_log_infer_time") - + with dpg.group(horizontal=True): dpg.add_text("SPP: ") dpg.add_text("1", tag="_log_spp") @@ -269,7 +271,10 @@ def callback_save(sender, app_data): def callback_mesh(sender, app_data): self.trainer.save_mesh() - dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') + dpg.set_value( + "_log_mesh", + f'saved {self.trainer.name}_{self.trainer.epoch}.ply', + ) self.trainer.epoch += 1 # use epoch to indicate different calls. dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) @@ -280,7 +285,7 @@ def callback_mesh(sender, app_data): with dpg.group(horizontal=True): dpg.add_text("", tag="_log_train_log") - + # rendering options with dpg.collapsing_header(label="Options", default_open=True): @@ -302,7 +307,7 @@ def callback_set_dynamic_resolution(sender, app_data): def callback_change_mode(sender, app_data): self.mode = app_data self.need_update = True - + dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) # bg_color picker @@ -383,7 +388,7 @@ def callback_set_abm_ratio(sender, app_data): def callback_change_shading(sender, app_data): self.shading = app_data self.need_update = True - + dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading) @@ -447,9 +452,9 @@ def callback_camera_drag_pan(sender, app_data): dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Right, callback=callback_camera_drag_pan) - + dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) - + # TODO: seems dearpygui doesn't support resizing texture... # def callback_resize(sender, app_data): # self.W = app_data[0] @@ -465,7 +470,7 @@ def callback_camera_drag_pan(sender, app_data): dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) - + dpg.bind_item_theme("_primary_window", theme_no_padding) dpg.setup_dearpygui() diff --git a/nerf/network.py b/nerf/network.py index aceea26b..dbc28535 100644 --- a/nerf/network.py +++ b/nerf/network.py @@ -214,13 +214,10 @@ def density(self, x): def background(self, d): h = self.encoder_bg(d) # [N, C] - - h = self.bg_net(h) - # sigmoid activation for rgb - rgbs = torch.sigmoid(h) + h = self.bg_net(h) - return rgbs + return torch.sigmoid(h) # optimizer utils def get_params(self, lr): diff --git a/nerf/network_grid.py b/nerf/network_grid.py index c308f3df..565d7b4b 100644 --- a/nerf/network_grid.py +++ b/nerf/network_grid.py @@ -18,10 +18,14 @@ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): self.dim_hidden = dim_hidden self.num_layers = num_layers - net = [] - for l in range(num_layers): - net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) - + net = [ + nn.Linear( + self.dim_in if l == 0 else self.dim_hidden, + self.dim_out if l == num_layers - 1 else self.dim_hidden, + bias=bias, + ) + for l in range(num_layers) + ] self.net = nn.ModuleList(net) def forward(self, x): @@ -144,13 +148,10 @@ def density(self, x): def background(self, d): h = self.encoder_bg(d) # [N, C] - - h = self.bg_net(h) - # sigmoid activation for rgb - rgbs = torch.sigmoid(h) + h = self.bg_net(h) - return rgbs + return torch.sigmoid(h) # optimizer utils def get_params(self, lr): diff --git a/nerf/network_grid_taichi.py b/nerf/network_grid_taichi.py index 8fa2efdd..8dee1b0d 100644 --- a/nerf/network_grid_taichi.py +++ b/nerf/network_grid_taichi.py @@ -18,10 +18,14 @@ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): self.dim_hidden = dim_hidden self.num_layers = num_layers - net = [] - for l in range(num_layers): - net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) - + net = [ + nn.Linear( + self.dim_in if l == 0 else self.dim_hidden, + self.dim_out if l == num_layers - 1 else self.dim_hidden, + bias=bias, + ) + for l in range(num_layers) + ] self.net = nn.ModuleList(net) def forward(self, x): @@ -142,13 +146,10 @@ def density(self, x): def background(self, d): h = self.encoder_bg(d) # [N, C] - - h = self.bg_net(h) - # sigmoid activation for rgb - rgbs = torch.sigmoid(h) + h = self.bg_net(h) - return rgbs + return torch.sigmoid(h) # optimizer utils def get_params(self, lr): diff --git a/nerf/network_grid_tcnn.py b/nerf/network_grid_tcnn.py index e270789b..5dac3be9 100644 --- a/nerf/network_grid_tcnn.py +++ b/nerf/network_grid_tcnn.py @@ -20,10 +20,14 @@ def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): self.dim_hidden = dim_hidden self.num_layers = num_layers - net = [] - for l in range(num_layers): - net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) - + net = [ + nn.Linear( + self.dim_in if l == 0 else self.dim_hidden, + self.dim_out if l == num_layers - 1 else self.dim_hidden, + bias=bias, + ) + for l in range(num_layers) + ] self.net = nn.ModuleList(net) def forward(self, x): @@ -152,13 +156,10 @@ def density(self, x): def background(self, d): h = self.encoder_bg(d) # [N, C] - - h = self.bg_net(h) - # sigmoid activation for rgb - rgbs = torch.sigmoid(h) + h = self.bg_net(h) - return rgbs + return torch.sigmoid(h) # optimizer utils def get_params(self, lr): diff --git a/nerf/provider.py b/nerf/provider.py index bc92c8ca..cfc2cfc2 100644 --- a/nerf/provider.py +++ b/nerf/provider.py @@ -231,7 +231,7 @@ def get_default_view_data(self): # sample a low-resolution but full image rays = get_rays(poses, intrinsics, H, W, -1) - data = { + return { 'H': H, 'W': W, 'rays_o': rays['rays_o'], @@ -243,13 +243,11 @@ def get_default_view_data(self): 'radius': self.opt.ref_radii, } - return data - def collate(self, index): - B = len(index) - if self.training: + B = len(index) + # random pose on the fly poses, dirs, thetas, phis, radius = rand_poses(B, self.device, self.opt, radius_range=self.opt.radius_range, theta_range=self.opt.theta_range, phi_range=self.opt.phi_range, return_dirs=True, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, uniform_sphere_rate=self.opt.uniform_sphere_rate) @@ -299,7 +297,7 @@ def collate(self, index): delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] delta_radius = radius - self.opt.default_radius - data = { + return { 'H': self.H, 'W': self.W, 'rays_o': rays['rays_o'], @@ -311,8 +309,6 @@ def collate(self, index): 'radius': delta_radius, } - return data - def dataloader(self, batch_size=None): batch_size = batch_size or self.opt.batch_size loader = DataLoader(list(range(self.size)), batch_size=batch_size, collate_fn=self.collate, shuffle=self.training, num_workers=0) diff --git a/nerf/renderer.py b/nerf/renderer.py index 514ab373..a48fa740 100644 --- a/nerf/renderer.py +++ b/nerf/renderer.py @@ -48,9 +48,7 @@ def sample_pdf(bins, weights, n_samples, det=False): denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom - samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) - - return samples + return bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) @torch.cuda.amp.autocast(enabled=False) def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05): @@ -287,10 +285,10 @@ def __init__(self, opt): self.register_buffer('density_bitfield', density_bitfield) self.mean_density = 0 self.iter_density = 0 - + if self.opt.dmtet: # load dmtet vertices - tets = np.load('tets/{}_tets.npz'.format(self.opt.tet_grid_size)) + tets = np.load(f'tets/{self.opt.tet_grid_size}_tets.npz') self.verts = - torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * 2 # covers [-1, 1] self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') self.tet_scale = torch.tensor([1, 1, 1], dtype=torch.float32, device='cuda') @@ -311,7 +309,7 @@ def __init__(self, opt): self.glctx = dr.RasterizeCudaContext() else: self.glctx = dr.RasterizeGLContext() - + if self.taichi_ray: from einops import rearrange from taichi_modules import RayMarcherTaichi @@ -338,15 +336,14 @@ def __init__(self, opt): @torch.no_grad() def density_blob(self, x): # x: [B, N, 3] - + d = (x ** 2).sum(-1) - - if self.opt.density_activation == 'exp': - g = self.opt.blob_density * torch.exp(- d / (2 * self.opt.blob_radius ** 2)) - else: - g = self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius) - return g + return ( + self.opt.blob_density * torch.exp(-d / (2 * self.opt.blob_radius**2)) + if self.opt.density_activation == 'exp' + else self.opt.blob_density * (1 - torch.sqrt(d) / self.opt.blob_radius) + ) def forward(self, x, d): raise NotImplementedError() @@ -664,7 +661,7 @@ def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', # calculate weight_sum (mask) weights_sum = weights.sum(dim=-1) # [N] - + # calculate depth depth = torch.sum(weights * z_vals, dim=-1) @@ -673,12 +670,7 @@ def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', # mix background color if bg_color is None: - if self.opt.bg_radius > 0: - # use the bg model to calculate bg_color - bg_color = self.background(rays_d) # [N, 3] - else: - bg_color = 1 - + bg_color = self.background(rays_d) if self.opt.bg_radius > 0 else 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) @@ -690,15 +682,15 @@ def run(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='albedo', # orientation loss loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 results['loss_orient'] = loss_orient.sum(-1).mean() - + if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() - + if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None: normal_image = torch.sum(weights.unsqueeze(-1) * (normals + 1) / 2, dim=-2) # [N, 3], in [0, 1] results['normal_image'] = normal_image - + results['image'] = image results['depth'] = depth results['weights'] = weights @@ -735,42 +727,42 @@ def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='alb if light_d.shape[0] > 1: flatten_rays = raymarching.flatten_rays(rays, xyzs.shape[0]).long() light_d = light_d[flatten_rays] - + sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh, binarize) - + # normals related regularizations if self.opt.lambda_orient > 0 and normals is not None: # orientation loss loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 results['loss_orient'] = loss_orient.mean() - + if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() - + if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None: _, _, _, normal_image = raymarching.composite_rays_train(sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize) results['normal_image'] = normal_image - + # weights normalization results['weights'] = weights else: - + # allocate outputs dtype = torch.float32 - + weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) - + n_alive = N rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] rays_t = nears.clone() # [N] step = 0 - + while step < self.opt.max_steps: # hard coded max step # count alive rays @@ -795,12 +787,7 @@ def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='alb # mix background color if bg_color is None: - if self.opt.bg_radius > 0: - # use the bg model to calculate bg_color - bg_color = self.background(rays_d) # [N, 3] - else: - bg_color = 1 - + bg_color = self.background(rays_d) if self.opt.bg_radius > 0 else 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) @@ -811,7 +798,7 @@ def run_cuda(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='alb results['image'] = image results['depth'] = depth results['weights_sum'] = weights_sum - + return results @torch.no_grad() @@ -870,8 +857,6 @@ def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) light_d = safe_normalize(campos + torch.randn_like(campos)).view(-1, 1, 1, 3) # [B, 1, 1, 3] - results = {} - # get mesh sdf = self.sdf deform = torch.tanh(self.deform) / self.opt.tet_grid_size @@ -883,10 +868,10 @@ def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :] faces = faces.int() - + face_normals = torch.cross(v1 - v0, v2 - v0) face_normals = safe_normalize(face_normals) - + vn = torch.zeros_like(verts) vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) @@ -898,7 +883,7 @@ def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, verts_clip = torch.bmm(F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1), mvp.permute(0,2,1)).float() # [B, N, 4] rast, rast_db = dr.rasterize(self.glctx, verts_clip, faces, (h, w)) - + alpha = (rast[..., 3:] > 0).float() xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3] normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces) @@ -934,26 +919,18 @@ def run_dmtet(self, rays_o, rays_d, mvp, h, w, light_d=None, ambient_ratio=1.0, # mix background color if bg_color is None: - if self.opt.bg_radius > 0: - # use the bg model to calculate bg_color - bg_color = self.background(rays_d) # [N, 3] - else: - bg_color = 1 - + bg_color = self.background(rays_d) if self.opt.bg_radius > 0 else 1 if torch.is_tensor(bg_color) and len(bg_color.shape) > 1: bg_color = bg_color.view(-1, h, w, 3) - + depth = rast[:, :, :, [2]] # [B, H, W] color = color + (1 - alpha) * bg_color - results['depth'] = depth - results['image'] = color - results['weights_sum'] = alpha.squeeze(-1) - + results = {'depth': depth, 'image': color, 'weights_sum': alpha.squeeze(-1)} if self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0: normal_image = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3] results['normal_image'] = normal_image - + # regularizations if self.training: if self.opt.lambda_mesh_normal > 0: @@ -998,38 +975,38 @@ def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='a # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) _, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4)) - + # normals related regularizations if self.opt.lambda_orient > 0 and normals is not None: # orientation loss loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 results['loss_orient'] = loss_orient.mean() - + if self.opt.lambda_3d_normal_smooth > 0 and normals is not None: normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) results['loss_normal_perturb'] = (normals - normals_perturb).abs().mean() - + if (self.opt.lambda_2d_normal_smooth > 0 or self.opt.lambda_normal > 0) and normals is not None: _, _, _, normal_image, _ = self.volume_render(sigmas.detach(), (normals + 1) / 2, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4)) results['normal_image'] = normal_image - + # weights normalization results['weights'] = weights else: - + # allocate outputs dtype = torch.float32 - + weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) - + n_alive = N rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] rays_t = hits_t[:, 0, 0] step = 0 - + min_samples = 1 if exp_step_factor == 0 else 4 while step < self.opt.max_steps: # hard coded max step @@ -1046,7 +1023,7 @@ def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='a n_step = max(min(N // n_alive, 64), min_samples) xyzs, dirs, deltas, ts, N_eff_samples = \ - self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive, + self.raymarching_test_taichi(rays_o, rays_d, hits_t[:, 0], rays_alive, self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES, n_step) @@ -1079,12 +1056,7 @@ def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='a # mix background color if bg_color is None: - if self.opt.bg_radius > 0: - # use the bg model to calculate bg_color - bg_color = self.background(rays_d) # [N, 3] - else: - bg_color = 1 - + bg_color = self.background(rays_d) if self.opt.bg_radius > 0 else 1 image = image + self.rearrange(1 - weights_sum, 'n -> n 1') * bg_color image = image.view(*prefix, 3) @@ -1095,7 +1067,7 @@ def run_taichi(self, rays_o, rays_d, light_d=None, ambient_ratio=1.0, shading='a results['image'] = image results['depth'] = depth results['weights_sum'] = weights_sum - + return results @@ -1158,33 +1130,26 @@ def render(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, ** device = rays_o.device if self.dmtet: - results = self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs) + return self.run_dmtet(rays_o, rays_d, mvp, h, w, **kwargs) elif self.cuda_ray: - results = self.run_cuda(rays_o, rays_d, **kwargs) + return self.run_cuda(rays_o, rays_d, **kwargs) elif self.taichi_ray: - results = self.run_taichi(rays_o, rays_d, **kwargs) - else: - if staged: - depth = torch.empty((B, N), device=device) - image = torch.empty((B, N, 3), device=device) - weights_sum = torch.empty((B, N), device=device) - - for b in range(B): - head = 0 - while head < N: - tail = min(head + max_ray_batch, N) - results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs) - depth[b:b+1, head:tail] = results_['depth'] - weights_sum[b:b+1, head:tail] = results_['weights_sum'] - image[b:b+1, head:tail] = results_['image'] - head += max_ray_batch - - results = {} - results['depth'] = depth - results['image'] = image - results['weights_sum'] = weights_sum + return self.run_taichi(rays_o, rays_d, **kwargs) + elif staged: + depth = torch.empty((B, N), device=device) + image = torch.empty((B, N, 3), device=device) + weights_sum = torch.empty((B, N), device=device) - else: - results = self.run(rays_o, rays_d, **kwargs) - - return results + for b in range(B): + head = 0 + while head < N: + tail = min(head + max_ray_batch, N) + results_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs) + depth[b:b+1, head:tail] = results_['depth'] + weights_sum[b:b+1, head:tail] = results_['weights_sum'] + image[b:b+1, head:tail] = results_['image'] + head += max_ray_batch + + return {'depth': depth, 'image': image, 'weights_sum': weights_sum} + else: + return self.run(rays_o, rays_d, **kwargs) diff --git a/nerf/utils.py b/nerf/utils.py index 7983fdda..263a6f85 100644 --- a/nerf/utils.py +++ b/nerf/utils.py @@ -130,7 +130,7 @@ def srgb_to_linear(x): class Trainer(object): def __init__(self, - argv, # command line args + argv, # command line args name, # name of this experiment opt, # extra conf model, # network @@ -258,7 +258,9 @@ def __init__(self, self.log(f'[INFO] Cmdline: {self.argv}') self.log(f'[INFO] opt: {self.opt}') self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') - self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') + self.log( + f'[INFO] #parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}' + ) if self.workspace is not None: if self.use_checkpoint == "scratch": diff --git a/optimizer.py b/optimizer.py index f5bb64fc..b10df4fc 100644 --- a/optimizer.py +++ b/optimizer.py @@ -53,21 +53,18 @@ def __init__(self, max_grad_norm=0.0, no_prox=False, foreach: bool = True): - if not 0.0 <= max_grad_norm: - raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) - if not 0.0 <= lr: - raise ValueError('Invalid learning rate: {}'.format(lr)) - if not 0.0 <= eps: - raise ValueError('Invalid epsilon value: {}'.format(eps)) + if max_grad_norm < 0.0: + raise ValueError(f'Invalid Max grad norm: {max_grad_norm}') + if lr < 0.0: + raise ValueError(f'Invalid learning rate: {lr}') + if eps < 0.0: + raise ValueError(f'Invalid epsilon value: {eps}') if not 0.0 <= betas[0] < 1.0: - raise ValueError('Invalid beta parameter at index 0: {}'.format( - betas[0])) + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') if not 0.0 <= betas[1] < 1.0: - raise ValueError('Invalid beta parameter at index 1: {}'.format( - betas[1])) + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') if not 0.0 <= betas[2] < 1.0: - raise ValueError('Invalid beta parameter at index 2: {}'.format( - betas[2])) + raise ValueError(f'Invalid beta parameter at index 2: {betas[2]}') defaults = dict(lr=lr, betas=betas, eps=eps, @@ -276,7 +273,7 @@ def _multi_tensor_adan( no_prox: bool, clip_global_grad_norm: Tensor, ): - if len(params) == 0: + if not params: return torch._foreach_mul_(grads, clip_global_grad_norm) diff --git a/preprocess_image.py b/preprocess_image.py index f7937b20..887a6102 100644 --- a/preprocess_image.py +++ b/preprocess_image.py @@ -51,9 +51,9 @@ def __call__(self, image): inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16) generated_ids = self.model.generate(**inputs, max_new_tokens=20) - generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() - - return generated_text + return self.processor.batch_decode( + generated_ids, skip_special_tokens=True + )[0].strip() class DPT(): @@ -84,9 +84,7 @@ def __init__(self, task='depth', device='cuda'): # load model checkpoint = torch.load(path, map_location='cpu') if 'state_dict' in checkpoint: - state_dict = {} - for k, v in checkpoint['state_dict'].items(): - state_dict[k[6:]] = v + state_dict = {k[6:]: v for k, v in checkpoint['state_dict'].items()} else: state_dict = checkpoint self.model.load_state_dict(state_dict) @@ -131,7 +129,7 @@ def __call__(self, image): out_caption = os.path.join(out_dir, os.path.basename(opt.path).split('.')[0] + '_caption.txt') # load image - print(f'[INFO] loading image...') + print('[INFO] loading image...') image = cv2.imread(opt.path, cv2.IMREAD_UNCHANGED) if image.shape[-1] == 4: image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) @@ -139,12 +137,12 @@ def __call__(self, image): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # carve background - print(f'[INFO] background removal...') + print('[INFO] background removal...') carved_image = BackgroundRemoval()(image) # [H, W, 4] mask = carved_image[..., -1] > 0 # predict depth - print(f'[INFO] depth estimation...') + print('[INFO] depth estimation...') dpt_depth_model = DPT(task='depth') depth = dpt_depth_model(image)[0] depth[mask] = (depth[mask] - depth[mask].min()) / (depth[mask].max() - depth[mask].min() + 1e-9) @@ -153,7 +151,7 @@ def __call__(self, image): del dpt_depth_model # predict normal - print(f'[INFO] normal estimation...') + print('[INFO] normal estimation...') dpt_normal_model = DPT(task='normal') normal = dpt_normal_model(image)[0] normal = (normal * 255).astype(np.uint8).transpose(1, 2, 0) @@ -162,7 +160,7 @@ def __call__(self, image): # recenter if opt.recenter: - print(f'[INFO] recenter...') + print('[INFO] recenter...') final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8) final_depth = np.zeros((opt.size, opt.size), dtype=np.uint8) final_normal = np.zeros((opt.size, opt.size, 3), dtype=np.uint8) diff --git a/raymarching/backend.py b/raymarching/backend.py index 7cc0d76c..2dbf5162 100644 --- a/raymarching/backend.py +++ b/raymarching/backend.py @@ -18,8 +18,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. diff --git a/raymarching/setup.py b/raymarching/setup.py index 4d32fa7b..fea35bdb 100644 --- a/raymarching/setup.py +++ b/raymarching/setup.py @@ -19,8 +19,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. diff --git a/shencoder/backend.py b/shencoder/backend.py index 4971d5e3..d74ad9ee 100644 --- a/shencoder/backend.py +++ b/shencoder/backend.py @@ -18,8 +18,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. diff --git a/shencoder/setup.py b/shencoder/setup.py index 4633ebda..34d2aa58 100644 --- a/shencoder/setup.py +++ b/shencoder/setup.py @@ -19,8 +19,13 @@ def find_cl_path(): import glob for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: - paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) - if paths: + if paths := sorted( + glob.glob( + r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" + % (program_files, edition) + ), + reverse=True, + ): return paths[0] # If cl.exe is not on path, try to find it. diff --git a/taichi_modules/hash_encoder.py b/taichi_modules/hash_encoder.py index 9a1b7a76..7ed57d09 100644 --- a/taichi_modules/hash_encoder.py +++ b/taichi_modules/hash_encoder.py @@ -173,9 +173,7 @@ def __init__(self, super(HashEncoderTaichi, self).__init__() self.per_level_scale = b - if batch_size < 2048: - batch_size = 2048 - + batch_size = max(batch_size, 2048) # per_level_scale = 1.3195079565048218 print("per_level_scale: ", b) self.offsets = ti.field(ti.i32, shape=(16, )) diff --git a/taichi_modules/ray_march.py b/taichi_modules/ray_march.py index d159d03b..70b3b376 100644 --- a/taichi_modules/ray_march.py +++ b/taichi_modules/ray_march.py @@ -39,7 +39,7 @@ def raymarching_train(rays_o: ti.types.ndarray(ndim=2), t = t1 N_samples = 0 - while (0 <= t) & (t < t2) & (N_samples < max_samples): + while (t >= 0) & (t < t2) & (N_samples < max_samples): xyz = ray_o + t * ray_d dt = calc_dt(t, exp_step_factor, grid_size, scale) mip = ti.max(mip_from_pos(xyz, cascades), @@ -55,11 +55,9 @@ def raymarching_train(rays_o: ti.types.ndarray(ndim=2), # nxyz = ti.ceil(nxyz) idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32)) - occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8)) - # idx = __morton3D(ti.cast(nxyz, ti.uint32)) - # occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32)) - - if occ: + if occ := density_bitfield[ti.u32(idx // 8)] & ( + 1 << ti.u32(idx % 8) + ): t += dt N_samples += 1 else: @@ -98,11 +96,9 @@ def raymarching_train(rays_o: ti.types.ndarray(ndim=2), # nxyz = ti.ceil(nxyz) idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32)) - occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8)) - # idx = __morton3D(ti.cast(nxyz, ti.uint32)) - # occ = density_bitfield[mip, idx//8] & (1 << ti.cast(idx%8, ti.uint32)) - - if occ: + if occ := density_bitfield[ti.u32(idx // 8)] & ( + 1 << ti.u32(idx % 8) + ): s = start_idx + samples xyzs[s, 0] = xyz[0] xyzs[s, 1] = xyz[1] @@ -260,7 +256,7 @@ def raymarching_test_kernel( s = 0 - while (0 <= t) & (t < t2) & (s < N_samples): + while (t >= 0) & (t < t2) & (s < N_samples): xyz = ray_o + t * ray_d dt = calc_dt(t, exp_step_factor, grid_size, scale) mip = ti.max(mip_from_pos(xyz, cascades), @@ -276,9 +272,9 @@ def raymarching_test_kernel( # nxyz = ti.ceil(nxyz) idx = mip * grid_size3 + __morton3D(ti.cast(nxyz, ti.u32)) - occ = density_bitfield[ti.u32(idx // 8)] & (1 << ti.u32(idx % 8)) - - if occ: + if occ := density_bitfield[ti.u32(idx // 8)] & ( + 1 << ti.u32(idx % 8) + ): xyzs[n, s, 0] = xyz[0] xyzs[n, s, 1] = xyz[1] xyzs[n, s, 2] = xyz[2] diff --git a/taichi_modules/utils.py b/taichi_modules/utils.py index 02c2f2a6..9b331b2c 100644 --- a/taichi_modules/utils.py +++ b/taichi_modules/utils.py @@ -218,7 +218,6 @@ def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): def depth2img(depth): depth = (depth - depth.min()) / (depth.max() - depth.min()) - depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8), - cv2.COLORMAP_TURBO) - - return depth_img \ No newline at end of file + return cv2.applyColorMap( + (depth * 255).astype(np.uint8), cv2.COLORMAP_TURBO + ) \ No newline at end of file diff --git a/tets/generate_tets.py b/tets/generate_tets.py index 94c5241f..01fc77c2 100644 --- a/tets/generate_tets.py +++ b/tets/generate_tets.py @@ -20,8 +20,11 @@ def generate_tetrahedron_grid_file(res=32, root='..'): frac = 1.0 / res - command = 'cd %s/quartet; ' % (root) + \ - './quartet_release meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res) + command = ( + f'cd {root}/quartet; ' + + './quartet_release meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' + % (frac, res, res) + ) os.system(command)