forked from comfyanonymous/ComfyUI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsupported_models_base.py
95 lines (77 loc) · 3.33 KB
/
supported_models_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from . import model_base
from . import utils
from . import latent_formats
class ClipTarget:
def __init__(self, tokenizer, clip):
self.clip = clip
self.tokenizer = tokenizer
self.params = {}
class BASE:
unet_config = {}
unet_extra_config = {
"num_heads": -1,
"num_head_channels": 64,
}
required_keys = {}
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
manual_cast_dtype = None
@classmethod
def matches(s, unet_config, state_dict=None):
for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True
def model_type(self, state_dict, prefix=""):
return model_base.ModelType.EPS
def inpaint_model(self):
return self.unet_config["in_channels"] > 4
def __init__(self, unet_config):
self.unet_config = unet_config.copy()
self.sampling_settings = self.sampling_settings.copy()
self.latent_format = self.latent_format()
for x in self.unet_extra_config:
self.unet_config[x] = self.unet_extra_config[x]
def get_model(self, state_dict, prefix="", device=None):
if self.noise_aug_config is not None:
out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device)
else:
out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device)
if self.inpaint_model():
out.set_inpaint()
return out
def process_clip_state_dict(self, state_dict):
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_vae_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": self.text_encoder_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_clip_vision_state_dict_for_saving(self, state_dict):
replace_prefix = {}
if self.clip_vision_prefix is not None:
replace_prefix[""] = self.clip_vision_prefix
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype