Skip to content

Commit

Permalink
tests for padding parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
jezsadler committed Jan 8, 2024
1 parent 0c7a1ce commit f47907c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
24 changes: 13 additions & 11 deletions src/omlt/io/onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,15 @@ def _visit_node(self, node, next_nodes):

def _consume_dense_nodes(self, node, next_nodes):
"""Starting from a MatMul node, consume nodes to form a dense Ax + b node."""
# This should only be called when we know we have a starting MatMul node. This
# error indicates a bug in the function calling this one.
if node.op_type != "MatMul":
raise ValueError(
f"{node.name} is a {node.op_type} node, only MatMul nodes can be used as starting points for consumption."
f"{node.name} is a {node.op_type} node, but the method for parsing MatMul nodes was invoked."
)
if len(node.input) != 2:
raise ValueError(
f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 input dimensions can be used as starting points for consumption."
f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 input dimensions."
)

[in_0, in_1] = list(node.input)
Expand All @@ -200,7 +202,7 @@ def _consume_dense_nodes(self, node, next_nodes):
raise TypeError(f"Expected a node next, got a {type_} instead.")
if node.op_type != "Add":
raise ValueError(
f"The first node to be consumed, {node.name}, is a {node.op_type} node. Only Add nodes are supported."
f"The next node to be parsed, {node.name}, is a {node.op_type} node. Only Add nodes are supported."
)

# extract biases
Expand Down Expand Up @@ -255,11 +257,11 @@ def _consume_gemm_dense_nodes(self, node, next_nodes):
"""Starting from a Gemm node, consume nodes to form a dense aAB + bC node."""
if node.op_type != "Gemm":
raise ValueError(
f"{node.name} is a {node.op_type} node, only Gemm nodes can be used as starting points for consumption."
f"{node.name} is a {node.op_type} node, but the method for parsing Gemm nodes was invoked."
)
if len(node.input) != 3:
raise ValueError(
f"{node.name} input has {len(node.input)} dimensions, only nodes with 3 input dimensions can be used as starting points for consumption."
f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 3 input dimensions."
)

attr = _collect_attributes(node)
Expand Down Expand Up @@ -310,11 +312,11 @@ def _consume_conv_nodes(self, node, next_nodes):
"""
if node.op_type != "Conv":
raise ValueError(
f"{node.name} is a {node.op_type} node, only Conv nodes can be used as starting points for consumption."
f"{node.name} is a {node.op_type} node, but the method for parsing Conv nodes was invoked."
)
if len(node.input) not in [2, 3]:
raise ValueError(
f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 or 3 input dimensions can be used as starting points for consumption."
f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 or 3 input dimensions."
)

if len(node.input) == 2:
Expand Down Expand Up @@ -422,11 +424,11 @@ def _consume_reshape_nodes(self, node, next_nodes):
"""Parse a Reshape node."""
if node.op_type != "Reshape":
raise ValueError(
f"{node.name} is a {node.op_type} node, only Reshape nodes can be used as starting points for consumption."
f"{node.name} is a {node.op_type} node, but the method for parsing Reshape nodes was invoked."
)
if len(node.input) != 2:
raise ValueError(
f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 input dimensions can be used as starting points for consumption."
f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 input dimensions."
)
[in_0, in_1] = list(node.input)
input_layer = self._node_map[in_0]
Expand All @@ -443,7 +445,7 @@ def _consume_pool_nodes(self, node, next_nodes):
"""
if node.op_type not in _POOLING_OP_TYPES:
raise ValueError(
f"{node.name} is a {node.op_type} node, only MaxPool nodes can be used as starting points for consumption."
f"{node.name} is a {node.op_type} node, but the method for parsing MaxPool nodes was invoked."
)
pool_func_name = "max"

Expand All @@ -454,7 +456,7 @@ def _consume_pool_nodes(self, node, next_nodes):
)
if len(node.input) != 1:
raise ValueError(
f"{node.name} input has {len(node.input)} dimensions, only nodes with 1 input dimension can be used as starting points for consumption."
f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 1 input dimension."
)

input_layer, transformer = self._node_input_and_transformer(node.input[0])
Expand Down
49 changes: 33 additions & 16 deletions tests/io/test_onnx_parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from pathlib import Path
import pytest

from omlt.dependencies import onnx, onnx_available
from tests.conftest import _Datadir

if onnx_available:
from omlt.io.onnx import load_onnx_neural_network
from omlt.io.onnx_parser import NetworkParser


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_linear_131(datadir):
def test_linear_131(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
net = load_onnx_neural_network(model)
layers = list(net.layers)
Expand All @@ -20,7 +22,7 @@ def test_linear_131(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_linear_131_relu(datadir):
def test_linear_131_relu(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131_relu.onnx"))
net = load_onnx_neural_network(model)
layers = list(net.layers)
Expand All @@ -32,7 +34,7 @@ def test_linear_131_relu(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_linear_131_sigmoid(datadir):
def test_linear_131_sigmoid(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131_sigmoid.onnx"))
net = load_onnx_neural_network(model)
layers = list(net.layers)
Expand All @@ -44,7 +46,7 @@ def test_linear_131_sigmoid(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_gemm(datadir):
def test_gemm(datadir: _Datadir):
model = onnx.load(datadir.file("gemm.onnx"))
net = load_onnx_neural_network(model)
layers = list(net.layers)
Expand All @@ -58,7 +60,7 @@ def test_gemm(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_gemm_transB(datadir):
def test_gemm_transB(datadir: _Datadir):
model = onnx.load(datadir.file("gemm_not_transB.onnx"))
model_transB = onnx.load(datadir.file("gemm_transB.onnx"))
net = load_onnx_neural_network(model)
Expand All @@ -74,7 +76,7 @@ def test_gemm_transB(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_conv(datadir):
def test_conv(datadir: _Datadir):
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
net = load_onnx_neural_network(model)
layers = list(net.layers)
Expand All @@ -85,9 +87,24 @@ def test_conv(datadir):
assert layers[1].strides == [1, 1]
assert layers[1].kernel_shape == (2, 2)

@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_conv_dilations(datadir: _Datadir):
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
for attr in model.graph.node[0].attribute:
if attr.name == "dilations":
attr.ints.clear()
attr.ints.extend([2,2])
if attr.name == "pads":
attr.ints.clear()
attr.ints.extend([1,2,1,0])
model.graph.node[1].attribute[0].t.raw_data = numpy_helper.from_array(np.array([-1,128])).raw_data
net = load_onnx_neural_network(model)
layers = list(net.layers)
assert layers[1].dilations == [2,2]
assert layers[1].pads == [1,2,1,0]

@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_maxpool(datadir):
def test_maxpool(datadir: _Datadir):
model = onnx.load(datadir.file("maxpool_2d.onnx"))
net = load_onnx_neural_network(model)
layers = list(net.layers)
Expand All @@ -109,7 +126,7 @@ def test_maxpool(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_input_tensor_invalid_dims(datadir):
def test_input_tensor_invalid_dims(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
model.graph.input[0].type.tensor_type.shape.dim[1].dim_value = 0
parser = NetworkParser()
Expand All @@ -120,7 +137,7 @@ def test_input_tensor_invalid_dims(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_no_input_layers(datadir):
def test_no_input_layers(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
model.graph.input.remove(model.graph.input[0])
parser = NetworkParser()
Expand All @@ -131,7 +148,7 @@ def test_no_input_layers(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_node_no_inputs(datadir):
def test_node_no_inputs(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
while len(model.graph.node[0].input) > 0:
model.graph.node[0].input.pop()
Expand All @@ -143,7 +160,7 @@ def test_node_no_inputs(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_wrong_node_type(datadir):
def test_consume_wrong_node_type(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph, None, None)
Expand Down Expand Up @@ -190,7 +207,7 @@ def test_consume_wrong_node_type(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_dense_wrong_dims(datadir):
def test_consume_dense_wrong_dims(datadir: _Datadir):
model = onnx.load(datadir.file("keras_linear_131.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph, None, None)
Expand All @@ -208,7 +225,7 @@ def test_consume_dense_wrong_dims(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_gemm_wrong_dims(datadir):
def test_consume_gemm_wrong_dims(datadir: _Datadir):
model = onnx.load(datadir.file("gemm.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph, None, None)
Expand All @@ -222,7 +239,7 @@ def test_consume_gemm_wrong_dims(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_conv_wrong_dims(datadir):
def test_consume_conv_wrong_dims(datadir: _Datadir):
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph, None, None)
Expand All @@ -236,7 +253,7 @@ def test_consume_conv_wrong_dims(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_reshape_wrong_dims(datadir):
def test_consume_reshape_wrong_dims(datadir: _Datadir):
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph, None, None)
Expand All @@ -250,7 +267,7 @@ def test_consume_reshape_wrong_dims(datadir):


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_consume_maxpool_wrong_dims(datadir):
def test_consume_maxpool_wrong_dims(datadir: _Datadir):
model = onnx.load(datadir.file("maxpool_2d.onnx"))
parser = NetworkParser()
parser.parse_network(model.graph, None, None)
Expand Down

0 comments on commit f47907c

Please sign in to comment.