From 66aaa14001be9f7cf2b52f84c0dff588e36aabbf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Jun 2024 17:02:05 -0400 Subject: [PATCH] Controlnet refactor. --- comfy/cldm/cldm.py | 9 +++++---- comfy/controlnet.py | 35 ++++++++++----------------------- comfy/ldm/cascade/controlnet.py | 2 +- comfy/t2i_adapter/adapter.py | 10 ++++++++-- 4 files changed, 24 insertions(+), 32 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 28076dd9251..f8a16159452 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -289,7 +289,8 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs): guided_hint = self.input_hint_block(hint, emb, context) - outs = [] + out_output = [] + out_middle = [] hs = [] if self.num_classes is not None: @@ -304,10 +305,10 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs): guided_hint = None else: h = module(h, emb, context) - outs.append(zero_conv(h, emb, context)) + out_output.append(zero_conv(h, emb, context)) h = self.middle_block(h, emb, context) - outs.append(self.middle_block_out(h, emb, context)) + out_middle.append(self.middle_block_out(h, emb, context)) - return outs + return {"middle": out_middle, "output": out_output} diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 8cf4a61a683..f50df68357d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -89,27 +89,12 @@ def inference_memory_requirements(self, dtype): return self.previous_controlnet.inference_memory_requirements(dtype) return 0 - def control_merge(self, control_input, control_output, control_prev, output_dtype): + def control_merge(self, control, control_prev, output_dtype): out = {'input':[], 'middle':[], 'output': []} - if control_input is not None: - for i in range(len(control_input)): - key = 'input' - x = control_input[i] - if x is not None: - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - out[key].insert(0, x) - - if control_output is not None: + for key in control: + control_output = control[key] for i in range(len(control_output)): - if i == (len(control_output) - 1): - key = 'middle' - index = 0 - else: - key = 'output' - index = i x = control_output[i] if x is not None: if self.global_average_pooling: @@ -120,6 +105,7 @@ def control_merge(self, control_input, control_output, control_prev, output_dtyp x = x.to(output_dtype) out[key].append(x) + if control_prev is not None: for x in ['input', 'middle', 'output']: o = out[x] @@ -182,7 +168,7 @@ def get_control(self, x_noisy, t, cond, batched_number): x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) - return self.control_merge(None, control, control_prev, output_dtype) + return self.control_merge(control, control_prev, output_dtype) def copy(self): c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) @@ -490,12 +476,11 @@ def get_control(self, x_noisy, t, cond, batched_number): self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) self.t2i_model.cpu() - control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) - mid = None - if self.t2i_model.xl == True: - mid = control_input[-1:] - control_input = control_input[:-1] - return self.control_merge(control_input, mid, control_prev, x_noisy.dtype) + control_input = {} + for k in self.control_input: + control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k])) + + return self.control_merge(control_input, control_prev, x_noisy.dtype) def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) diff --git a/comfy/ldm/cascade/controlnet.py b/comfy/ldm/cascade/controlnet.py index 5dac5939409..7a52c3c263f 100644 --- a/comfy/ldm/cascade/controlnet.py +++ b/comfy/ldm/cascade/controlnet.py @@ -90,4 +90,4 @@ def forward(self, x): proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)] for i, idx in enumerate(self.proj_blocks): proj_outputs[idx] = self.projections[i](x) - return proj_outputs + return {"input": proj_outputs[::-1]} diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index e9a606b1cd6..10ea18e3266 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -153,7 +153,13 @@ def forward(self, x): features.append(None) features.append(x) - return features + features = features[::-1] + + if self.xl: + return {"input": features[1:], "middle": features[:1]} + else: + return {"input": features} + class LayerNorm(nn.LayerNorm): @@ -290,4 +296,4 @@ def forward(self, x): features.append(None) features.append(x) - return features + return {"input": features[::-1]}