-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored main branch #1
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,10 +42,7 @@ def __init__(self, start_index=1): | |
self.start_index = start_index | ||
|
||
def forward(self, x): | ||
if self.start_index == 2: | ||
readout = (x[:, 0] + x[:, 1]) / 2 | ||
else: | ||
readout = x[:, 0] | ||
readout = (x[:, 0] + x[:, 1]) / 2 if self.start_index == 2 else x[:, 0] | ||
return x[:, self.start_index :] + readout.unsqueeze(1) | ||
|
||
|
||
|
@@ -84,10 +81,10 @@ def forward_vit(pretrained, x): | |
layer_3 = pretrained.activations["3"] | ||
layer_4 = pretrained.activations["4"] | ||
|
||
layer_1 = pretrained.act_postprocess1[0:2](layer_1) | ||
layer_2 = pretrained.act_postprocess2[0:2](layer_2) | ||
layer_3 = pretrained.act_postprocess3[0:2](layer_3) | ||
layer_4 = pretrained.act_postprocess4[0:2](layer_4) | ||
layer_1 = pretrained.act_postprocess1[:2](layer_1) | ||
layer_2 = pretrained.act_postprocess2[:2](layer_2) | ||
layer_3 = pretrained.act_postprocess3[:2](layer_3) | ||
layer_4 = pretrained.act_postprocess4[:2](layer_4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
unflattened_dim = 2 | ||
|
@@ -96,7 +93,7 @@ def forward_vit(pretrained, x): | |
int(torch.div(w, pretrained.model.patch_size[0], rounding_mode='floor')), | ||
) | ||
unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size)) | ||
|
||
|
||
if layer_1.ndim == 3: | ||
layer_1 = unflatten(layer_1) | ||
|
@@ -107,10 +104,10 @@ def forward_vit(pretrained, x): | |
if layer_4.ndim == 3: | ||
layer_4 = unflatten_with_named_tensor(layer_4, unflattened_dim, unflattened_size) | ||
|
||
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) | ||
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) | ||
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) | ||
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) | ||
layer_1 = pretrained.act_postprocess1[3:](layer_1) | ||
layer_2 = pretrained.act_postprocess2[3:](layer_2) | ||
layer_3 = pretrained.act_postprocess3[3:](layer_3) | ||
layer_4 = pretrained.act_postprocess4[3:](layer_4) | ||
|
||
return layer_1, layer_2, layer_3, layer_4 | ||
|
||
|
@@ -187,9 +184,7 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): | |
elif use_readout == "add": | ||
readout_oper = [AddReadout(start_index)] * len(features) | ||
elif use_readout == "project": | ||
readout_oper = [ | ||
ProjectReadout(vit_features, start_index) for out_feat in features | ||
] | ||
readout_oper = [ProjectReadout(vit_features, start_index) for _ in features] | ||
Comment on lines
-190
to
+187
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
else: | ||
assert ( | ||
False | ||
|
@@ -315,7 +310,7 @@ def _make_vit_b16_backbone( | |
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): | ||
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) | ||
|
||
hooks = [5, 11, 17, 23] if hooks == None else hooks | ||
hooks = [5, 11, 17, 23] if hooks is None else hooks | ||
Comment on lines
-318
to
+313
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return _make_vit_b16_backbone( | ||
model, | ||
features=[256, 512, 1024, 1024], | ||
|
@@ -328,7 +323,7 @@ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): | |
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): | ||
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) | ||
|
||
hooks = [2, 5, 8, 11] if hooks == None else hooks | ||
hooks = [2, 5, 8, 11] if hooks is None else hooks | ||
Comment on lines
-331
to
+326
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return _make_vit_b16_backbone( | ||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout | ||
) | ||
|
@@ -337,7 +332,7 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): | |
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): | ||
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) | ||
|
||
hooks = [2, 5, 8, 11] if hooks == None else hooks | ||
hooks = [2, 5, 8, 11] if hooks is None else hooks | ||
Comment on lines
-340
to
+335
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return _make_vit_b16_backbone( | ||
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout | ||
) | ||
|
@@ -348,7 +343,7 @@ def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks= | |
"vit_deit_base_distilled_patch16_384", pretrained=pretrained | ||
) | ||
|
||
hooks = [2, 5, 8, 11] if hooks == None else hooks | ||
hooks = [2, 5, 8, 11] if hooks is None else hooks | ||
Comment on lines
-351
to
+346
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return _make_vit_b16_backbone( | ||
model, | ||
features=[96, 192, 384, 768], | ||
|
@@ -498,7 +493,7 @@ def _make_pretrained_vitb_rn50_384( | |
): | ||
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) | ||
|
||
hooks = [0, 1, 8, 11] if hooks == None else hooks | ||
hooks = [0, 1, 8, 11] if hooks is None else hooks | ||
Comment on lines
-501
to
+496
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return _make_vit_b_rn50_backbone( | ||
model, | ||
features=[256, 512, 768, 768], | ||
|
@@ -589,7 +584,7 @@ def _make_efficientnet_backbone(effnet): | |
pretrained = nn.Module() | ||
|
||
pretrained.layer1 = nn.Sequential( | ||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] | ||
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[:2] | ||
Comment on lines
-592
to
+587
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
) | ||
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) | ||
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) | ||
|
@@ -897,13 +892,11 @@ def forward(self, x): | |
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) | ||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) | ||
|
||
out = self.scratch.output_conv(path_1) | ||
|
||
return out | ||
return self.scratch.output_conv(path_1) | ||
Comment on lines
-900
to
+895
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
class DPTDepthModel(DPT): | ||
def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs): | ||
features = kwargs["features"] if "features" in kwargs else 256 | ||
features = kwargs.get("features", 256) | ||
Comment on lines
-906
to
+899
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
head = nn.Sequential( | ||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,16 +40,12 @@ def forward(self, input, max_level=None, **kwargs): | |
|
||
for i in range(max_level): | ||
freq = self.freq_bands[i] | ||
for p_fn in self.periodic_fns: | ||
out.append(p_fn(input * freq)) | ||
|
||
out.extend(p_fn(input * freq) for p_fn in self.periodic_fns) | ||
# append 0 | ||
if self.N_freqs - max_level > 0: | ||
out.append(torch.zeros(*input.shape[:-1], (self.N_freqs - max_level) * 2 * input.shape[-1], device=input.device, dtype=input.dtype)) | ||
|
||
out = torch.cat(out, dim=-1) | ||
|
||
return out | ||
return torch.cat(out, dim=-1) | ||
Comment on lines
-43
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def get_encoder(encoding, input_dim=3, | ||
multires=6, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,6 +86,6 @@ | |
|
||
print('+' + '-'*52 + '+') | ||
for line in wrapped_text.split('\n'): | ||
print('| {} |'.format(line.ljust(50))) | ||
print(f'| {line.ljust(50)} |') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
print('+' + '-'*52 + '+') | ||
#print(result) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,13 @@ def find_cl_path(): | |
import glob | ||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: | ||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): | ||
Comment on lines
-22
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return paths[0] | ||
|
||
# If cl.exe is not on path, try to find it. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,8 +20,13 @@ def find_cl_path(): | |
import glob | ||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: | ||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): | ||
Comment on lines
-23
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return paths[0] | ||
|
||
# If cl.exe is not on path, try to find it. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,8 +18,13 @@ def find_cl_path(): | |
import glob | ||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: | ||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): | ||
Comment on lines
-21
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return paths[0] | ||
# If cl.exe is not on path, try to find it. | ||
if os.system("where cl.exe >nul 2>nul") != 0: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,13 +194,13 @@ def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): | |
|
||
@torch.cuda.amp.autocast(enabled=False) | ||
def grad_weight_decay(self, weight=0.1): | ||
if self.embeddings.grad is None: | ||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') | ||
|
||
# level-wise meaned weight decay (ref: zip-nerf) | ||
|
||
B = self.embeddings.shape[0] # size of embedding | ||
C = self.embeddings.shape[1] # embedding dim for each level | ||
L = self.offsets.shape[0] - 1 # level | ||
|
||
if self.embeddings.grad is None: | ||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') | ||
Comment on lines
+197
to
-204
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
_backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,13 @@ def find_cl_path(): | |
import glob | ||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]: | ||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: | ||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): | ||
Comment on lines
-22
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return paths[0] | ||
|
||
# If cl.exe is not on path, try to find it. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ def __init__(self, device, vram_O, t_range=[0.02, 0.98]): | |
|
||
self.device = device | ||
|
||
print(f'[INFO] loading DeepFloyd IF-I-XL...') | ||
print('[INFO] loading DeepFloyd IF-I-XL...') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
model_key = "DeepFloyd/IF-I-XL-v1.0" | ||
|
||
|
@@ -70,7 +70,7 @@ def __init__(self, device, vram_O, t_range=[0.02, 0.98]): | |
self.max_step = int(self.num_train_timesteps * t_range[1]) | ||
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience | ||
|
||
print(f'[INFO] loaded DeepFloyd IF-I-XL!') | ||
print('[INFO] loaded DeepFloyd IF-I-XL!') | ||
|
||
@torch.no_grad() | ||
def get_text_embeds(self, prompt): | ||
|
@@ -79,9 +79,7 @@ def get_text_embeds(self, prompt): | |
# TODO: should I add the preprocessing at https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py#LL486C10-L486C28 | ||
prompt = self.pipe._text_preprocessing(prompt, clean_caption=False) | ||
inputs = self.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, add_special_tokens=True, return_tensors='pt') | ||
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] | ||
|
||
return embeddings | ||
return self.text_encoder(inputs.input_ids.to(self.device))[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1): | ||
|
@@ -116,10 +114,7 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1 | |
grad = grad_scale * w[:, None, None, None] * (noise_pred - noise) | ||
grad = torch.nan_to_num(grad) | ||
|
||
# since we omitted an item in grad, we need to use the custom function to specify the gradient | ||
loss = SpecifyGradient.apply(images, grad) | ||
|
||
return loss | ||
return SpecifyGradient.apply(images, grad) | ||
Comment on lines
-119
to
+117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
@torch.no_grad() | ||
def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps=50, guidance_scale=7.5): | ||
|
@@ -129,7 +124,7 @@ def produce_imgs(self, text_embeddings, height=64, width=64, num_inference_steps | |
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
for i, t in enumerate(self.scheduler.timesteps): | ||
for t in self.scheduler.timesteps: | ||
Comment on lines
-132
to
+127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | ||
model_input = torch.cat([images] * 2) | ||
model_input = self.scheduler.scale_model_input(model_input, t) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range= | |
self.device = device | ||
self.sd_version = sd_version | ||
|
||
print(f'[INFO] loading stable diffusion...') | ||
print('[INFO] loading stable diffusion...') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
if hf_key is not None: | ||
print(f'[INFO] using hugging face custom model key: {hf_key}') | ||
|
@@ -84,16 +84,14 @@ def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range= | |
self.max_step = int(self.num_train_timesteps * t_range[1]) | ||
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience | ||
|
||
print(f'[INFO] loaded stable diffusion!') | ||
print('[INFO] loaded stable diffusion!') | ||
|
||
@torch.no_grad() | ||
def get_text_embeds(self, prompt): | ||
# prompt: [str] | ||
|
||
inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') | ||
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] | ||
|
||
return embeddings | ||
return self.text_encoder(inputs.input_ids.to(self.device))[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1, | ||
|
@@ -170,10 +168,7 @@ def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=Fa | |
viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0) | ||
save_image(viz_images, save_guidance_path) | ||
|
||
# since we omitted an item in grad, we need to use the custom function to specify the gradient | ||
loss = SpecifyGradient.apply(latents, grad) | ||
|
||
return loss | ||
return SpecifyGradient.apply(latents, grad) | ||
Comment on lines
-173
to
+171
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
@torch.no_grad() | ||
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): | ||
|
@@ -183,7 +178,7 @@ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_ | |
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
|
||
for i, t in enumerate(self.scheduler.timesteps): | ||
for t in self.scheduler.timesteps: | ||
Comment on lines
-186
to
+181
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | ||
latent_model_input = torch.cat([latents] * 2) | ||
# predict the noise residual | ||
|
@@ -213,9 +208,7 @@ def encode_imgs(self, imgs): | |
imgs = 2 * imgs - 1 | ||
|
||
posterior = self.vae.encode(imgs).latent_dist | ||
latents = posterior.sample() * self.vae.config.scaling_factor | ||
|
||
return latents | ||
return posterior.sample() * self.vae.config.scaling_factor | ||
Comment on lines
-216
to
+211
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
AddReadout.forward
refactored with the following changes:assign-if-exp
)