Skip to content

Commit

Permalink
optimize nn MUD parse labels strip space
Browse files Browse the repository at this point in the history
  • Loading branch information
Neutree committed Dec 12, 2024
1 parent d5cd634 commit b155703
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 17 deletions.
61 changes: 51 additions & 10 deletions docs/doc/en/pro/customize_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,65 @@ import numpy as np
def parse_str_values(value: str) -> list[float]:
return [float(v) for v in value.split(",")]

def load_labels(model_path, path_or_labels: str):
path = os.path.join(os.path.dirname(model_path), path_or_labels)
labels0 = open(path, encoding="utf-8").readlines() if os.path.exists(path) else path_or_labels.split(",")
return [label.strip() for label in labels0]
def load_labels(model_path, path_or_labels : str):
path = ""
if not ("," in path_or_labels or " " in path_or_labels or "\n" in path_or_labels):
path = os.path.join(os.path.dirname(model_path), path_or_labels)
if path and os.path.exists(path):
with open(path, encoding = "utf-8") as f:
labels0 = f.readlines()
else:
labels0 = path_or_labels.split(",")
labels = []
for label in labels0:
labels.append(label.strip())
return labels

class My_Classifier:
def __init__(self, model: str):
self.model = nn.NN(model, dual_buff=False)
def __init__(self, model : str):
self.model = nn.NN(model, dual_buff = False)
self.extra_info = self.model.extra_info()
self.mean = parse_str_values(self.extra_info["mean"])
self.scale = parse_str_values(self.extra_info["scale"])
self.labels = self.model.extra_info_labels()
# self.labels = load_labels(model, self.extra_info["labels"]) # same as self.model.extra_info_labels()

def classify(self, img : image.Image):
outs = self.model.forward_image(img, self.mean, self.scale, copy_result = False)
# 后处理, 以分类模型为例
for k in outs.keys():
out = nn.F.softmax(outs[k], replace=True)
out = tensor.tensor_to_numpy_float32(out, copy = False).flatten()
max_idx = out.argmax()
return self.labels[max_idx], out[max_idx]def load_labels(model_path, path_or_labels : str):
path = ""
if not ("," in path_or_labels or " " in path_or_labels or "\n" in path_or_labels):
path = os.path.join(os.path.dirname(model_path), path_or_labels)
if path and os.path.exists(path):
with open(path, encoding = "utf-8") as f:
labels0 = f.readlines()
else:
labels0 = path_or_labels.split(",")
labels = []
for label in labels0:
labels.append(label.strip())
return labels

class My_Classifier:
def __init__(self, model : str):
self.model = nn.NN(model, dual_buff = False)
self.extra_info = self.model.extra_info()
self.mean = parse_str_values(self.extra_info["mean"])
self.scale = parse_str_values(self.extra_info["scale"])
self.labels = load_labels(model, self.extra_info["labels"])
self.labels = self.model.extra_info_labels()
# self.labels = load_labels(model, self.extra_info["labels"]) # same as self.model.extra_info_labels()

def classify(self, img: image.Image):
outs = self.model.forward_image(img, self.mean, self.scale, copy_result=False)
def classify(self, img : image.Image):
outs = self.model.forward_image(img, self.mean, self.scale, copy_result = False)
# 后处理, 以分类模型为例
for k in outs.keys():
out = nn.F.softmax(outs[k], replace=True)
out = tensor.tensor_to_numpy_float32(out, copy=False).flatten()
out = tensor.tensor_to_numpy_float32(out, copy = False).flatten()
max_idx = out.argmax()
return self.labels[max_idx], out[max_idx]

Expand Down
9 changes: 6 additions & 3 deletions docs/doc/zh/pro/customize_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def parse_str_values(value : str) -> list[float]:
return [float(value)]

def load_labels(model_path, path_or_labels : str):
path = os.path.join(os.path.dirname(model_path), path_or_labels)
if os.path.exists(path):
path = ""
if not ("," in path_or_labels or " " in path_or_labels or "\n" in path_or_labels):
path = os.path.join(os.path.dirname(model_path), path_or_labels)
if path and os.path.exists(path):
with open(path, encoding = "utf-8") as f:
labels0 = f.readlines()
else:
Expand All @@ -97,7 +99,8 @@ class My_Classifier:
self.extra_info = self.model.extra_info()
self.mean = parse_str_values(self.extra_info["mean"])
self.scale = parse_str_values(self.extra_info["scale"])
self.labels = load_labels(model, self.extra_info["labels"])
self.labels = self.model.extra_info_labels()
# self.labels = load_labels(model, self.extra_info["labels"]) # same as self.model.extra_info_labels()

def classify(self, img : image.Image):
outs = self.model.forward_image(img, self.mean, self.scale, copy_result = False)
Expand Down
9 changes: 6 additions & 3 deletions examples/vision/ai_vision/nn_custom_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ def parse_str_values(value : str) -> list[float]:
return [float(value)]

def load_labels(model_path, path_or_labels : str):
path = os.path.join(os.path.dirname(model_path), path_or_labels)
if os.path.exists(path):
path = ""
if not ("," in path_or_labels or " " in path_or_labels or "\n" in path_or_labels):
path = os.path.join(os.path.dirname(model_path), path_or_labels)
if path and os.path.exists(path):
with open(path, encoding = "utf-8") as f:
labels0 = f.readlines()
else:
Expand All @@ -30,7 +32,8 @@ def __init__(self, model : str):
self.extra_info = self.model.extra_info()
self.mean = parse_str_values(self.extra_info["mean"])
self.scale = parse_str_values(self.extra_info["scale"])
self.labels = load_labels(model, self.extra_info["labels"])
self.labels = self.model.extra_info_labels()
# self.labels = load_labels(model, self.extra_info["labels"]) # same as self.model.extra_info_labels()

def classify(self, img : image.Image):
outs = self.model.forward_image(img, self.mean, self.scale, copy_result = False)
Expand Down
2 changes: 1 addition & 1 deletion examples/vision/ai_vision/nn_yolo11_seg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from maix import camera, display, image, nn, app, time

detector = nn.YOLO11(model="/root/models/yolov11n_seg.mud", dual_buff = True)
detector = nn.YOLO11(model="/root/models/yolo11n_seg.mud", dual_buff = True)

cam = camera.Camera(detector.input_width(), detector.input_height(), detector.input_format())
disp = display.Display()
Expand Down

0 comments on commit b155703

Please sign in to comment.