diff --git a/denoising_diffusion_pytorch/__init__.py b/denoising_diffusion_pytorch/__init__.py index 05ded85d2..c63fcb09c 100644 --- a/denoising_diffusion_pytorch/__init__.py +++ b/denoising_diffusion_pytorch/__init__.py @@ -8,4 +8,4 @@ from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D -from denoising_diffusion_pytorch.karras_unet import KarrasUnet +from denoising_diffusion_pytorch.karras_unet import KarrasUnet, InvSqrtDecayLRSched diff --git a/denoising_diffusion_pytorch/karras_unet.py b/denoising_diffusion_pytorch/karras_unet.py index f25cba636..1139843d4 100644 --- a/denoising_diffusion_pytorch/karras_unet.py +++ b/denoising_diffusion_pytorch/karras_unet.py @@ -9,6 +9,7 @@ import torch from torch import nn, einsum from torch.nn import Module, ModuleList +from torch.optim.lr_scheduler import LambdaLR import torch.nn.functional as F from einops import rearrange, repeat, pack, unpack @@ -680,6 +681,21 @@ def forward(self, x): return x +# works best with inverse square root decay schedule + +def InvSqrtDecayLRSched( + optimizer, + t_ref = 70000, + sigma_ref = 0.01 +): + """ + refer to equation 67 and Table1 + """ + def inv_sqrt_decay_fn(step: int): + return sigma_ref / sqrt(max(t / t_ref, 1.)) + + return LambdaLR(optimizer, lr_lambda = inv_sqrt_decay_fn) + # example if __name__ == '__main__': diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 420b4d0ed..37f0254fb 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.5' +__version__ = '1.10.6'