diff --git a/comfy/model_base.py b/comfy/model_base.py index 421f271b28a..170b1fd4459 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -15,9 +15,10 @@ class ModelType(Enum): V_PREDICTION = 2 V_PREDICTION_EDM = 3 STABLE_CASCADE = 4 + EDM = 5 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling +from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -33,6 +34,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.STABLE_CASCADE: c = EPS s = StableCascadeSampling + elif model_type == ModelType.EDM: + c = EDM + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5d57a31a182..74908216cee 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -163,7 +163,13 @@ class SDXL(supported_models_base.BASE): latent_format = latent_formats.SDXL def model_type(self, state_dict, prefix=""): - if "v_pred" in state_dict: + if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 + self.latent_format = latent_formats.SDXL_Playground_2_5() + self.sampling_settings["sigma_data"] = 0.5 + self.sampling_settings["sigma_max"] = 80.0 + self.sampling_settings["sigma_min"] = 0.002 + return model_base.ModelType.EDM + elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: return model_base.ModelType.EPS