Skip to content

Commit

Permalink
[Community] Fix merger (huggingface#2006)
Browse files Browse the repository at this point in the history
* [Community] Fix merger

* finish
  • Loading branch information
patrickvonplaten authored Jan 16, 2023
1 parent 651c5ad commit f6a5c35
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/community/checkpoint_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
"""

def __init__(self):
self.register_to_config()
super().__init__()

def _compare_model_configs(self, dict0, dict1):
Expand Down Expand Up @@ -167,6 +168,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
final_pipe = DiffusionPipeline.from_pretrained(
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
)
final_pipe.to(self.device)

checkpoint_path_2 = None
if len(cached_folders) > 2:
Expand Down Expand Up @@ -202,9 +204,9 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
theta_0 = theta_0()

update_theta_0 = getattr(module, "load_state_dict")
theta_1 = torch.load(checkpoint_path_1)
theta_1 = torch.load(checkpoint_path_1, map_location="cpu")

theta_2 = torch.load(checkpoint_path_2) if checkpoint_path_2 else None
theta_2 = torch.load(checkpoint_path_2, map_location="cpu") if checkpoint_path_2 else None

if not theta_0.keys() == theta_1.keys():
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
Expand Down

0 comments on commit f6a5c35

Please sign in to comment.