From 5713d63dc5840d6726e12e130e00b47162c4ebf4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 6 Dec 2023 23:08:02 +0900 Subject: [PATCH] add temporary workaround for playground-v2 --- library/sdxl_model_util.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 2f0154cae..a844927cd 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -133,6 +133,12 @@ def convert_key(key): # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) + # temporary workaround for text_projection.weight.weight for Playground-v2 + if "text_projection.weight.weight" in new_sd: + print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] + del new_sd["text_projection.weight.weight"] + return new_sd, logit_scale @@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - + # 一部のposition_idsがないモデルへの対応 / add position_ids for some models if "text_model.embeddings.position_ids" not in te1_sd: te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)