-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathexport_onnx.py
92 lines (77 loc) · 2.57 KB
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import torch
from onnx import load_model, save_model
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from depth_anything.dpt import DPT_DINOv2
from depth_anything.util.transform import load_image
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
choices=["s", "b", "l"],
required=True,
help="Model size variant. Available options: 's', 'b', 'l'.",
)
parser.add_argument(
"--output",
type=str,
default=None,
required=False,
help="Path to save the ONNX model.",
)
return parser.parse_args()
def export_onnx(model: str, output: str = None):
# Handle args
if model is None:
model = "s"
if output is None:
output = f"weights/depth_anything_vit{model}14.onnx"
# Device for tracing (use whichever has enough free memory)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 这里因为要转ort模型,不是用ort推理,因此指定为cpu
device = "cpu"
print("user set device : ", device)
# Sample image for tracing (dimensions don't matter)
image, _ = load_image("assets/sacre_coeur1.jpg")
image = torch.from_numpy(image).to(device)
# Load model params
if model == "s":
depth_anything = DPT_DINOv2(
encoder="vits", features=64, out_channels=[48, 96, 192, 384]
)
elif model == "b":
depth_anything = DPT_DINOv2(
encoder="vitb", features=128, out_channels=[96, 192, 384, 768]
)
else: # model == "l"
depth_anything = DPT_DINOv2(
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]
)
depth_anything.to(device).load_state_dict(
torch.hub.load_state_dict_from_url(
f"https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vit{model}14.pth",
map_location="cpu",
),
strict=True,
)
depth_anything.eval()
torch.onnx.export(
depth_anything,
image,
output,
input_names=["image"],
output_names=["depth"],
opset_version=17,
# dynamic_axes={
# "image": {2: "height", 3: "width"},
# "depth": {2: "height", 3: "width"},
# },
)
save_model(
SymbolicShapeInference.infer_shapes(load_model(output), auto_merge=True),
output,
)
if __name__ == "__main__":
args = parse_args()
export_onnx(**vars(args))