-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
speed.py
195 lines (153 loc) · 7.63 KB
/
speed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# coding: utf-8
"""
Benchmark the inference speed of each module in LivePortrait.
TODO: heavy GPT style, need to refactor
"""
import torch
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
import yaml
import time
import numpy as np
from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig
def initialize_inputs(batch_size=1, device_id=0):
"""
Generate random input tensors and move them to GPU
"""
feature_3d = torch.randn(batch_size, 32, 16, 64, 64).to(device_id).half()
kp_source = torch.randn(batch_size, 21, 3).to(device_id).half()
kp_driving = torch.randn(batch_size, 21, 3).to(device_id).half()
source_image = torch.randn(batch_size, 3, 256, 256).to(device_id).half()
generator_input = torch.randn(batch_size, 256, 64, 64).to(device_id).half()
eye_close_ratio = torch.randn(batch_size, 3).to(device_id).half()
lip_close_ratio = torch.randn(batch_size, 2).to(device_id).half()
feat_stitching = concat_feat(kp_source, kp_driving).half()
feat_eye = concat_feat(kp_source, eye_close_ratio).half()
feat_lip = concat_feat(kp_source, lip_close_ratio).half()
inputs = {
'feature_3d': feature_3d,
'kp_source': kp_source,
'kp_driving': kp_driving,
'source_image': source_image,
'generator_input': generator_input,
'feat_stitching': feat_stitching,
'feat_eye': feat_eye,
'feat_lip': feat_lip
}
return inputs
def load_and_compile_models(cfg, model_config):
"""
Load and compile models for inference
"""
appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
models_with_params = [
('Appearance Feature Extractor', appearance_feature_extractor),
('Motion Extractor', motion_extractor),
('Warping Network', warping_module),
('SPADE Decoder', spade_generator)
]
compiled_models = {}
for name, model in models_with_params:
model = model.half()
model = torch.compile(model, mode='max-autotune') # Optimize for inference
model.eval() # Switch to evaluation mode
compiled_models[name] = model
retargeting_models = ['stitching', 'eye', 'lip']
for retarget in retargeting_models:
module = stitching_retargeting_module[retarget].half()
module = torch.compile(module, mode='max-autotune') # Optimize for inference
module.eval() # Switch to evaluation mode
stitching_retargeting_module[retarget] = module
return compiled_models, stitching_retargeting_module
def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
"""
Warm up models to prepare them for benchmarking
"""
print("Warm up start!")
with torch.no_grad():
for _ in range(10):
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
compiled_models['Motion Extractor'](inputs['source_image'])
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
print("Warm up end!")
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
"""
Measure inference times for each model
"""
times = {name: [] for name in compiled_models.keys()}
times['Stitching and Retargeting Modules'] = []
overall_times = []
with torch.no_grad():
for _ in range(100):
torch.cuda.synchronize()
overall_start = time.time()
start = time.time()
compiled_models['Appearance Feature Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Appearance Feature Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Motion Extractor'](inputs['source_image'])
torch.cuda.synchronize()
times['Motion Extractor'].append(time.time() - start)
start = time.time()
compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
torch.cuda.synchronize()
times['Warping Network'].append(time.time() - start)
start = time.time()
compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required
torch.cuda.synchronize()
times['SPADE Decoder'].append(time.time() - start)
start = time.time()
stitching_retargeting_module['stitching'](inputs['feat_stitching'])
stitching_retargeting_module['eye'](inputs['feat_eye'])
stitching_retargeting_module['lip'](inputs['feat_lip'])
torch.cuda.synchronize()
times['Stitching and Retargeting Modules'].append(time.time() - start)
overall_times.append(time.time() - overall_start)
return times, overall_times
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
"""
Print benchmark results with average and standard deviation of inference times
"""
average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
for name, model in compiled_models.items():
num_params = sum(p.numel() for p in model.parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
for index, retarget in enumerate(retargeting_models):
num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
num_params_in_millions = num_params / 1e6
print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
for name, avg_time in average_times.items():
std_time = std_times[name]
print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
def main():
"""
Main function to benchmark speed and model parameters
"""
# Load configuration
cfg = InferenceConfig()
model_config_path = cfg.models_config
with open(model_config_path, 'r') as file:
model_config = yaml.safe_load(file)
# Sample input tensors
inputs = initialize_inputs(device_id = cfg.device_id)
# Load and compile models
compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
# Warm up models
warm_up_models(compiled_models, stitching_retargeting_module, inputs)
# Measure inference times
times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
# Print benchmark results
print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
if __name__ == "__main__":
main()