From 8e012043a9d0af3979bbe2cea8dc1ec7768f9d88 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 Jul 2024 17:51:56 -0400 Subject: [PATCH] Add a ModelSamplingAuraFlow node to change the shift value. Set the default AuraFlow shift value to 1.73 (sqrt(3)). --- comfy/supported_models.py | 1 + comfy_extras/nodes_model_advanced.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ccf8c333e45..a030f6229f3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -564,6 +564,7 @@ class AuraFlow(supported_models_base.BASE): sampling_settings = { "multiplier": 1.0, + "shift": 1.73, } unet_extra_config = {} diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 97559cf56e3..22ba9547b89 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -144,7 +144,7 @@ def INPUT_TYPES(s): CATEGORY = "advanced/model" - def patch(self, model, shift): + def patch(self, model, shift, multiplier=1000): m = model.clone() sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow @@ -154,10 +154,22 @@ class ModelSamplingAdvanced(sampling_base, sampling_type): pass model_sampling = ModelSamplingAdvanced(model.model.model_config) - model_sampling.set_parameters(shift=shift) + model_sampling.set_parameters(shift=shift, multiplier=multiplier) m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingAuraFlow(ModelSamplingSD3): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + FUNCTION = "patch_aura" + + def patch_aura(self, model, shift): + return self.patch(model, shift, multiplier=1.0) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -271,5 +283,6 @@ def rescale_cfg(args): "ModelSamplingContinuousV": ModelSamplingContinuousV, "ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingSD3": ModelSamplingSD3, + "ModelSamplingAuraFlow": ModelSamplingAuraFlow, "RescaleCFG": RescaleCFG, }