Skip to content

Commit

Permalink
Make --gpu-only put intermediate values in GPU memory instead of cpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 8, 2023
1 parent cdff081 commit 9ac0b48
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 29 deletions.
4 changes: 2 additions & 2 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def encode_image(self, image):
t = outputs[k]
if t is not None:
if k == 'hidden_states':
outputs["penultimate_hidden_states"] = t[-2].cpu()
outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device())
outputs["hidden_states"] = None
else:
outputs[k] = t.cpu()
outputs[k] = t.to(comfy.model_management.intermediate_device())

return outputs

Expand Down
6 changes: 6 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@ def text_encoder_dtype(device=None):
else:
return torch.float32

def intermediate_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")

def vae_device():
return get_torch_device()

Expand Down
4 changes: 2 additions & 2 deletions comfy/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)

samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
samples = samples.to(comfy.model_management.intermediate_device())

cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
Expand All @@ -111,7 +111,7 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
sigmas = sigmas.to(model.load_device)

samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
samples = samples.to(comfy.model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples
Expand Down
23 changes: 12 additions & 11 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __init__(self, sd=None, device=None, config=None):
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()

self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)

Expand All @@ -201,9 +202,9 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):

decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output

Expand All @@ -214,9 +215,9 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
pbar = comfy.utils.ProgressBar(steps)

encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples

Expand All @@ -228,15 +229,15 @@ def decode(self, samples_in):
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)

pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)

pixel_samples = pixel_samples.cpu().movedim(1,-1)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples

def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
Expand All @@ -252,10 +253,10 @@ def encode(self, pixel_samples):
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()

except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
Expand Down
6 changes: 3 additions & 3 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def encode_token_weights(self, token_weight_pairs):

out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].cpu()
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled

Expand All @@ -56,8 +56,8 @@ def encode_token_weights(self, token_weight_pairs):
output.append(z)

if (len(output) == 0):
return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled

class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
Expand Down
12 changes: 6 additions & 6 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def lanczos(samples, width, height):
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
return result
return result.to(samples.device, samples.dtype)

def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
Expand Down Expand Up @@ -405,17 +405,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))

@torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x]

ps = function(s_in).cpu()
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_canny.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def INPUT_TYPES(s):

def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,)

NODE_CLASS_MAPPINGS = {
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:
batch_size, height, width, channels = image.shape

kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
Expand Down
6 changes: 3 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,8 +947,8 @@ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, heigh
return (c, )

class EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
def __init__(self):
self.device = comfy.model_management.intermediate_device()

@classmethod
def INPUT_TYPES(s):
Expand All @@ -961,7 +961,7 @@ def INPUT_TYPES(s):
CATEGORY = "latent"

def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )


Expand Down

0 comments on commit 9ac0b48

Please sign in to comment.