Skip to content

Commit

Permalink
add reuse_model option, code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yusing committed Jun 16, 2024
1 parent 6bf3b54 commit cfcf467
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions tensorrt_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
{".engine"},
)


class TQDMProgressMonitor(trt.IProgressMonitor):
def __init__(self):
trt.IProgressMonitor.__init__(self)
Expand Down Expand Up @@ -93,14 +94,18 @@ def step_complete(self, phase_name, step):
except KeyboardInterrupt:
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
return False


class TRT_MODEL_CONVERSION_BASE:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.temp_dir = folder_paths.get_temp_directory()
self.timing_cache_path = os.path.normpath(
os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"))
os.path.join(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"
)
)
)

RETURN_TYPES = ()
Expand Down Expand Up @@ -148,26 +153,30 @@ def _convert(
context_max,
num_video_frames,
is_static: bool,
reuse_model: bool = False,
):
output_onnx = os.path.normpath(
os.path.join(
os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
)
os.path.join(self.temp_dir, str(time.time()), "model.onnx")
)

comfy.model_management.unload_all_models()
comfy.model_management.load_models_gpu([model], force_patch_weights=True)
if not reuse_model:
comfy.model_management.unload_all_models()
comfy.model_management.load_models_gpu([model], force_patch_weights=True)
unet = model.model.diffusion_model

context_dim = model.model.model_config.unet_config.get("context_dim", None)
context_len = 77
context_len_min = context_len

if context_dim is None: #SD3
context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
if context_dim is None: # SD3
context_embedder_config = model.model.model_config.unet_config.get(
"context_embedder_config", None
)
if context_embedder_config is not None:
context_dim = context_embedder_config.get("params", {}).get("in_features", None)
context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
context_dim = context_embedder_config.get("params", {}).get(
"in_features", None
)
context_len = 154 # NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77

if context_dim is not None:
input_names = ["x", "timesteps", "context"]
Expand All @@ -179,7 +188,7 @@ def _convert(
"context": {0: "batch", 1: "num_embeds"},
}

transformer_options = model.model_options['transformer_options'].copy()
transformer_options = model.model_options["transformer_options"].copy()
if model.model.model_config.unet_config.get(
"use_temporal_resblock", False
): # SVD
Expand All @@ -205,7 +214,13 @@ def forward(self, x, timesteps, context, y):
unet = svd_unet
context_len_min = context_len = 1
else:

class UNET(torch.nn.Module):
def __init__(self, unet, opts):
super().__init__()
self.unet = unet
self.transformer_options = opts

def forward(self, x, timesteps, context, y=None):
return self.unet(
x,
Expand All @@ -214,10 +229,8 @@ def forward(self, x, timesteps, context, y=None):
y,
transformer_options=self.transformer_options,
)
_unet = UNET()
_unet.unet = unet
_unet.transformer_options = transformer_options
unet = _unet

unet = UNET(unet, transformer_options)

input_channels = model.model.model_config.unet_config.get("in_channels")

Expand Down Expand Up @@ -272,8 +285,9 @@ def forward(self, x, timesteps, context, y=None):
dynamic_axes=dynamic_axes,
)

comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
if not reuse_model:
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()

# TRT conversion starts here
logger = trt.Logger(trt.Logger.INFO)
Expand Down Expand Up @@ -304,7 +318,9 @@ def forward(self, x, timesteps, context, y=None):
profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)

# Encode shapes to filename
encode = lambda a: ".".join(map(lambda x: str(x), a))
def encode(a):
return ".".join(map(str, a))

prefix_encode += "{}#{}#{}#{};".format(
input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
)
Expand Down Expand Up @@ -589,6 +605,7 @@ def INPUT_TYPES(s):
"step": 1,
},
),
"reuse_model": ("BOOLEAN", {"default": False}),
},
}

Expand All @@ -601,6 +618,7 @@ def convert(
width_opt,
context_opt,
num_video_frames,
reuse_model,
):
return super()._convert(
model,
Expand All @@ -619,6 +637,7 @@ def convert(
context_opt,
num_video_frames,
is_static=True,
reuse_model=reuse_model,
)


Expand Down

0 comments on commit cfcf467

Please sign in to comment.