From f6a5c359cc4d6d34df18a6f6ab90b1462395eef9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 16 Jan 2023 14:25:25 +0100 Subject: [PATCH] [Community] Fix merger (#2006) * [Community] Fix merger * finish --- examples/community/checkpoint_merger.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 764129e85900..115af6f92e59 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline): """ def __init__(self): + self.register_to_config() super().__init__() def _compare_model_configs(self, dict0, dict1): @@ -167,6 +168,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] final_pipe = DiffusionPipeline.from_pretrained( cached_folders[0], torch_dtype=torch_dtype, device_map=device_map ) + final_pipe.to(self.device) checkpoint_path_2 = None if len(cached_folders) > 2: @@ -202,9 +204,9 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike] theta_0 = theta_0() update_theta_0 = getattr(module, "load_state_dict") - theta_1 = torch.load(checkpoint_path_1) + theta_1 = torch.load(checkpoint_path_1, map_location="cpu") - theta_2 = torch.load(checkpoint_path_2) if checkpoint_path_2 else None + theta_2 = torch.load(checkpoint_path_2, map_location="cpu") if checkpoint_path_2 else None if not theta_0.keys() == theta_1.keys(): print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")