Skip to content

Commit

Permalink
Add --fast argument to enable experimental optimizations.
Browse files Browse the repository at this point in the history
Optimizations that might break things/lower quality will be put behind
this flag first and might be enabled by default in the future.

Currently the only optimization is float8_e4m3fn matrix multiplication on
4000/ADA series Nvidia cards or later. If you have one of these cards you
will see a speed boost when using fp8_e4m3fn flux for example.
  • Loading branch information
comfyanonymous committed Aug 20, 2024
1 parent d1a6bd6 commit 9953f22
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")

parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
Expand Down
5 changes: 1 addition & 4 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod

if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
Expand Down
10 changes: 10 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,16 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma

return False

def supports_fp8_compute(device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True

def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
Expand Down
41 changes: 40 additions & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
import comfy.model_management

from comfy.cli_args import args

def cast_to(weight, dtype=None, device=None, non_blocking=False):
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
Expand Down Expand Up @@ -242,3 +242,42 @@ class ConvTranspose1d(disable_weight_init.ConvTranspose1d):

class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True


def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None

if len(input.shape) == 3:
out = torch.empty((input.shape[0], input.shape[1], self.weight.shape[0]), device=input.device, dtype=input.dtype)
inn = input.to(dtype)
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
for i in range(input.shape[0]):
if self.bias is not None:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking))
else:
o, _ = torch._scaled_mm(inn[i], w, out_dtype=input.dtype)
out[i] = o
return out
return None

class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out

weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)


def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if comfy.model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

0 comments on commit 9953f22

Please sign in to comment.