-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored main branch #1
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to GitHub API limits, only the first 60 comments can be shown.
if self.start_index == 2: | ||
readout = (x[:, 0] + x[:, 1]) / 2 | ||
else: | ||
readout = x[:, 0] | ||
readout = (x[:, 0] + x[:, 1]) / 2 if self.start_index == 2 else x[:, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function AddReadout.forward
refactored with the following changes:
- Replace if statement with if expression (
assign-if-exp
)
layer_1 = pretrained.act_postprocess1[0:2](layer_1) | ||
layer_2 = pretrained.act_postprocess2[0:2](layer_2) | ||
layer_3 = pretrained.act_postprocess3[0:2](layer_3) | ||
layer_4 = pretrained.act_postprocess4[0:2](layer_4) | ||
layer_1 = pretrained.act_postprocess1[:2](layer_1) | ||
layer_2 = pretrained.act_postprocess2[:2](layer_2) | ||
layer_3 = pretrained.act_postprocess3[:2](layer_3) | ||
layer_4 = pretrained.act_postprocess4[:2](layer_4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function forward_vit
refactored with the following changes:
- Replace a[0:x] with a[:x] and a[x:len(a)] with a[x:] [×8] (
remove-redundant-slice-index
)
readout_oper = [ | ||
ProjectReadout(vit_features, start_index) for out_feat in features | ||
] | ||
readout_oper = [ProjectReadout(vit_features, start_index) for _ in features] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function get_readout_oper
refactored with the following changes:
- Replace unused for index with underscore (
for-index-underscore
)
hooks = [5, 11, 17, 23] if hooks == None else hooks | ||
hooks = [5, 11, 17, 23] if hooks is None else hooks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function _make_pretrained_vitl16_384
refactored with the following changes:
- Use x is None rather than x == None (
none-compare
)
hooks = [2, 5, 8, 11] if hooks == None else hooks | ||
hooks = [2, 5, 8, 11] if hooks is None else hooks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function _make_pretrained_vitb16_384
refactored with the following changes:
- Use x is None rather than x == None (
none-compare
)
state_dict = {} | ||
for k, v in checkpoint['state_dict'].items(): | ||
state_dict[k[6:]] = v | ||
state_dict = {k[6:]: v for k, v in checkpoint['state_dict'].items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function DPT.__init__
refactored with the following changes:
- Convert for loop into dictionary comprehension (
dict-comprehension
)
print(f'[INFO] loading image...') | ||
print('[INFO] loading image...') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 134-165
refactored with the following changes:
- Replace f-string with no interpolated values with string [×5] (
remove-redundant-fstring
)
print('| {} |'.format(line.ljust(50))) | ||
print(f'| {line.ljust(50)} |') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 89-89
refactored with the following changes:
- Replace call to format with f-string (
use-fstring-for-formatting
)
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function find_cl_path
refactored with the following changes:
- Use named expression to simplify assignment and conditional (
use-named-expression
)
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function find_cl_path
refactored with the following changes:
- Use named expression to simplify assignment and conditional (
use-named-expression
)
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function find_cl_path
refactored with the following changes:
- Use named expression to simplify assignment and conditional (
use-named-expression
)
if self.embeddings.grad is None: | ||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') | ||
|
||
# level-wise meaned weight decay (ref: zip-nerf) | ||
|
||
B = self.embeddings.shape[0] # size of embedding | ||
C = self.embeddings.shape[1] # embedding dim for each level | ||
L = self.offsets.shape[0] - 1 # level | ||
|
||
if self.embeddings.grad is None: | ||
raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function GridEncoder.grad_weight_decay
refactored with the following changes:
- Move assignments closer to their usage (
move-assign
) - Lift code into else after jump in control flow (
reintroduce-else
) - Remove unnecessary else after guard condition (
remove-unnecessary-else
)
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True) | ||
if paths: | ||
if paths := sorted( | ||
glob.glob( | ||
r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" | ||
% (program_files, edition) | ||
), | ||
reverse=True, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function find_cl_path
refactored with the following changes:
- Use named expression to simplify assignment and conditional (
use-named-expression
)
print(f'[INFO] loading DeepFloyd IF-I-XL...') | ||
print('[INFO] loading DeepFloyd IF-I-XL...') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function IF.__init__
refactored with the following changes:
- Replace f-string with no interpolated values with string [×2] (
remove-redundant-fstring
)
embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] | ||
|
||
return embeddings | ||
return self.text_encoder(inputs.input_ids.to(self.device))[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function IF.get_text_embeds
refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable
)
print(f'[INFO] loading model ...') | ||
print('[INFO] loading model ...') | ||
zero123 = Zero123(device, opt.fp16, opt=opt) | ||
|
||
print(f'[INFO] running model ...') | ||
print('[INFO] running model ...') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 328-331
refactored with the following changes:
- Replace f-string with no interpolated values with string [×2] (
remove-redundant-fstring
)
config = list(train_dir.rglob(f"*-project.yaml")) | ||
assert len(ckpt) > 0, f"didn't find any config in {train_dir}" | ||
config = list(train_dir.rglob("*-project.yaml")) | ||
assert ckpt, f"didn't find any config in {train_dir}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function load_training_dir
refactored with the following changes:
- Replace f-string with no interpolated values with string (
remove-redundant-fstring
) - Simplify sequence length comparison (
simplify-len-comparison
)
self.last_lr = lr | ||
return lr | ||
else: | ||
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) | ||
t = min(t, 1.0) | ||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( | ||
1 + np.cos(t * np.pi)) | ||
self.last_lr = lr | ||
return lr | ||
|
||
self.last_lr = lr | ||
return lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function LambdaWarmUpCosineScheduler.schedule
refactored with the following changes:
- Hoist repeated code outside conditional statement [×2] (
hoist-statement-from-if
)
interval = 0 | ||
for cl in self.cum_cycles[1:]: | ||
for interval, cl in enumerate(self.cum_cycles[1:]): | ||
if n <= cl: | ||
return interval | ||
interval += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function LambdaWarmUpCosineScheduler2.find_in_interval
refactored with the following changes:
- Replace manual loop counter with call to enumerate (
convert-to-enumerate
)
self.last_f = f | ||
return f | ||
else: | ||
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) | ||
t = min(t, 1.0) | ||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( | ||
1 + np.cos(t * np.pi)) | ||
self.last_f = f | ||
return f | ||
|
||
self.last_f = f | ||
return f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function LambdaWarmUpCosineScheduler2.schedule
refactored with the following changes:
- Hoist repeated code outside conditional statement [×2] (
hoist-statement-from-if
)
self.last_f = f | ||
return f | ||
else: | ||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) | ||
self.last_f = f | ||
return f | ||
|
||
self.last_f = f | ||
return f |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function LambdaLinearScheduler.schedule
refactored with the following changes:
- Hoist repeated code outside conditional statement [×2] (
hoist-statement-from-if
)
txts = list() | ||
txts = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function log_txt_as_img
refactored with the following changes:
- Replace
list()
with[]
(list-literal
)
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) | ||
return len(x.shape) == 4 and x.shape[1] in [3, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function isimage
refactored with the following changes:
- Replace multiple comparisons of same variable with
in
operator (merge-comparisons
)
if not 0.0 <= lr: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if not 0.0 <= eps: | ||
raise ValueError("Invalid epsilon value: {}".format(eps)) | ||
if lr < 0.0: | ||
raise ValueError(f"Invalid learning rate: {lr}") | ||
if eps < 0.0: | ||
raise ValueError(f"Invalid epsilon value: {eps}") | ||
if not 0.0 <= betas[0] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") | ||
if not 0.0 <= betas[1] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||
if not 0.0 <= weight_decay: | ||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | ||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") | ||
if weight_decay < 0.0: | ||
raise ValueError(f"Invalid weight_decay value: {weight_decay}") | ||
if not 0.0 <= ema_decay <= 1.0: | ||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) | ||
raise ValueError(f"Invalid ema_decay value: {ema_decay}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function AdamWwithEMAandWings.__init__
refactored with the following changes:
- Simplify logical expression using De Morgan identities [×3] (
de-morgan
) - Ensure constant in comparison is on the right [×3] (
flip-comparison
) - Replace call to format with f-string [×6] (
use-fstring-for-formatting
)
print("Deleting key {} from state_dict.".format(k)) | ||
print(f"Deleting key {k} from state_dict.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function VQModel.init_from_ckpt
refactored with the following changes:
- Replace call to format with f-string (
use-fstring-for-formatting
)
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() | ||
return x | ||
return x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function AutoencoderKL.get_input
refactored with the following changes:
- Inline variable that is immediately returned (
inline-immediately-returned-variable
)
log = dict() | ||
log = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function AutoencoderKL.log_images
refactored with the following changes:
- Replace
dict()
with{}
(dict-literal
)
if self.vq_interface: | ||
return x, None, [None, None, None] | ||
return x | ||
return (x, None, [None, None, None]) if self.vq_interface else x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function IdentityFirstStage.quantize
refactored with the following changes:
- Lift code into else after jump in control flow (
reintroduce-else
) - Replace if statement with if expression (
assign-if-exp
)
print("Deleting key {} from state_dict.".format(k)) | ||
print(f"Deleting key {k} from state_dict.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function NoisyLatentImageClassifier.init_from_ckpt
refactored with the following changes:
- Replace call to format with f-string (
use-fstring-for-formatting
)
for down in range(self.numd): | ||
for _ in range(self.numd): | ||
h, w = targets.shape[-2:] | ||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') | ||
|
||
# targets = rearrange(targets,'b c h w -> b h w c') | ||
# targets = rearrange(targets,'b c h w -> b h w c') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function NoisyLatentImageClassifier.get_conditioning
refactored with the following changes:
- Replace unused for index with underscore (
for-index-underscore
)
Branch
main
refactored by Sourcery.If you're happy with these changes, merge this Pull Request using the Squash and merge strategy.
See our documentation here.
Run Sourcery locally
Reduce the feedback loop during development by using the Sourcery editor plugin:
Review changes via command line
To manually merge these changes, make sure you're on the
main
branch, then run:Help us improve this pull request!