diff --git a/src/omlt/io/onnx_parser.py b/src/omlt/io/onnx_parser.py index 1d1a3d6e..d11e90e0 100644 --- a/src/omlt/io/onnx_parser.py +++ b/src/omlt/io/onnx_parser.py @@ -176,13 +176,15 @@ def _visit_node(self, node, next_nodes): def _consume_dense_nodes(self, node, next_nodes): """Starting from a MatMul node, consume nodes to form a dense Ax + b node.""" + # This should only be called when we know we have a starting MatMul node. This + # error indicates a bug in the function calling this one. if node.op_type != "MatMul": raise ValueError( - f"{node.name} is a {node.op_type} node, only MatMul nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing MatMul nodes was invoked." ) if len(node.input) != 2: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 input dimensions." ) [in_0, in_1] = list(node.input) @@ -200,7 +202,7 @@ def _consume_dense_nodes(self, node, next_nodes): raise TypeError(f"Expected a node next, got a {type_} instead.") if node.op_type != "Add": raise ValueError( - f"The first node to be consumed, {node.name}, is a {node.op_type} node. Only Add nodes are supported." + f"The next node to be parsed, {node.name}, is a {node.op_type} node. Only Add nodes are supported." ) # extract biases @@ -255,11 +257,11 @@ def _consume_gemm_dense_nodes(self, node, next_nodes): """Starting from a Gemm node, consume nodes to form a dense aAB + bC node.""" if node.op_type != "Gemm": raise ValueError( - f"{node.name} is a {node.op_type} node, only Gemm nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing Gemm nodes was invoked." ) if len(node.input) != 3: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 3 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 3 input dimensions." ) attr = _collect_attributes(node) @@ -310,11 +312,11 @@ def _consume_conv_nodes(self, node, next_nodes): """ if node.op_type != "Conv": raise ValueError( - f"{node.name} is a {node.op_type} node, only Conv nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing Conv nodes was invoked." ) if len(node.input) not in [2, 3]: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 or 3 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 or 3 input dimensions." ) if len(node.input) == 2: @@ -422,11 +424,11 @@ def _consume_reshape_nodes(self, node, next_nodes): """Parse a Reshape node.""" if node.op_type != "Reshape": raise ValueError( - f"{node.name} is a {node.op_type} node, only Reshape nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing Reshape nodes was invoked." ) if len(node.input) != 2: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 2 input dimensions can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 2 input dimensions." ) [in_0, in_1] = list(node.input) input_layer = self._node_map[in_0] @@ -443,7 +445,7 @@ def _consume_pool_nodes(self, node, next_nodes): """ if node.op_type not in _POOLING_OP_TYPES: raise ValueError( - f"{node.name} is a {node.op_type} node, only MaxPool nodes can be used as starting points for consumption." + f"{node.name} is a {node.op_type} node, but the method for parsing MaxPool nodes was invoked." ) pool_func_name = "max" @@ -454,7 +456,7 @@ def _consume_pool_nodes(self, node, next_nodes): ) if len(node.input) != 1: raise ValueError( - f"{node.name} input has {len(node.input)} dimensions, only nodes with 1 input dimension can be used as starting points for consumption." + f"{node.name} input has {len(node.input)} dimensions, but the parser requires the starting node to have 1 input dimension." ) input_layer, transformer = self._node_input_and_transformer(node.input[0]) diff --git a/tests/io/test_onnx_parser.py b/tests/io/test_onnx_parser.py index 763b282c..afbf2a89 100644 --- a/tests/io/test_onnx_parser.py +++ b/tests/io/test_onnx_parser.py @@ -1,6 +1,8 @@ +from pathlib import Path import pytest from omlt.dependencies import onnx, onnx_available +from tests.conftest import _Datadir if onnx_available: from omlt.io.onnx import load_onnx_neural_network @@ -8,7 +10,7 @@ @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_linear_131(datadir): +def test_linear_131(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131.onnx")) net = load_onnx_neural_network(model) layers = list(net.layers) @@ -20,7 +22,7 @@ def test_linear_131(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_linear_131_relu(datadir): +def test_linear_131_relu(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131_relu.onnx")) net = load_onnx_neural_network(model) layers = list(net.layers) @@ -32,7 +34,7 @@ def test_linear_131_relu(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_linear_131_sigmoid(datadir): +def test_linear_131_sigmoid(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131_sigmoid.onnx")) net = load_onnx_neural_network(model) layers = list(net.layers) @@ -44,7 +46,7 @@ def test_linear_131_sigmoid(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_gemm(datadir): +def test_gemm(datadir: _Datadir): model = onnx.load(datadir.file("gemm.onnx")) net = load_onnx_neural_network(model) layers = list(net.layers) @@ -58,7 +60,7 @@ def test_gemm(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_gemm_transB(datadir): +def test_gemm_transB(datadir: _Datadir): model = onnx.load(datadir.file("gemm_not_transB.onnx")) model_transB = onnx.load(datadir.file("gemm_transB.onnx")) net = load_onnx_neural_network(model) @@ -74,7 +76,7 @@ def test_gemm_transB(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_conv(datadir): +def test_conv(datadir: _Datadir): model = onnx.load(datadir.file("convx1_gemmx1.onnx")) net = load_onnx_neural_network(model) layers = list(net.layers) @@ -85,9 +87,24 @@ def test_conv(datadir): assert layers[1].strides == [1, 1] assert layers[1].kernel_shape == (2, 2) +@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") +def test_conv_dilations(datadir: _Datadir): + model = onnx.load(datadir.file("convx1_gemmx1.onnx")) + for attr in model.graph.node[0].attribute: + if attr.name == "dilations": + attr.ints.clear() + attr.ints.extend([2,2]) + if attr.name == "pads": + attr.ints.clear() + attr.ints.extend([1,2,1,0]) + model.graph.node[1].attribute[0].t.raw_data = numpy_helper.from_array(np.array([-1,128])).raw_data + net = load_onnx_neural_network(model) + layers = list(net.layers) + assert layers[1].dilations == [2,2] + assert layers[1].pads == [1,2,1,0] @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_maxpool(datadir): +def test_maxpool(datadir: _Datadir): model = onnx.load(datadir.file("maxpool_2d.onnx")) net = load_onnx_neural_network(model) layers = list(net.layers) @@ -109,7 +126,7 @@ def test_maxpool(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_input_tensor_invalid_dims(datadir): +def test_input_tensor_invalid_dims(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131.onnx")) model.graph.input[0].type.tensor_type.shape.dim[1].dim_value = 0 parser = NetworkParser() @@ -120,7 +137,7 @@ def test_input_tensor_invalid_dims(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_no_input_layers(datadir): +def test_no_input_layers(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131.onnx")) model.graph.input.remove(model.graph.input[0]) parser = NetworkParser() @@ -131,7 +148,7 @@ def test_no_input_layers(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_node_no_inputs(datadir): +def test_node_no_inputs(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131.onnx")) while len(model.graph.node[0].input) > 0: model.graph.node[0].input.pop() @@ -143,7 +160,7 @@ def test_node_no_inputs(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_consume_wrong_node_type(datadir): +def test_consume_wrong_node_type(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131.onnx")) parser = NetworkParser() parser.parse_network(model.graph, None, None) @@ -190,7 +207,7 @@ def test_consume_wrong_node_type(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_consume_dense_wrong_dims(datadir): +def test_consume_dense_wrong_dims(datadir: _Datadir): model = onnx.load(datadir.file("keras_linear_131.onnx")) parser = NetworkParser() parser.parse_network(model.graph, None, None) @@ -208,7 +225,7 @@ def test_consume_dense_wrong_dims(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_consume_gemm_wrong_dims(datadir): +def test_consume_gemm_wrong_dims(datadir: _Datadir): model = onnx.load(datadir.file("gemm.onnx")) parser = NetworkParser() parser.parse_network(model.graph, None, None) @@ -222,7 +239,7 @@ def test_consume_gemm_wrong_dims(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_consume_conv_wrong_dims(datadir): +def test_consume_conv_wrong_dims(datadir: _Datadir): model = onnx.load(datadir.file("convx1_gemmx1.onnx")) parser = NetworkParser() parser.parse_network(model.graph, None, None) @@ -236,7 +253,7 @@ def test_consume_conv_wrong_dims(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_consume_reshape_wrong_dims(datadir): +def test_consume_reshape_wrong_dims(datadir: _Datadir): model = onnx.load(datadir.file("convx1_gemmx1.onnx")) parser = NetworkParser() parser.parse_network(model.graph, None, None) @@ -250,7 +267,7 @@ def test_consume_reshape_wrong_dims(datadir): @pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test") -def test_consume_maxpool_wrong_dims(datadir): +def test_consume_maxpool_wrong_dims(datadir: _Datadir): model = onnx.load(datadir.file("maxpool_2d.onnx")) parser = NetworkParser() parser.parse_network(model.graph, None, None)