Skip to content

Commit

Permalink
Add a simple experimental TorchCompileModel node.
Browse files Browse the repository at this point in the history
It probably only works on Linux.

For maximum speed on Flux with Nvidia 40 series/ada and newer try using
this node with fp8_e4m3fn and the --fast argument.
  • Loading branch information
comfyanonymous committed Sep 12, 2024
1 parent 405b529 commit d0b7ab8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
21 changes: 21 additions & 0 deletions comfy_extras/nodes_torch_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "_for_testing"
EXPERIMENTAL = True

def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )

NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,6 +2102,7 @@ def init_builtin_extra_nodes():
"nodes_hunyuan.py",
"nodes_flux.py",
"nodes_lora_extract.py",
"nodes_torch_compile.py",
]

import_failed = []
Expand Down

0 comments on commit d0b7ab8

Please sign in to comment.