Skip to content

Commit

Permalink
Greatly improve lowvram sampling speed by getting rid of accelerate.
Browse files Browse the repository at this point in the history
Let me know if this breaks anything.
  • Loading branch information
comfyanonymous committed Dec 22, 2023
1 parent 261bcbb commit 36a7953
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 50 deletions.
2 changes: 1 addition & 1 deletion comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
cm = self.control_model.state_dict()

for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
weight = sd[k]
try:
comfy.utils.set_attr(self.control_model, k, weight)
except:
Expand Down
6 changes: 1 addition & 5 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,7 @@ def process_latent_out(self, latent):

def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
Expand Down
52 changes: 30 additions & 22 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,8 @@ def is_nvidia():
FORCE_FP16 = True

if lowvram_available:
try:
import accelerate
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to


if cpu_state != CPUState.GPU:
Expand Down Expand Up @@ -298,8 +291,20 @@ def model_load(self, lowvram_model_memory=0):

if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = 0
sd = m.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem

self.model_accelerated = True

if is_intel_xpu() and not args.disable_ipex_optimize:
Expand All @@ -309,7 +314,11 @@ def model_load(self, lowvram_model_memory=0):

def model_unload(self):
if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model)
for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights

self.model_accelerated = False

self.model.unpatch_model(self.model.offload_device)
Expand Down Expand Up @@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0

if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024
lowvram_model_memory = 64 * 1024 * 1024

cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
Expand Down Expand Up @@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO
return True
return False

def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True

def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
Expand All @@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True

non_blocking = True
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking
non_blocking = device_supports_non_blocking(device)

if device_supports_cast:
if copy:
Expand Down Expand Up @@ -742,11 +754,7 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight

#TODO: might be cleaner to put this somewhere else
Expand Down
92 changes: 71 additions & 21 deletions comfy/ops.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,93 @@
import torch
from contextlib import contextmanager
import comfy.model_management

def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias


class disable_weight_init:
class Linear(torch.nn.Linear):
comfy_cast_weights = False
def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)

class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)

def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)

class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)

def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)

class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)


class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None

def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)

@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
Expand All @@ -31,35 +97,19 @@ def conv_nd(s, dims, *args, **kwargs):
else:
raise ValueError(f"unsupported dimensions: {dims}")

def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias

class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
comfy_cast_weights = True

class Conv2d(disable_weight_init.Conv2d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
comfy_cast_weights = True

class Conv3d(disable_weight_init.Conv3d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
comfy_cast_weights = True

class GroupNorm(disable_weight_init.GroupNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
comfy_cast_weights = True

class LayerNorm(disable_weight_init.LayerNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
comfy_cast_weights = True
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ einops
transformers>=4.25.1
safetensors>=0.3.0
aiohttp
accelerate
pyyaml
Pillow
scipy
Expand Down

0 comments on commit 36a7953

Please sign in to comment.