From e884da5401d7ac04c6cf3fa5880010ec5d91a780 Mon Sep 17 00:00:00 2001 From: danielafrimi Date: Sun, 31 Dec 2023 12:02:24 +0200 Subject: [PATCH 1/4] add dataset for calibration --- examples/yolonas_example.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/yolonas_example.py b/examples/yolonas_example.py index 2cce6b5..b3f9c91 100644 --- a/examples/yolonas_example.py +++ b/examples/yolonas_example.py @@ -1,15 +1,20 @@ +import hydra +import numpy as np import onnx -import torch -from onnx2tflite.utils.builder import keras_builder +import super_gradients import tensorflow as tf -import numpy as np +import torch +from omegaconf import DictConfig +from super_gradients.training.datasets.detection_datasets import COCODetectionDataset +from src.onnx2tflite.utils.builder import keras_builder class RandomDatasetGenerator: - def __init__(self, input_size, num_samples: int = 1): + def __init__(self, input_size, num_samples: int = 1, dataset=None): self.input_size = input_size self.counter = 0 self.num_samples = num_samples + self.dataset = dataset def __iter__(self): self.counter = 0 @@ -20,18 +25,24 @@ def iterator(self): def __next__(self): self.counter += 1 - if self.counter <= self.num_samples: + if self.counter <= self.num_samples and self.counter <= len(self.dataset): + if self.dataset: + sample = self.dataset[self.counter] + return [np.expand_dims(sample[0], axis=0)] + # Random input return [np.random.rand(*self.input_size).astype(np.float32)] raise StopIteration() -if __name__ == '__main__': +@hydra.main(config_path="../recipes/", config_name="coco2017_yolo_nas", version_base="1.2.0") +def run(cfg: DictConfig): + cfg = hydra.utils.instantiate(cfg) input_size = [1, 3, 640, 640] channel_last_size = [input_size[0], input_size[2], input_size[3], input_size[1]] x_nchw = torch.randn(*input_size) quantize = True - onnx_path = "/Users/liork/Downloads/yolonas_s_for_tflite.onnx" + onnx_path = "/Users/danielafrimi/Downloads/yolox_model_no_heads.onnx" tflite_path = onnx_path.replace(".onnx", "_quant.tflite" if quantize else ".tflite") onnx_model = onnx.load_model(onnx_path) @@ -40,10 +51,12 @@ def __next__(self): ) converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] #, tf.lite.OpsSet.SELECT_TF_OPS] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] # , tf.lite.OpsSet.SELECT_TF_OPS] if quantize: - converter.representative_dataset = RandomDatasetGenerator(input_size=channel_last_size, num_samples=5).iterator + dataset_val = COCODetectionDataset(**cfg.dataset_params.val_dataset_params) + converter.representative_dataset = RandomDatasetGenerator(input_size=channel_last_size, num_samples=5, + dataset=dataset_val).iterator converter._experimental_disable_per_channel = True converter.experimental_new_converter = True converter.experimental_new_quantizer = True @@ -92,5 +105,6 @@ def __next__(self): # np.testing.assert_allclose(output_data, torch_out, rtol=1e-4) - - +if __name__ == '__main__': + super_gradients.init_trainer() + run() \ No newline at end of file From 74849788bb5624c1398968567705124a85fbb033 Mon Sep 17 00:00:00 2001 From: danielafrimi Date: Mon, 1 Jan 2024 11:17:40 +0200 Subject: [PATCH 2/4] add dataset for calibration + eval on coco --- examples/yolonas_example.py | 161 +++++++++++++++++++++++++----------- 1 file changed, 113 insertions(+), 48 deletions(-) diff --git a/examples/yolonas_example.py b/examples/yolonas_example.py index b3f9c91..0ea340b 100644 --- a/examples/yolonas_example.py +++ b/examples/yolonas_example.py @@ -6,9 +6,31 @@ import torch from omegaconf import DictConfig from super_gradients.training.datasets.detection_datasets import COCODetectionDataset - +from super_gradients.training.metrics import DetectionMetrics +from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback +from super_gradients.training.utils import get_param +from tqdm import tqdm +from super_gradients.training import dataloaders from src.onnx2tflite.utils.builder import keras_builder + +def create_onnx_model_example(): + from torchvision.models import resnet50, ResNet50_Weights + model = resnet50(weights=ResNet50_Weights.DEFAULT) + model.eval() + + dummy_input = torch.zeros(1, 3, 224, 224) # BCHW + torch.onnx.export(model, + dummy_input, + 'resnet_50.onnx', + verbose=False, opset_version=15, + training=torch.onnx.TrainingMode.EVAL, + do_constant_folding=True, + input_names=['input'], + output_names=['output'], + dynamic_axes=None) + + class RandomDatasetGenerator: def __init__(self, input_size, num_samples: int = 1, dataset=None): self.input_size = input_size @@ -26,56 +48,18 @@ def iterator(self): def __next__(self): self.counter += 1 if self.counter <= self.num_samples and self.counter <= len(self.dataset): + print(f"num samples is {self.counter}") if self.dataset: sample = self.dataset[self.counter] - return [np.expand_dims(sample[0], axis=0)] + img = np.transpose(np.expand_dims(sample[0], axis=0), (0, 2, 3, 1)) + return [img] + # Random input return [np.random.rand(*self.input_size).astype(np.float32)] raise StopIteration() -@hydra.main(config_path="../recipes/", config_name="coco2017_yolo_nas", version_base="1.2.0") -def run(cfg: DictConfig): - cfg = hydra.utils.instantiate(cfg) - input_size = [1, 3, 640, 640] - channel_last_size = [input_size[0], input_size[2], input_size[3], input_size[1]] - x_nchw = torch.randn(*input_size) - quantize = True - - onnx_path = "/Users/danielafrimi/Downloads/yolox_model_no_heads.onnx" - tflite_path = onnx_path.replace(".onnx", "_quant.tflite" if quantize else ".tflite") - - onnx_model = onnx.load_model(onnx_path) - keras_model = keras_builder( - onnx_model=onnx_model, native_groupconv=True, tflite_compat=True - ) - - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] # , tf.lite.OpsSet.SELECT_TF_OPS] - - if quantize: - dataset_val = COCODetectionDataset(**cfg.dataset_params.val_dataset_params) - converter.representative_dataset = RandomDatasetGenerator(input_size=channel_last_size, num_samples=5, - dataset=dataset_val).iterator - converter._experimental_disable_per_channel = True - converter.experimental_new_converter = True - converter.experimental_new_quantizer = True - converter.optimizations = [tf.lite.Optimize.DEFAULT] - - # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] - converter.target_spec.supported_types = [] - - try: - tflite_model = converter.convert() - except Exception as e: - print("======================== Turn off `experimental_new_quantizer`, and try again") - print(e) - converter.experimental_new_quantizer = False - tflite_model = converter.convert() - - with open(tflite_path, "wb") as fp: - fp.write(tflite_model) - +def inference_tflite_model(tflite_path, x_nchw, dataset): tf.lite.experimental.Analyzer.analyze( model_path=tflite_path, model_content=None, gpu_compatibility=True ) @@ -91,13 +75,37 @@ def run(cfg: DictConfig): # Test the model on random input data input_shape = input_details[0]['shape'] interpreter.set_tensor(input_details[0]['index'], x_nchw.permute(0, 2, 3, 1).numpy()) - interpreter.invoke() # get_tensor() returns a copy of the tensor data # use tensor() in order to get a pointer to the tensor - output_data = interpreter.get_tensor(output_details[0]['index']) - print(output_data.shape) + output_data_bbox = interpreter.get_tensor(output_details[0]['index']) + output_data_cls = interpreter.get_tensor(output_details[1]['index']) + print(output_data_bbox.shape) + print(output_data_cls.shape) + + post_prediction_callback = PPYoloEPostPredictionCallback(score_threshold=0.1, max_predictions=300, nms_top_k=1000, + nms_threshold=0.7) + metric = DetectionMetrics(score_thres=0.1, top_k_predictions=300, num_cls=80, normalize_targets=True, + post_prediction_callback=post_prediction_callback) + + for i, data in tqdm(enumerate(dataset)): + label = torch.tensor(data[1]) + preds = [] + img = np.transpose(data[0], (0, 2, 3, 1)) + interpreter.set_tensor(input_details[0]['index'], img) + interpreter.invoke() + + output_data_bbox = torch.from_numpy(interpreter.get_tensor(output_details[0]['index'])).squeeze(dim=3) + output_data_cls = torch.from_numpy(interpreter.get_tensor(output_details[1]['index'])).permute(0, 2, 1) + + preds.append(output_data_bbox) + preds.append(output_data_cls) + + metric.update(preds=[preds], target=label, inputs=data[0], device='cuda') + + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") # with torch.no_grad(): # torch_out = module(x_nchw).numpy() @@ -105,6 +113,63 @@ def run(cfg: DictConfig): # np.testing.assert_allclose(output_data, torch_out, rtol=1e-4) +@hydra.main(config_path="../recipes/", config_name="coco2017_yolo_nas", version_base="1.2.0") +def run(cfg: DictConfig): + cfg = hydra.utils.instantiate(cfg) + inference = True + compile = False + input_size = [1, 3, 640, 640] + channel_last_size = [input_size[0], input_size[2], input_size[3], input_size[1]] + x_nchw = torch.randn(*input_size) + quantize = True + + onnx_path = "/home/daniel.afrimi/pycharm_projects/dso_tflite/examples/yolonas_s_for_tflite.onnx" + tflite_path = onnx_path.replace(".onnx", "_quant.tflite" if quantize else ".tflite") + dataset_val = COCODetectionDataset(**cfg.dataset_params.val_dataset_params) + + if compile: + onnx_model = onnx.load_model(onnx_path) + keras_model = keras_builder( + onnx_model=onnx_model, native_groupconv=True, tflite_compat=True + ) + + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] # , tf.lite.OpsSet.SELECT_TF_OPS] + + if quantize: + # todo do we want samples from train/val? + converter.representative_dataset = RandomDatasetGenerator(input_size=channel_last_size, num_samples=500, + dataset=dataset_val).iterator + converter._experimental_disable_per_channel = True + converter.experimental_new_converter = True + converter.experimental_new_quantizer = True + converter.optimizations = [ + tf.lite.Optimize.DEFAULT] # in their code is [tf.lite.Optimize.DEFAULT, tf.lite.OpsSet.SELECT_TF_OPS] + + # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.target_spec.supported_types = [] + + try: + tflite_model = converter.convert() + except Exception as e: + print( + "======================== Turn off `experimental_new_quantizer`, and try again ======================") + print(e) + converter.experimental_new_quantizer = False + tflite_model = converter.convert() + + with open(tflite_path, "wb") as fp: + fp.write(tflite_model) + + if inference: + val_dataloader = dataloaders.get( + name=get_param(cfg, "val_dataloader"), + dataset_params=cfg.dataset_params.val_dataset_params, + dataloader_params=cfg.dataset_params.val_dataloader_params, + ) + inference_tflite_model(tflite_path=tflite_path, x_nchw=x_nchw, dataset=val_dataloader) + + if __name__ == '__main__': super_gradients.init_trainer() - run() \ No newline at end of file + run() From cf0f1ff3fa7777e5cf63a81568cb4514dd037a38 Mon Sep 17 00:00:00 2001 From: danielafrimi Date: Wed, 3 Jan 2024 12:41:39 +0200 Subject: [PATCH 3/4] add dataset for calibration + eval on coco --- examples/yolonas_example.py | 202 +++++++++++++++++++++++++----- examples/yolonas_tflite_compat.py | 82 ++++++++++++ 2 files changed, 254 insertions(+), 30 deletions(-) create mode 100644 examples/yolonas_tflite_compat.py diff --git a/examples/yolonas_example.py b/examples/yolonas_example.py index 0ea340b..f904297 100644 --- a/examples/yolonas_example.py +++ b/examples/yolonas_example.py @@ -12,6 +12,9 @@ from tqdm import tqdm from super_gradients.training import dataloaders from src.onnx2tflite.utils.builder import keras_builder +from super_gradients.training import models +from super_gradients.training.models.conversion import onnx_simplify +from examples.yolonas_tflite_compat import YoloNAS_S_TFLite def create_onnx_model_example(): @@ -31,8 +34,8 @@ def create_onnx_model_example(): dynamic_axes=None) -class RandomDatasetGenerator: - def __init__(self, input_size, num_samples: int = 1, dataset=None): +class COCODatasetGenerator: + def __init__(self, dataset, input_size, num_samples: int = 1): self.input_size = input_size self.counter = 0 self.num_samples = num_samples @@ -49,17 +52,37 @@ def __next__(self): self.counter += 1 if self.counter <= self.num_samples and self.counter <= len(self.dataset): print(f"num samples is {self.counter}") - if self.dataset: - sample = self.dataset[self.counter] - img = np.transpose(np.expand_dims(sample[0], axis=0), (0, 2, 3, 1)) - return [img] + sample = self.dataset[self.counter] + img = np.transpose(np.expand_dims(sample[0], axis=0), (0, 2, 3, 1)) + return [img] + + raise StopIteration() + + +class RandomDatasetGenerator: + def __init__(self, input_size, num_samples: int = 1): + self.input_size = input_size + self.counter = 0 + self.num_samples = num_samples + + def __iter__(self): + self.counter = 0 + return self + + def iterator(self): + return self.__iter__() + def __next__(self): + self.counter += 1 + if self.counter <= self.num_samples : + print(f"num samples is {self.counter}") # Random input return [np.random.rand(*self.input_size).astype(np.float32)] raise StopIteration() -def inference_tflite_model(tflite_path, x_nchw, dataset): +def eval_tflite_model(tflite_path, x_nchw, dataset): + print(f'Start evaluate tflite model with path {tflite_path}') tf.lite.experimental.Analyzer.analyze( model_path=tflite_path, model_content=None, gpu_compatibility=True ) @@ -97,77 +120,196 @@ def inference_tflite_model(tflite_path, x_nchw, dataset): interpreter.invoke() output_data_bbox = torch.from_numpy(interpreter.get_tensor(output_details[0]['index'])).squeeze(dim=3) - output_data_cls = torch.from_numpy(interpreter.get_tensor(output_details[1]['index'])).permute(0, 2, 1) + output_data_cls = torch.from_numpy(interpreter.get_tensor(output_details[1]['index'])) preds.append(output_data_bbox) preds.append(output_data_cls) metric.update(preds=[preds], target=label, inputs=data[0], device='cuda') + if i % 1000 == 0: + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") + detection_metric = metric.compute() print(f"detection_metric result:\t {detection_metric}") - # with torch.no_grad(): - # torch_out = module(x_nchw).numpy() - # - # np.testing.assert_allclose(output_data, torch_out, rtol=1e-4) + +def eval_torch_model(eval_dataset): + print(f'Start evaluate torch model') + torch_model = models.get(model_name="YoloNAS_S_TFLite", num_classes=80).eval() + # load state dict + yolo_nas_model_weights = models.get(model_name="yolo_nas_s", num_classes=80, pretrained_weights="coco").eval() + torch_model.load_state_dict(yolo_nas_model_weights.state_dict()) + + post_prediction_callback = PPYoloEPostPredictionCallback(score_threshold=0.1, max_predictions=300, nms_top_k=1000, + nms_threshold=0.7) + metric = DetectionMetrics(score_thres=0.1, top_k_predictions=300, num_cls=80, normalize_targets=True, + post_prediction_callback=post_prediction_callback) + + for i, data in tqdm(enumerate(eval_dataset)): + label = torch.tensor(data[1]) + preds = [] + + pred = torch_model(data[0]) + + preds.append(pred[0][0].squeeze(dim=0)) + preds.append(pred[0][1].permute(0, 2, 1)) + + metric.update(preds=[preds], target=label, inputs=data[0], device='cuda') + + if i % 1000 == 0: + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") + + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") + + +def get_yolonas_onnx(onnx_path): + torch_model = models.get(model_name="YoloNAS_S_TFLite", num_classes=80).eval() + # load state dict + yolo_nas_model_weights = models.get(model_name="yolo_nas_s", num_classes=80, pretrained_weights="coco").eval() + torch_model.load_state_dict(yolo_nas_model_weights.state_dict()) + + input_size = [1, 3, 640, 640] + x_nchw = torch.randn(*input_size) + + # state dict sanity check + with torch.no_grad(): + (x1, x2), _ = torch_model(x_nchw) + (y1, y2), _ = yolo_nas_model_weights(x_nchw) + print(f"{'=' * 20} DIFF sanity check") + diff_cls = torch.abs(y2 - x2.permute(0, 2, 1)) + print(f"DIFF cls preds: mean = {diff_cls.mean()}, max = {diff_cls.max()}") + diff_reg = torch.abs(y1 - x1.squeeze(1)) + print(f"DIFF cls preds: mean = {diff_reg.mean()}, max = {diff_reg.max()}") + print(f"{'=' * 20}") + + torch_model.prep_model_for_conversion([1, 3, 640, 640]) + + torch.onnx.export(torch_model, x_nchw, onnx_path, opset_version=13) + onnx_simplify(onnx_path, onnx_path) + + # Edit onnx model with custom ops + onnx_model = onnx.load_model(onnx_path) + + i = 0 + counter = 0 + value_dict = {attr.name: attr for attr in onnx_model.graph.value_info} + initializer_dict = {attr.name: attr for attr in onnx_model.graph.initializer} + while i < len(onnx_model.graph.node): + if "/heads/" in onnx_model.graph.node[i].name and onnx_model.graph.node[i].op_type == "Reshape" and \ + onnx_model.graph.node[i + 1].op_type == "Transpose" and onnx_model.graph.node[i + 2].op_type == "Softmax": + output_edge = onnx_model.graph.node[i + 1].output[0] + dims = [d.dim_value for d in value_dict[output_edge].type.tensor_type.shape.dim] + num_regs = dims[1] + anchor_size = dims[2] + new_node = onnx.helper.make_node( + inputs=list(onnx_model.graph.node[i].input), + outputs=list(onnx_model.graph.node[i + 1].output), + name=f"/DFL_Reshape{counter}", + op_type="DFLReshape", + num_regs=num_regs, + anchor_size=anchor_size + ) + onnx_model.graph.node.insert(i, new_node) + del onnx_model.graph.node[i + 2] + del onnx_model.graph.node[i + 1] + i += 1 + + onnx.save_model(onnx_model, onnx_path) + +def calc_quantization_stats(converter): + result_file = "result_tflite_quant.csv" + debugger = tf.lite.experimental.QuantizationDebugger( + converter=converter, debug_dataset=converter.representative_dataset) + debugger.run() + + with open(result_file, 'w') as f: + debugger.layer_statistics_dump(f) +def print_keras_model_layers(keras_mnodel): + for i, layer in enumerate(keras_mnodel.layers): + print(f"Layer {i}: {layer.name}") @hydra.main(config_path="../recipes/", config_name="coco2017_yolo_nas", version_base="1.2.0") def run(cfg: DictConfig): cfg = hydra.utils.instantiate(cfg) - inference = True - compile = False + eval = True + compile = True + quantize = True input_size = [1, 3, 640, 640] channel_last_size = [input_size[0], input_size[2], input_size[3], input_size[1]] x_nchw = torch.randn(*input_size) - quantize = True - onnx_path = "/home/daniel.afrimi/pycharm_projects/dso_tflite/examples/yolonas_s_for_tflite.onnx" + onnx_path = "yolonas_s_for_tflite.onnx" tflite_path = onnx_path.replace(".onnx", "_quant.tflite" if quantize else ".tflite") - dataset_val = COCODetectionDataset(**cfg.dataset_params.val_dataset_params) if compile: + get_yolonas_onnx(onnx_path) onnx_model = onnx.load_model(onnx_path) keras_model = keras_builder( onnx_model=onnx_model, native_groupconv=True, tflite_compat=True ) + # print_keras_model_layers(keras_model) + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] # , tf.lite.OpsSet.SELECT_TF_OPS] if quantize: - # todo do we want samples from train/val? - converter.representative_dataset = RandomDatasetGenerator(input_size=channel_last_size, num_samples=500, + dataset_val = COCODetectionDataset(**cfg.dataset_params.val_dataset_params) + converter.representative_dataset = COCODatasetGenerator(input_size=channel_last_size, num_samples=500, dataset=dataset_val).iterator + converter._experimental_disable_per_channel = True converter.experimental_new_converter = True converter.experimental_new_quantizer = True - converter.optimizations = [ - tf.lite.Optimize.DEFAULT] # in their code is [tf.lite.Optimize.DEFAULT, tf.lite.OpsSet.SELECT_TF_OPS] + converter.optimizations = [tf.lite.Optimize.DEFAULT] # in their code is [tf.lite.Optimize.DEFAULT, tf.lite.OpsSet.SELECT_TF_OPS] # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.target_spec.supported_types = [] try: - tflite_model = converter.convert() + + # calc_quantization_stats(converter=converter) + # tflite_model = converter.convert() + tflite_path = "yolonas_s_for_tflite_selctive_quant.tflite" + suspected_layers = ['StatefulPartitionedCall:0', 'model/tf.math.subtract/Sub', + 'model/tf.concat_15/concat','model/tf.__operators__.add_20/AddV2'] + debug_options = tf.lite.experimental.QuantizationDebugOptions( + denylisted_nodes=suspected_layers) + + debugger = tf.lite.experimental.QuantizationDebugger( + converter=converter, + debug_dataset=converter.representative_dataset, + debug_options=debug_options) + + print(f"Compile model to tflite with Selective Quantization, emitted layers: " + f" {suspected_layers}, saves in path {tflite_path}") + + tflite_model = debugger.get_nondebug_quantized_model() + except Exception as e: - print( - "======================== Turn off `experimental_new_quantizer`, and try again ======================") + print("======================== Turn off `experimental_new_quantizer`, and try again ======================") print(e) converter.experimental_new_quantizer = False tflite_model = converter.convert() + else: + tflite_model = converter.convert() - with open(tflite_path, "wb") as fp: - fp.write(tflite_model) + with open(tflite_path, "wb") as fp: + fp.write(tflite_model) - if inference: + if eval: val_dataloader = dataloaders.get( name=get_param(cfg, "val_dataloader"), dataset_params=cfg.dataset_params.val_dataset_params, - dataloader_params=cfg.dataset_params.val_dataloader_params, - ) - inference_tflite_model(tflite_path=tflite_path, x_nchw=x_nchw, dataset=val_dataloader) + dataloader_params=cfg.dataset_params.val_dataloader_params) + + # eval_torch_model(eval_dataset=val_dataloader) + eval_tflite_model(tflite_path=tflite_path, x_nchw=x_nchw, dataset=val_dataloader) if __name__ == '__main__': diff --git a/examples/yolonas_tflite_compat.py b/examples/yolonas_tflite_compat.py new file mode 100644 index 0000000..f0a5807 --- /dev/null +++ b/examples/yolonas_tflite_compat.py @@ -0,0 +1,82 @@ +from typing import Tuple + +import torch +from torch import Tensor +import copy + +from super_gradients.training.models.detection_models.yolo_nas.yolo_nas_variants import YoloNAS +from super_gradients.common.registry import register_detection_module +from super_gradients.training.models.detection_models.yolo_nas.dfl_heads import NDFLHeads +from super_gradients.training.utils.bbox_utils import batch_distance2bbox +from super_gradients.training.models.detection_models.pp_yolo_e.pp_yolo_head import generate_anchors_for_grid_cell +from super_gradients.common.registry import register_model +from super_gradients.training.models import get_arch_params +from super_gradients.training.utils import HpmStruct, get_param + + +@register_model() +class YoloNAS_S_TFLite(YoloNAS): + def __init__(self, arch_params): + default_arch_params = get_arch_params("yolo_nas_s_arch_params") + default_arch_params["heads"]["NDFLHeadsTFlite"] = default_arch_params["heads"].pop("NDFLHeads") + merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params)) + merged_arch_params.override(**arch_params.to_dict()) + super().__init__( + backbone=merged_arch_params.backbone, + neck=merged_arch_params.neck, + heads=merged_arch_params.heads, + num_classes=get_param(merged_arch_params, "num_classes", None), + in_channels=get_param(merged_arch_params, "in_channels", 3), + bn_momentum=get_param(merged_arch_params, "bn_momentum", None), + bn_eps=get_param(merged_arch_params, "bn_eps", None), + inplace_act=get_param(merged_arch_params, "inplace_act", None), + ) + + +@register_detection_module() +class NDFLHeadsTFlite(NDFLHeads): + def forward_eval(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]: + cls_score_list, reg_distri_list, reg_dist_reduced_list = [], [], [] + + for i, feat in enumerate(feats): + b, _, h, w = feat.shape + height_mul_width = h * w + reg_distri, cls_logit = getattr(self, f"head{i + 1}")(feat) + reg_distri_list.append(torch.permute(reg_distri.flatten(2), [0, 2, 1])) + + reg_dist_reduced = torch.permute(reg_distri.reshape([-1, 4, self.reg_max + 1, height_mul_width]), [0, 2, 3, 1]) + reg_dist_reduced = torch.nn.functional.conv2d(torch.nn.functional.softmax(reg_dist_reduced, dim=1), weight=self.proj_conv)#.squeeze(1) + + # cls and reg + cls_score_list.append(cls_logit.reshape([b, self.num_classes, height_mul_width])) + reg_dist_reduced_list.append(reg_dist_reduced) + + cls_score_list = torch.cat(cls_score_list, dim=-1) # [B, C, Anchors] + # cls_score_list = torch.permute(cls_score_list, [0, 2, 1]) # # [B, Anchors, C] + + reg_distri_list = torch.cat(reg_distri_list, dim=1) # [B, Anchors, 4 * (self.reg_max + 1)] + reg_dist_reduced_list = torch.cat(reg_dist_reduced_list, dim=2) # [B, 1, Anchors, 4] + + # Decode bboxes + # Note in eval mode, anchor_points_inference is different from anchor_points computed on train + if self.eval_size: + anchor_points_inference, stride_tensor = self.anchor_points, self.stride_tensor + else: + anchor_points_inference, stride_tensor = self._generate_anchors(feats) + + pred_scores = cls_score_list.sigmoid() + pred_bboxes = batch_distance2bbox(anchor_points_inference, reg_dist_reduced_list) * stride_tensor # [B, Anchors, 4] + + decoded_predictions = pred_bboxes, pred_scores + + if torch.jit.is_tracing(): + return decoded_predictions + + anchors, anchor_points, num_anchors_list, _ = generate_anchors_for_grid_cell(feats, self.fpn_strides, self.grid_cell_scale, self.grid_cell_offset) + + raw_predictions = cls_score_list, reg_distri_list, anchors, anchor_points, num_anchors_list, stride_tensor + return decoded_predictions, raw_predictions + + def _generate_anchors(self, feats=None, dtype=None, device=None): + anchor_points, stride_tensor = super()._generate_anchors(feats, dtype, device) + return anchor_points.unsqueeze(0).unsqueeze(0), stride_tensor.unsqueeze(0).unsqueeze(0) \ No newline at end of file From 609d89756cf3288d7dca43364c6238122ce2c0ee Mon Sep 17 00:00:00 2001 From: danielafrimi Date: Sun, 7 Jan 2024 14:04:25 +0200 Subject: [PATCH 4/4] expanded API to convert yolonas to tflite, it can use for various models than needs to compile to TFLite --- examples/onnx2tflite_example.py | 328 ++++++++++++++++++++++++++++++ examples/yolonas_example.py | 55 +++-- examples/yolonas_tflite_compat.py | 41 +++- 3 files changed, 403 insertions(+), 21 deletions(-) create mode 100644 examples/onnx2tflite_example.py diff --git a/examples/onnx2tflite_example.py b/examples/onnx2tflite_example.py new file mode 100644 index 0000000..017c326 --- /dev/null +++ b/examples/onnx2tflite_example.py @@ -0,0 +1,328 @@ +import argparse + +import hydra +import numpy as np +import onnx +import super_gradients +import tensorflow as tf +import torch +from omegaconf import DictConfig +from super_gradients.training import dataloaders +from super_gradients.training import models +from super_gradients.training.datasets.detection_datasets import COCODetectionDataset +from super_gradients.training.metrics import DetectionMetrics +from super_gradients.training.models.conversion import onnx_simplify +from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback +from super_gradients.training.utils import get_param +from tqdm import tqdm + +from src.onnx2tflite.utils.builder import keras_builder + + +# Calibration Datasets for tflite converter +class COCODatasetGenerator: + def __init__(self, dataset, input_size, num_samples: int = 1): + self.input_size = input_size + self.counter = 0 + self.num_samples = num_samples + self.dataset = dataset + + def __iter__(self): + self.counter = 0 + return self + + def iterator(self): + return self.__iter__() + + def __next__(self): + self.counter += 1 + if self.counter <= self.num_samples and self.counter <= len(self.dataset): + print(f"num samples is {self.counter}") + sample = self.dataset[self.counter] + img = np.transpose(np.expand_dims(sample[0], axis=0), (0, 2, 3, 1)) + return [img] + + raise StopIteration() + + +class RandomDatasetGenerator: + def __init__(self, input_size, num_samples: int = 1): + self.input_size = input_size + self.counter = 0 + self.num_samples = num_samples + + def __iter__(self): + self.counter = 0 + return self + + def iterator(self): + return self.__iter__() + + def __next__(self): + self.counter += 1 + if self.counter <= self.num_samples: + print(f"num samples is {self.counter}") + return [np.random.rand(*self.input_size).astype(np.float32)] + raise StopIteration() + + +def eval_tflite_model_coco_dataset(tflite_path, cfg): + """ + Evaluate a TFLite model using the COCO dataset. + + Parameters: + - tflite_path (str): Path to the TFLite model file. + - cfg (Config): Configuration object containing evaluation parameters. + + Returns: + None + + Raises: + - FileNotFoundError: If the specified TFLite model file is not found. + """ + + print(f'Start evaluate tflite model with path {tflite_path}') + val_dataloader = dataloaders.get( + name=get_param(cfg, "val_dataloader"), + dataset_params=cfg.dataset_params.val_dataset_params, + dataloader_params=cfg.dataset_params.val_dataloader_params) + + tf.lite.experimental.Analyzer.analyze( + model_path=tflite_path, model_content=None, gpu_compatibility=True + ) + + # Load the TFLite model and allocate tensors + interpreter = tf.lite.Interpreter(model_path=tflite_path) + interpreter.allocate_tensors() + + # # Get input and output tensors + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + post_prediction_callback = PPYoloEPostPredictionCallback(score_threshold=0.1, max_predictions=300, nms_top_k=1000, + nms_threshold=0.7) + metric = DetectionMetrics(score_thres=0.1, top_k_predictions=300, num_cls=80, normalize_targets=True, + post_prediction_callback=post_prediction_callback) + + for i, data in tqdm(enumerate(val_dataloader)): + label = torch.tensor(data[1]) + preds = [] + img = np.transpose(data[0], (0, 2, 3, 1)) + interpreter.set_tensor(input_details[0]['index'], img) + interpreter.invoke() + + output_data_bbox = torch.from_numpy(interpreter.get_tensor(output_details[0]['index'])).squeeze(dim=3) + output_data_cls = torch.from_numpy(interpreter.get_tensor(output_details[1]['index'])) + + preds.append(output_data_bbox) + preds.append(output_data_cls) + + metric.update(preds=[preds], target=label, inputs=data[0], device='cuda') + + if i % 1000 == 0: + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") + + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") + + +def eval_yolonas_torch_model_coco(eval_dataset): + print(f'Start evaluate torch model') + # todo maybe to change this? + torch_model = models.get(model_name="YoloNAS_S_TFLite", num_classes=80).eval() + # load state dict + yolo_nas_model_weights = models.get(model_name="yolo_nas_s", num_classes=80, pretrained_weights="coco").eval() + torch_model.load_state_dict(yolo_nas_model_weights.state_dict()) + + post_prediction_callback = PPYoloEPostPredictionCallback(score_threshold=0.1, max_predictions=300, nms_top_k=1000, + nms_threshold=0.7) + metric = DetectionMetrics(score_thres=0.1, top_k_predictions=300, num_cls=80, normalize_targets=True, + post_prediction_callback=post_prediction_callback) + + for i, data in tqdm(enumerate(eval_dataset)): + label = torch.tensor(data[1]) + preds = [] + + pred = torch_model(data[0]) + + preds.append(pred[0][0].squeeze(dim=0)) + preds.append(pred[0][1].permute(0, 2, 1)) + + metric.update(preds=[preds], target=label, inputs=data[0], device='cuda') + + if i % 1000 == 0: + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") + + detection_metric = metric.compute() + print(f"detection_metric result:\t {detection_metric}") + + +def create_yolonas_onnx(onnx_path, model_name="YoloNAS_L_TFLite", checkpoint_model_name="yolo_nas_l"): + torch_model = models.get(model_name=model_name, num_classes=80).eval() + # load state dict + yolo_nas_model_weights = models.get(model_name=checkpoint_model_name, num_classes=80, + pretrained_weights="coco").eval() + torch_model.load_state_dict(yolo_nas_model_weights.state_dict()) + + input_size = [1, 3, 640, 640] + x_nchw = torch.randn(*input_size) + + # state dict sanity check + with torch.no_grad(): + (x1, x2), _ = torch_model(x_nchw) + (y1, y2), _ = yolo_nas_model_weights(x_nchw) + print(f"{'=' * 20} DIFF sanity check") + diff_cls = torch.abs(y2 - x2.permute(0, 2, 1)) + print(f"DIFF cls preds: mean = {diff_cls.mean()}, max = {diff_cls.max()}") + diff_reg = torch.abs(y1 - x1.squeeze(1)) + print(f"DIFF cls preds: mean = {diff_reg.mean()}, max = {diff_reg.max()}") + print(f"{'=' * 20}") + + torch_model.prep_model_for_conversion([1, 3, 640, 640]) + + torch.onnx.export(torch_model, x_nchw, onnx_path, opset_version=13) + onnx_simplify(onnx_path, onnx_path) + + # Edit onnx model with custom ops + onnx_model = onnx.load_model(onnx_path) + + i = 0 + counter = 0 + value_dict = {attr.name: attr for attr in onnx_model.graph.value_info} + initializer_dict = {attr.name: attr for attr in onnx_model.graph.initializer} + while i < len(onnx_model.graph.node): + if "/heads/" in onnx_model.graph.node[i].name and onnx_model.graph.node[i].op_type == "Reshape" and \ + onnx_model.graph.node[i + 1].op_type == "Transpose" and onnx_model.graph.node[ + i + 2].op_type == "Softmax": + output_edge = onnx_model.graph.node[i + 1].output[0] + dims = [d.dim_value for d in value_dict[output_edge].type.tensor_type.shape.dim] + num_regs = dims[1] + anchor_size = dims[2] + new_node = onnx.helper.make_node( + inputs=list(onnx_model.graph.node[i].input), + outputs=list(onnx_model.graph.node[i + 1].output), + name=f"/DFL_Reshape{counter}", + op_type="DFLReshape", + num_regs=num_regs, + anchor_size=anchor_size + ) + onnx_model.graph.node.insert(i, new_node) + del onnx_model.graph.node[i + 2] + del onnx_model.graph.node[i + 1] + i += 1 + + onnx.save_model(onnx_model, onnx_path) + + +def calculate_quantization_stats(converter, result_file="result_tflite_quant.csv"): + debugger = tf.lite.experimental.QuantizationDebugger( + converter=converter, debug_dataset=converter.representative_dataset) + + debugger.run() + + with open(result_file, 'w') as f: + debugger.layer_statistics_dump(f) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Parser for Model Evaluation and Compilation') + parser.add_argument('eval_model', type=bool, default=False, + help='Perform model evaluation. Set to True to enable model evaluation.') + parser.add_argument('compile_model', type=bool, default=True, + help='Compile the model for deployment. Set to True to enable model compilation.') + parser.add_argument('quantize_model_int8', type=bool, default=True, + help='Quantize the model to INT8 format. Set to True to enable INT8 quantization.') + parser.add_argument('model_input_size', type=list, default=[1, 3, 640, 640], + help='Set the input size for the model. Provide a list of integers representing the input ' + 'size, e.g., [batch_size, channels, height, width].') + parser.add_argument('onnx_path', type=str, default="yolonas_l_for_tflite.onnx", + help='Path to the ONNX model file. Specify the ONNX model file for evaluation, compilation, ' + 'or quantization.') + + return parser + + +@hydra.main(config_path="../recipes/", config_name="coco2017_yolo_nas", version_base="1.2.0") +def run(cfg: DictConfig): + cfg = hydra.utils.instantiate(cfg) + args = parse_args().parse_args() + + tflite_path = args.onnx_path.replace(".onnx", "_quant.tflite" if args.quantize_model_int8 else ".tflite") + + if args.compile_model: + # Create yolonas onnx model + create_yolonas_onnx(args.onnx_path) + onnx_model = onnx.load_model(args.onnx_path) + keras_model = keras_builder(onnx_model=onnx_model, native_groupconv=True, tflite_compat=True) + + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + # Sets the TensorFlow Lite operations supported by the converter + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] # , tf.lite.OpsSet.SELECT_TF_OPS] + + if args.quantize_model_int8: + channel_last_size = [args.model_input_size[0], args.model_input_size[2], args.model_input_size[3], + args.model_input_size[1]] + + dataset_val = COCODetectionDataset(**cfg.dataset_params.val_dataset_params) + converter.representative_dataset = COCODatasetGenerator(input_size=channel_last_size, num_samples=500, + dataset=dataset_val).iterator + # Disables per-channel quantization + converter._experimental_disable_per_channel = True + converter.experimental_new_converter = True + converter.experimental_new_quantizer = True + converter.optimizations = [ + tf.lite.Optimize.DEFAULT] # in their code (nxp) is [tf.lite.Optimize.DEFAULT, tf.lite.OpsSet.SELECT_TF_OPS] + + # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.target_spec.supported_types = [] + try: + # Using the debugger you can watch (in the csv output) the error per layer/operator while doing + # quantization to int8. After a csv file been created you can choose the layers you want to + # exclude from quantization to int8, and preserve them as float16 (selective quantization) - + # insert them to a list like "suspected_layers" + calculate_quantization_stats(converter=converter) + suspected_layers = ['StatefulPartitionedCall:0', 'model/tf.concat_15/concat', + 'model/tf.math.subtract/Sub', 'model/tf.__operators__.add_28/AddV2'] + + debug_options = tf.lite.experimental.QuantizationDebugOptions(denylisted_nodes=suspected_layers) + + debugger = tf.lite.experimental.QuantizationDebugger( + converter=converter, + debug_dataset=converter.representative_dataset, + debug_options=debug_options) + + print( + f"Compile model to tflite with Selective Quantization, emitted layers: {suspected_layers}, saves in path {tflite_path}") + tflite_model = debugger.get_nondebug_quantized_model() + + except Exception as e: + print( + "======================== Turn off `experimental_new_quantizer`, and try again ======================") + print(e) + converter.experimental_new_quantizer = False + + suspected_layers = [] + debug_options = tf.lite.experimental.QuantizationDebugOptions(denylisted_nodes=suspected_layers) + debugger = tf.lite.experimental.QuantizationDebugger( + converter=converter, + debug_dataset=converter.representative_dataset, + debug_options=debug_options) + + tflite_model = debugger.get_nondebug_quantized_model() + + else: + # FP16 + tflite_model = converter.convert() + + with open(tflite_path, "wb") as fp: + fp.write(tflite_model) + + if args.eval_model: + eval_tflite_model_coco_dataset(tflite_path=tflite_path, cfg=cfg) + + +if __name__ == '__main__': + super_gradients.init_trainer() + run() diff --git a/examples/yolonas_example.py b/examples/yolonas_example.py index f904297..906d82e 100644 --- a/examples/yolonas_example.py +++ b/examples/yolonas_example.py @@ -1,3 +1,5 @@ +import argparse + import hydra import numpy as np import onnx @@ -74,7 +76,7 @@ def iterator(self): def __next__(self): self.counter += 1 - if self.counter <= self.num_samples : + if self.counter <= self.num_samples: print(f"num samples is {self.counter}") # Random input return [np.random.rand(*self.input_size).astype(np.float32)] @@ -166,10 +168,10 @@ def eval_torch_model(eval_dataset): print(f"detection_metric result:\t {detection_metric}") -def get_yolonas_onnx(onnx_path): - torch_model = models.get(model_name="YoloNAS_S_TFLite", num_classes=80).eval() +def create_yolonas_onnx(onnx_path): + torch_model = models.get(model_name="YoloNAS_L_TFLite", num_classes=80).eval() # load state dict - yolo_nas_model_weights = models.get(model_name="yolo_nas_s", num_classes=80, pretrained_weights="coco").eval() + yolo_nas_model_weights = models.get(model_name="yolo_nas_l", num_classes=80, pretrained_weights="coco").eval() torch_model.load_state_dict(yolo_nas_model_weights.state_dict()) input_size = [1, 3, 640, 640] @@ -200,7 +202,8 @@ def get_yolonas_onnx(onnx_path): initializer_dict = {attr.name: attr for attr in onnx_model.graph.initializer} while i < len(onnx_model.graph.node): if "/heads/" in onnx_model.graph.node[i].name and onnx_model.graph.node[i].op_type == "Reshape" and \ - onnx_model.graph.node[i + 1].op_type == "Transpose" and onnx_model.graph.node[i + 2].op_type == "Softmax": + onnx_model.graph.node[i + 1].op_type == "Transpose" and onnx_model.graph.node[ + i + 2].op_type == "Softmax": output_edge = onnx_model.graph.node[i + 1].output[0] dims = [d.dim_value for d in value_dict[output_edge].type.tensor_type.shape.dim] num_regs = dims[1] @@ -220,22 +223,30 @@ def get_yolonas_onnx(onnx_path): onnx.save_model(onnx_model, onnx_path) -def calc_quantization_stats(converter): - result_file = "result_tflite_quant.csv" + +def calc_quantization_stats(converter, result_file="result_tflite_quant.csv"): debugger = tf.lite.experimental.QuantizationDebugger( converter=converter, debug_dataset=converter.representative_dataset) debugger.run() with open(result_file, 'w') as f: debugger.layer_statistics_dump(f) + + def print_keras_model_layers(keras_mnodel): for i, layer in enumerate(keras_mnodel.layers): print(f"Layer {i}: {layer.name}") +def parse_args(): + parser = argparse.ArgumentParser(description='My Hydra Application') + parser.add_argument('--my_arg', type=int, help='An example argument') + return parser + @hydra.main(config_path="../recipes/", config_name="coco2017_yolo_nas", version_base="1.2.0") def run(cfg: DictConfig): cfg = hydra.utils.instantiate(cfg) + args = parse_args().parse_args() eval = True compile = True quantize = True @@ -243,17 +254,16 @@ def run(cfg: DictConfig): channel_last_size = [input_size[0], input_size[2], input_size[3], input_size[1]] x_nchw = torch.randn(*input_size) - onnx_path = "yolonas_s_for_tflite.onnx" + onnx_path = "yolonas_l_for_tflite.onnx" tflite_path = onnx_path.replace(".onnx", "_quant.tflite" if quantize else ".tflite") if compile: - get_yolonas_onnx(onnx_path) + create_yolonas_onnx(onnx_path) onnx_model = onnx.load_model(onnx_path) keras_model = keras_builder( onnx_model=onnx_model, native_groupconv=True, tflite_compat=True ) - # print_keras_model_layers(keras_model) converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] # , tf.lite.OpsSet.SELECT_TF_OPS] @@ -261,25 +271,31 @@ def run(cfg: DictConfig): if quantize: dataset_val = COCODetectionDataset(**cfg.dataset_params.val_dataset_params) converter.representative_dataset = COCODatasetGenerator(input_size=channel_last_size, num_samples=500, - dataset=dataset_val).iterator + dataset=dataset_val).iterator converter._experimental_disable_per_channel = True converter.experimental_new_converter = True converter.experimental_new_quantizer = True - converter.optimizations = [tf.lite.Optimize.DEFAULT] # in their code is [tf.lite.Optimize.DEFAULT, tf.lite.OpsSet.SELECT_TF_OPS] + converter.optimizations = [ + tf.lite.Optimize.DEFAULT] # in their code is [tf.lite.Optimize.DEFAULT, tf.lite.OpsSet.SELECT_TF_OPS] # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.target_spec.supported_types = [] try: - # calc_quantization_stats(converter=converter) + # calc_quantization_stats(converter=converter, result_file="result_tflite_quant_yolonas_l.csv") + # stop = "here" + # tflite_model = converter.convert() - tflite_path = "yolonas_s_for_tflite_selctive_quant.tflite" - suspected_layers = ['StatefulPartitionedCall:0', 'model/tf.math.subtract/Sub', - 'model/tf.concat_15/concat','model/tf.__operators__.add_20/AddV2'] + + tflite_path = "yolonas_l_for_tflite_selctive_quant.tflite" + # suspected_layers = ['StatefulPartitionedCall:0', 'model/tf.math.subtract/Sub', + # 'model/tf.concat_15/concat','model/tf.__operators__.add_20/AddV2'] + suspected_layers_yolonas_l = ['StatefulPartitionedCall:0', 'model/tf.concat_15/concat', + 'model/tf.math.subtract/Sub', 'model/tf.__operators__.add_28/AddV2'] debug_options = tf.lite.experimental.QuantizationDebugOptions( - denylisted_nodes=suspected_layers) + denylisted_nodes=suspected_layers_yolonas_l) debugger = tf.lite.experimental.QuantizationDebugger( converter=converter, @@ -287,12 +303,13 @@ def run(cfg: DictConfig): debug_options=debug_options) print(f"Compile model to tflite with Selective Quantization, emitted layers: " - f" {suspected_layers}, saves in path {tflite_path}") + f" {suspected_layers_yolonas_l}, saves in path {tflite_path}") tflite_model = debugger.get_nondebug_quantized_model() except Exception as e: - print("======================== Turn off `experimental_new_quantizer`, and try again ======================") + print( + "======================== Turn off `experimental_new_quantizer`, and try again ======================") print(e) converter.experimental_new_quantizer = False tflite_model = converter.convert() diff --git a/examples/yolonas_tflite_compat.py b/examples/yolonas_tflite_compat.py index f0a5807..10bc666 100644 --- a/examples/yolonas_tflite_compat.py +++ b/examples/yolonas_tflite_compat.py @@ -13,7 +13,6 @@ from super_gradients.training.models import get_arch_params from super_gradients.training.utils import HpmStruct, get_param - @register_model() class YoloNAS_S_TFLite(YoloNAS): def __init__(self, arch_params): @@ -33,6 +32,42 @@ def __init__(self, arch_params): ) +@register_model() +class YoloNAS_M_TFLite(YoloNAS): + def __init__(self, arch_params): + default_arch_params = get_arch_params("yolo_nas_m_arch_params") + default_arch_params["heads"]["NDFLHeadsTFlite"] = default_arch_params["heads"].pop("NDFLHeads") + merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params)) + merged_arch_params.override(**arch_params.to_dict()) + super().__init__( + backbone=merged_arch_params.backbone, + neck=merged_arch_params.neck, + heads=merged_arch_params.heads, + num_classes=get_param(merged_arch_params, "num_classes", None), + in_channels=get_param(merged_arch_params, "in_channels", 3), + bn_momentum=get_param(merged_arch_params, "bn_momentum", None), + bn_eps=get_param(merged_arch_params, "bn_eps", None), + inplace_act=get_param(merged_arch_params, "inplace_act", None), + ) + +@register_model() +class YoloNAS_L_TFLite(YoloNAS): + def __init__(self, arch_params): + default_arch_params = get_arch_params("yolo_nas_l_arch_params") + default_arch_params["heads"]["NDFLHeadsTFlite"] = default_arch_params["heads"].pop("NDFLHeads") + merged_arch_params = HpmStruct(**copy.deepcopy(default_arch_params)) + merged_arch_params.override(**arch_params.to_dict()) + super().__init__( + backbone=merged_arch_params.backbone, + neck=merged_arch_params.neck, + heads=merged_arch_params.heads, + num_classes=get_param(merged_arch_params, "num_classes", None), + in_channels=get_param(merged_arch_params, "in_channels", 3), + bn_momentum=get_param(merged_arch_params, "bn_momentum", None), + bn_eps=get_param(merged_arch_params, "bn_eps", None), + inplace_act=get_param(merged_arch_params, "inplace_act", None), + ) + @register_detection_module() class NDFLHeadsTFlite(NDFLHeads): def forward_eval(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]: @@ -79,4 +114,6 @@ def forward_eval(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor] def _generate_anchors(self, feats=None, dtype=None, device=None): anchor_points, stride_tensor = super()._generate_anchors(feats, dtype, device) - return anchor_points.unsqueeze(0).unsqueeze(0), stride_tensor.unsqueeze(0).unsqueeze(0) \ No newline at end of file + return anchor_points.unsqueeze(0).unsqueeze(0), stride_tensor.unsqueeze(0).unsqueeze(0) + +