Skip to content

Commit

Permalink
Allow model config to preprocess the vae state dict on load.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 21, 2023
1 parent d66b631 commit 6a491eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ class WeightsLoader(torch.nn.Module):

if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)

if output_clip:
Expand Down
3 changes: 3 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def process_clip_state_dict(self, state_dict):
def process_unet_state_dict(self, state_dict):
return state_dict

def process_vae_state_dict(self, state_dict):
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
Expand Down

0 comments on commit 6a491eb

Please sign in to comment.