Skip to content

Commit

Permalink
Auto detect playground v2.5 model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 27, 2024
1 parent d46583e commit 8daedc5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 5 additions & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ class ModelType(Enum):
V_PREDICTION = 2
V_PREDICTION_EDM = 3
STABLE_CASCADE = 4
EDM = 5


from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling


def model_sampling(model_config, model_type):
Expand All @@ -33,6 +34,9 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.STABLE_CASCADE:
c = EPS
s = StableCascadeSampling
elif model_type == ModelType.EDM:
c = EDM
s = ModelSamplingContinuousEDM

class ModelSampling(s, c):
pass
Expand Down
8 changes: 7 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL

def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict:
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
self.sampling_settings["sigma_data"] = 0.5
self.sampling_settings["sigma_max"] = 80.0
self.sampling_settings["sigma_min"] = 0.002
return model_base.ModelType.EDM
elif "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
else:
return model_base.ModelType.EPS
Expand Down

0 comments on commit 8daedc5

Please sign in to comment.