Skip to content

Commit

Permalink
AuraFlow model implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 11, 2024
1 parent f45157e commit 9f291d7
Show file tree
Hide file tree
Showing 12 changed files with 1,744 additions and 2 deletions.
479 changes: 479 additions & 0 deletions comfy/ldm/aura/mmdit.py

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.aura.mmdit
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.model_management
Expand Down Expand Up @@ -598,6 +599,17 @@ def memory_required(self, input_shape):
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)

class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out


class StableAudio1(BaseModel):
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
Expand Down
8 changes: 8 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["audio_model"] = "dit1.0"
return unet_config

if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit
unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
return unet_config

if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

Expand Down Expand Up @@ -253,6 +259,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
def unet_prefix_from_state_dict(state_dict):
if "model.model.postprocess_conv.weight" in state_dict: #audio models
unet_key_prefix = "model.model."
elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
unet_key_prefix = "model."
else:
unet_key_prefix = "model.diffusion_model."
return unet_key_prefix
Expand Down
4 changes: 4 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
import comfy.text_encoders.aura_t5

import comfy.model_patcher
import comfy.lora
Expand Down Expand Up @@ -415,6 +416,9 @@ class EmptyClass:
if weight.shape[-1] == 4096:
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, dtype_t5=dtype_t5)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif weight.shape[-1] == 2048:
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif "encoder.block.0.layer.0.SelfAttention.k.weight" in clip_data[0]:
clip_target.clip = sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer
Expand Down
24 changes: 23 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
import comfy.text_encoders.aura_t5

from . import supported_models_base
from . import latent_formats
Expand Down Expand Up @@ -556,7 +557,28 @@ def process_unet_state_dict_for_saving(self, state_dict):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)

class AuraFlow(supported_models_base.BASE):
unet_config = {
"cond_seq_dim": 2048,
}

sampling_settings = {
"multiplier": 1.0,
}

unet_extra_config = {}
latent_format = latent_formats.SDXL

vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
out = model_base.AuraFlow(self, device=device)
return out

def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)

models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio]
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]

models += [SVD_img2vid]
22 changes: 22 additions & 0 deletions comfy/text_encoders/aura_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from comfy import sd1_clip
from transformers import LlamaTokenizerFast
import comfy.t5
import os

class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)

class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=LlamaTokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)

class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="pile_t5xl", tokenizer=PT5XlTokenizer)

class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
22 changes: 22 additions & 0 deletions comfy/text_encoders/t5_pile_config_xl.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"d_ff": 5120,
"d_kv": 64,
"d_model": 2048,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 2,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "umt5",
"num_decoder_layers": 24,
"num_heads": 32,
"num_layers": 24,
"output_past": true,
"pad_token_id": 1,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 32128
}
102 changes: 102 additions & 0 deletions comfy/text_encoders/t5_pile_tokenizer/added_tokens.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
"<extra_id_0>": 32099,
"<extra_id_10>": 32089,
"<extra_id_11>": 32088,
"<extra_id_12>": 32087,
"<extra_id_13>": 32086,
"<extra_id_14>": 32085,
"<extra_id_15>": 32084,
"<extra_id_16>": 32083,
"<extra_id_17>": 32082,
"<extra_id_18>": 32081,
"<extra_id_19>": 32080,
"<extra_id_1>": 32098,
"<extra_id_20>": 32079,
"<extra_id_21>": 32078,
"<extra_id_22>": 32077,
"<extra_id_23>": 32076,
"<extra_id_24>": 32075,
"<extra_id_25>": 32074,
"<extra_id_26>": 32073,
"<extra_id_27>": 32072,
"<extra_id_28>": 32071,
"<extra_id_29>": 32070,
"<extra_id_2>": 32097,
"<extra_id_30>": 32069,
"<extra_id_31>": 32068,
"<extra_id_32>": 32067,
"<extra_id_33>": 32066,
"<extra_id_34>": 32065,
"<extra_id_35>": 32064,
"<extra_id_36>": 32063,
"<extra_id_37>": 32062,
"<extra_id_38>": 32061,
"<extra_id_39>": 32060,
"<extra_id_3>": 32096,
"<extra_id_40>": 32059,
"<extra_id_41>": 32058,
"<extra_id_42>": 32057,
"<extra_id_43>": 32056,
"<extra_id_44>": 32055,
"<extra_id_45>": 32054,
"<extra_id_46>": 32053,
"<extra_id_47>": 32052,
"<extra_id_48>": 32051,
"<extra_id_49>": 32050,
"<extra_id_4>": 32095,
"<extra_id_50>": 32049,
"<extra_id_51>": 32048,
"<extra_id_52>": 32047,
"<extra_id_53>": 32046,
"<extra_id_54>": 32045,
"<extra_id_55>": 32044,
"<extra_id_56>": 32043,
"<extra_id_57>": 32042,
"<extra_id_58>": 32041,
"<extra_id_59>": 32040,
"<extra_id_5>": 32094,
"<extra_id_60>": 32039,
"<extra_id_61>": 32038,
"<extra_id_62>": 32037,
"<extra_id_63>": 32036,
"<extra_id_64>": 32035,
"<extra_id_65>": 32034,
"<extra_id_66>": 32033,
"<extra_id_67>": 32032,
"<extra_id_68>": 32031,
"<extra_id_69>": 32030,
"<extra_id_6>": 32093,
"<extra_id_70>": 32029,
"<extra_id_71>": 32028,
"<extra_id_72>": 32027,
"<extra_id_73>": 32026,
"<extra_id_74>": 32025,
"<extra_id_75>": 32024,
"<extra_id_76>": 32023,
"<extra_id_77>": 32022,
"<extra_id_78>": 32021,
"<extra_id_79>": 32020,
"<extra_id_7>": 32092,
"<extra_id_80>": 32019,
"<extra_id_81>": 32018,
"<extra_id_82>": 32017,
"<extra_id_83>": 32016,
"<extra_id_84>": 32015,
"<extra_id_85>": 32014,
"<extra_id_86>": 32013,
"<extra_id_87>": 32012,
"<extra_id_88>": 32011,
"<extra_id_89>": 32010,
"<extra_id_8>": 32091,
"<extra_id_90>": 32009,
"<extra_id_91>": 32008,
"<extra_id_92>": 32007,
"<extra_id_93>": 32006,
"<extra_id_94>": 32005,
"<extra_id_95>": 32004,
"<extra_id_96>": 32003,
"<extra_id_97>": 32002,
"<extra_id_98>": 32001,
"<extra_id_99>": 32000,
"<extra_id_9>": 32090
}
125 changes: 125 additions & 0 deletions comfy/text_encoders/t5_pile_tokenizer/special_tokens_map.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_99>",
"<extra_id_98>",
"<extra_id_97>",
"<extra_id_96>",
"<extra_id_95>",
"<extra_id_94>",
"<extra_id_93>",
"<extra_id_92>",
"<extra_id_91>",
"<extra_id_90>",
"<extra_id_89>",
"<extra_id_88>",
"<extra_id_87>",
"<extra_id_86>",
"<extra_id_85>",
"<extra_id_84>",
"<extra_id_83>",
"<extra_id_82>",
"<extra_id_81>",
"<extra_id_80>",
"<extra_id_79>",
"<extra_id_78>",
"<extra_id_77>",
"<extra_id_76>",
"<extra_id_75>",
"<extra_id_74>",
"<extra_id_73>",
"<extra_id_72>",
"<extra_id_71>",
"<extra_id_70>",
"<extra_id_69>",
"<extra_id_68>",
"<extra_id_67>",
"<extra_id_66>",
"<extra_id_65>",
"<extra_id_64>",
"<extra_id_63>",
"<extra_id_62>",
"<extra_id_61>",
"<extra_id_60>",
"<extra_id_59>",
"<extra_id_58>",
"<extra_id_57>",
"<extra_id_56>",
"<extra_id_55>",
"<extra_id_54>",
"<extra_id_53>",
"<extra_id_52>",
"<extra_id_51>",
"<extra_id_50>",
"<extra_id_49>",
"<extra_id_48>",
"<extra_id_47>",
"<extra_id_46>",
"<extra_id_45>",
"<extra_id_44>",
"<extra_id_43>",
"<extra_id_42>",
"<extra_id_41>",
"<extra_id_40>",
"<extra_id_39>",
"<extra_id_38>",
"<extra_id_37>",
"<extra_id_36>",
"<extra_id_35>",
"<extra_id_34>",
"<extra_id_33>",
"<extra_id_32>",
"<extra_id_31>",
"<extra_id_30>",
"<extra_id_29>",
"<extra_id_28>",
"<extra_id_27>",
"<extra_id_26>",
"<extra_id_25>",
"<extra_id_24>",
"<extra_id_23>",
"<extra_id_22>",
"<extra_id_21>",
"<extra_id_20>",
"<extra_id_19>",
"<extra_id_18>",
"<extra_id_17>",
"<extra_id_16>",
"<extra_id_15>",
"<extra_id_14>",
"<extra_id_13>",
"<extra_id_12>",
"<extra_id_11>",
"<extra_id_10>",
"<extra_id_9>",
"<extra_id_8>",
"<extra_id_7>",
"<extra_id_6>",
"<extra_id_5>",
"<extra_id_4>",
"<extra_id_3>",
"<extra_id_2>",
"<extra_id_1>",
"<extra_id_0>"
],
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
Binary file not shown.
Loading

0 comments on commit 9f291d7

Please sign in to comment.