Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
Boomerl committed Nov 10, 2023
1 parent 4c463ca commit 795d4cd
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,3 +2582,93 @@ def forward(self, x, param_preds, points, strides):
deploy_cfg=deploy_cfg)

assert rewrite_outputs is not None


def get_sparseinst():
"""SparseInst Config."""
test_cfg = Config(dict(score_thr=0.4, mask_thr_binary=0.45))
data_preprocessor = dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_mask=True,
pad_size_divisor=32)
backbone = Config(
dict(
type='ResNet',
depth=50,
out_indices=(1, 2, 3),
frozen_stages=0,
norm_cfg=dict(type='BN', requires_grad=False),
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')))

from projects.SparseInst.sparseinst import SparseInst
model = SparseInst(
data_preprocessor=data_preprocessor,
backbone=backbone,
encoder=dict(
type='InstanceContextEncoder', in_channels=[512, 1024, 2048]),
decoder=dict(
type='BaseIAMDecoder', in_channels=256 + 2, num_classes=80),
criterion=dict(
type='SparseInstCriterion',
num_classes=80,
assigner=dict(type='SparseInstMatcher', alpha=0.8, beta=0.2)),
test_cfg=test_cfg,
init_cfg=dict(
type='Normal',
layer='Conv2d',
std=0.01,
override=dict(
type='Normal', name='conv_cls', std=0.01, bias_prob=0.01)))

model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_sparseinst_predict(backend_type):
"""Test predict rewrite of sparseinst."""
check_backend(backend_type)
sparseinst = get_sparseinst()
sparseinst.cpu().eval()

output_names = ['dets', 'labels', 'masks']
deploy_cfg = Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
export_postprocess_mask=False))))

img = torch.randn(1, 3, 320, 320)
from mmdet.structures import DetDataSample
data_sample = DetDataSample(metainfo=dict(img_shape=(320, 320, 3)))

# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(
sparseinst, 'predict', batch_data_samples=[data_sample])
rewrite_inputs = {'batch_inputs': img}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

if is_backend_output:
assert rewrite_outputs[0].shape[-1] == 5
assert rewrite_outputs[1] is not None
assert rewrite_outputs[2] is not None
else:
assert rewrite_outputs is not None

0 comments on commit 795d4cd

Please sign in to comment.