Skip to content

Commit

Permalink
add temporary workaround for playground-v2
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Dec 6, 2023
1 parent 46cf41c commit 5713d63
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion library/sdxl_model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def convert_key(key):
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)

# temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd:
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"]

return new_sd, logit_scale


Expand Down Expand Up @@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)

# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
Expand Down

0 comments on commit 5713d63

Please sign in to comment.