Skip to content

Commit

Permalink
Playground V2.5 support with ModelSamplingContinuousEDM node.
Browse files Browse the repository at this point in the history
Use ModelSamplingContinuousEDM with edm_playground_v2.5 selected.
  • Loading branch information
comfyanonymous committed Feb 27, 2024
1 parent 1e0fcc9 commit d46583e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 7 deletions.
27 changes: 27 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch

class LatentFormat:
scale_factor = 1.0
Expand Down Expand Up @@ -34,6 +35,32 @@ def __init__(self):
]
self.taesd_decoder_name = "taesdxl_decoder"

class SDXL_Playground_2_5(LatentFormat):
def __init__(self):
self.scale_factor = 0.5
self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)

self.latent_rgb_factors = [
# R G B
[ 0.3920, 0.4054, 0.4549],
[-0.2634, -0.0196, 0.0653],
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"

def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return (latent - latents_mean) * self.scale_factor / latents_std

def process_out(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean


class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
Expand Down
13 changes: 9 additions & 4 deletions comfy/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5


class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
Expand Down Expand Up @@ -92,18 +97,18 @@ def percent_to_sigma(self, percent):
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
self.sigma_data = 1.0

if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}

sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_parameters(sigma_min, sigma_max, sigma_data)

def set_sigma_range(self, sigma_min, sigma_max):
def set_parameters(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()

self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
Expand Down
2 changes: 1 addition & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)

if latent_image is not None:
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = model.process_latent_in(latent_image)

if hasattr(model, 'extra_conds'):
Expand Down
13 changes: 11 additions & 2 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import torch

class LCM(comfy.model_sampling.EPS):
Expand Down Expand Up @@ -135,7 +136,7 @@ class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["v_prediction", "eps"],),
"sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
}}
Expand All @@ -148,17 +149,25 @@ def INPUT_TYPES(s):
def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone()

latent_format = None
sigma_data = 1.0
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "edm_playground_v2.5":
sampling_type = comfy.model_sampling.EDM
sigma_data = 0.5
latent_format = comfy.latent_formats.SDXL_Playground_2_5()

class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass

model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max)
model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
m.add_object_patch("model_sampling", model_sampling)
if latent_format is not None:
m.add_object_patch("latent_format", latent_format)
return (m, )

class RescaleCFG:
Expand Down

0 comments on commit d46583e

Please sign in to comment.