diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index a04e9c4947..33751cb4d8 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -26,11 +26,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from copy import deepcopy - import numpy as np import qonnx.core.data_layout as DataLayout import warnings +from copy import deepcopy from onnx import TensorProto from onnx import helper as oh from qonnx.core.datatype import DataType @@ -721,7 +720,8 @@ def apply(self, model): elif producer is not None and producer.op_type == "Transpose": perms = list(get_by_name(producer.attribute, "perm").ints) if perms == [0, 3, 1, 2]: - # check if the producer is a fork node (need to move it past the fork before this transform) + # check if the producer is a fork node + # (need to move it past the fork before this transform) if model.is_fork_node(producer): model = model.transform(MoveTransposePastFork()) # topology modified, "ask" ModelWrapper to apply this transform again @@ -777,7 +777,8 @@ def apply(self, model): if producer is not None and producer.op_type == "Transpose": perms = list(get_by_name(producer.attribute, "perm").ints) if perms == [0, 3, 1, 2]: - # check if the producer is a fork node (need to move it past the fork before this transform) + # check if the producer is a fork node + # (need to move it past the fork before this transform) if model.is_fork_node(producer): model = model.transform(MoveTransposePastFork()) # topology modified, "ask" ModelWrapper to apply this transform again @@ -957,14 +958,18 @@ def apply(self, model): for split_output_idx, old_split_output in enumerate(split_outputs): new_mul_node = deepcopy(producer) new_split_output = model.make_new_valueinfo_name() - model.set_tensor_datatype(new_split_output, model.get_tensor_datatype(producer.input[0])) - # model.set_tensor_layout(new_split_output, model.get_tensor_layout(producer.input[0])) - model.set_tensor_shape(new_split_output, model.get_tensor_shape(old_split_output)) - + model.set_tensor_datatype( + new_split_output, model.get_tensor_datatype(producer.input[0]) + ) + + model.set_tensor_shape( + new_split_output, model.get_tensor_shape(old_split_output) + ) + n.output[split_output_idx] = new_split_output new_mul_node.input[0] = new_split_output new_mul_node.output[0] = old_split_output - + graph.node.insert(node_ind, new_mul_node) node_ind += 1 @@ -973,7 +978,6 @@ def apply(self, model): graph.node.remove(producer) graph_modified = True - if graph_modified: model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False) @@ -981,7 +985,6 @@ def apply(self, model): class MoveTransposePastSplit(Transformation): - def __init__(self): super().__init__() self.ops_to_move = ["Transpose"] @@ -1004,14 +1007,18 @@ def apply(self, model): new_trans_node = deepcopy(producer) new_split_output = model.make_new_valueinfo_name() old_split_output_shape = model.get_tensor_shape(old_split_output) - model.set_tensor_datatype(new_split_output, model.get_tensor_datatype(producer.input[0])) - # model.set_tensor_layout(new_split_output, model.get_tensor_layout(producer.input[0])) - model.set_tensor_shape(new_split_output, permute_shape(old_split_output_shape, reverse_perm)) - + model.set_tensor_datatype( + new_split_output, model.get_tensor_datatype(producer.input[0]) + ) + + model.set_tensor_shape( + new_split_output, permute_shape(old_split_output_shape, reverse_perm) + ) + n.output[split_output_idx] = new_split_output new_trans_node.input[0] = new_split_output new_trans_node.output[0] = old_split_output - + graph.node.insert(node_ind, new_trans_node) node_ind += 1 @@ -1342,7 +1349,7 @@ def move_node(self, model, n, producers): model.graph.node.remove(prod) return True - + def are_producers_identical(self, model, producers): """ Checks only op_types @@ -1428,10 +1435,10 @@ def are_producers_identical(self, model, producers): return True def move_node(self, model, n, producers): - ''' + """ We use the base move_node method to move the first producer past the join node (and delete the rest) - ''' + """ add_inits = [model.get_initializer(producer.input[1]) for producer in producers] new_init = np.sum(add_inits) model.set_initializer(producers[0].input[1], new_init) @@ -1439,7 +1446,7 @@ def move_node(self, model, n, producers): return True - + class MoveTransposePastJoinConcat(MoveIdenticalOpPastJoinOp): def __init__(self): super().__init__(["Transpose"], ["Concat"]) @@ -1461,7 +1468,7 @@ def move_node(self, model, n, producers): for i in range(len(n.input)): n.input[i] = trans_inputs[i] - new_concat_out = model.make_new_valueinfo_name() #reuse tensor + new_concat_out = model.make_new_valueinfo_name() # reuse tensor # reverse the permutation of the concat output transpose_perm = get_by_name(producers[0].attribute, "perm").ints reverse_perm = np.argsort(transpose_perm) @@ -1490,6 +1497,7 @@ class MoveAffinePastJoinConcat(MoveIdenticalOpPastJoinOp): """ Applies to scalar linear or channelwise affine ops with the same parameter value """ + def __init__(self, linear_ops=["Mul", "Add"]): super().__init__(linear_ops, ["Concat"]) @@ -1499,17 +1507,20 @@ def are_producers_identical_scalar_ops(self, model, producers): producer_param = model.get_initializer(producer.input[1]) if (first_param != producer_param).any() or np.prod(producer_param.shape) != 1: return False - + return True - + def are_producers_channelwise_ops(self, channel_dim, model, producers): for producer in producers: producer_input = producer.input[0] num_channels = model.get_tensor_shape(producer_input)[channel_dim] producer_param = model.get_initializer(producer.input[1]) - if len(producer_param.shape) < channel_dim or producer_param.shape[channel_dim] != num_channels: + if ( + len(producer_param.shape) < channel_dim + or producer_param.shape[channel_dim] != num_channels + ): return False - + return True def move_node(self, model, n, producers): @@ -1519,7 +1530,7 @@ def move_node(self, model, n, producers): if len(producer.input) != 2 or producer_init is None: warnings.warn("Producer found that is not single-input, skipping") return False - + # decide if producers are identical scalar ops or channelwise ops channelwise_op = False identical_scalar_op = self.are_producers_identical_scalar_ops(model, producers) @@ -1527,7 +1538,9 @@ def move_node(self, model, n, producers): channel_dim = get_by_name(n.attribute, "axis").i channelwise_op = self.are_producers_channelwise_ops(channel_dim, model, producers) if not channelwise_op: - warnings.warn("Producers are neither identical scalar ops nor channelwise ops, skipping") + warnings.warn( + "Producers are neither identical scalar ops nor channelwise ops, skipping" + ) return False # Rewire concat inputs @@ -1561,12 +1574,10 @@ def move_node(self, model, n, producers): class MoveMulPastJoinConcat(MoveAffinePastJoinConcat): - def __init__(self): super().__init__(["Mul"]) class MoveAddPastJoinConcat(MoveAffinePastJoinConcat): - def __init__(self): - super().__init__(["Add"]) \ No newline at end of file + super().__init__(["Add"]) diff --git a/tests/transformation/streamline/test_move_identical_op_past_join_add.py b/tests/transformation/streamline/test_move_identical_op_past_join_add.py index 5ab4986333..7226d31589 100644 --- a/tests/transformation/streamline/test_move_identical_op_past_join_add.py +++ b/tests/transformation/streamline/test_move_identical_op_past_join_add.py @@ -26,8 +26,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pytest -import os -from os.path import join import numpy as np from onnx import TensorProto @@ -36,11 +34,14 @@ from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model import finn.core.onnx_exec as oxe -from finn.transformation.streamline.reorder import MoveTransposePastJoinAdd, MoveMulPastJoinAdd, MoveAddPastJoinAdd +from finn.transformation.streamline.reorder import ( + MoveAddPastJoinAdd, + MoveMulPastJoinAdd, + MoveTransposePastJoinAdd, +) def create_add_model(identical_op): - perm = None if "Transpose" in identical_op: perm = identical_op.split("_")[1] @@ -57,13 +58,9 @@ def create_add_model(identical_op): out_shape = in_shape op_value = 1.5 - op1_node = oh.make_node( - identical_op, inputs=["in1"], outputs=["op1_out"] - ) + op1_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op1_out"]) - op2_node = oh.make_node( - identical_op, inputs=["in2"], outputs=["op2_out"] - ) + op2_node = oh.make_node(identical_op, inputs=["in2"], outputs=["op2_out"]) if identical_op == "Transpose": new_attr = oh.make_attribute("perm", perm) @@ -74,10 +71,8 @@ def create_add_model(identical_op): op2_init = oh.make_tensor_value_info("op2_param", TensorProto.FLOAT, [1]) op1_node.input.append(op1_init.name) op2_node.input.append(op2_init.name) - - add_node = oh.make_node( - "Add", inputs=["op1_out", "op2_out"], outputs=["out_join1"] - ) + + add_node = oh.make_node("Add", inputs=["op1_out", "op2_out"], outputs=["out_join1"]) in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, in_shape) in2 = oh.make_tensor_value_info("in2", TensorProto.FLOAT, in_shape) @@ -109,7 +104,7 @@ def create_add_model(identical_op): "Transpose_0231": MoveTransposePastJoinAdd(), "Transpose_0312": MoveTransposePastJoinAdd(), "Mul": MoveMulPastJoinAdd(), - "Add": MoveAddPastJoinAdd() + "Add": MoveAddPastJoinAdd(), } diff --git a/tests/transformation/streamline/test_move_identical_op_past_join_concat.py b/tests/transformation/streamline/test_move_identical_op_past_join_concat.py index 0739d6a807..2dcf90d10a 100644 --- a/tests/transformation/streamline/test_move_identical_op_past_join_concat.py +++ b/tests/transformation/streamline/test_move_identical_op_past_join_concat.py @@ -26,21 +26,24 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pytest -import os -from os.path import join -import numpy as np +import numpy as np +import os from onnx import TensorProto from onnx import helper as oh +from os.path import join from qonnx.core.modelwrapper import ModelWrapper from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model import finn.core.onnx_exec as oxe -from finn.transformation.streamline.reorder import MoveTransposePastJoinConcat, MoveMulPastJoinConcat, MoveAddPastJoinConcat +from finn.transformation.streamline.reorder import ( + MoveAddPastJoinConcat, + MoveMulPastJoinConcat, + MoveTransposePastJoinConcat, +) def create_concat_model(identical_op): - perm = None channelwise = False if "Transpose" in identical_op: @@ -81,14 +84,10 @@ def create_concat_model(identical_op): op2_param_shape = [1] op1_param = 1.5 op2_param = 1.5 - - op1_node = oh.make_node( - identical_op, inputs=["in1"], outputs=["op1_out"] - ) - op2_node = oh.make_node( - identical_op, inputs=["in2"], outputs=["op2_out"] - ) + op1_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op1_out"]) + + op2_node = oh.make_node(identical_op, inputs=["in2"], outputs=["op2_out"]) if identical_op == "Transpose": new_attr = oh.make_attribute("perm", perm) @@ -99,7 +98,7 @@ def create_concat_model(identical_op): op2_init = oh.make_tensor_value_info("op2_param", TensorProto.FLOAT, op2_param_shape) op1_node.input.append(op1_init.name) op2_node.input.append(op2_init.name) - + concat_node = oh.make_node( "Concat", inputs=["op1_out", "op2_out"], outputs=["out_join1"], axis=concat_axis ) @@ -136,13 +135,16 @@ def create_concat_model(identical_op): "Mul": MoveMulPastJoinConcat(), "Mul_channelwise": MoveMulPastJoinConcat(), "Add": MoveAddPastJoinConcat(), - "Add_channelwise": MoveAddPastJoinConcat() + "Add_channelwise": MoveAddPastJoinConcat(), } @pytest.mark.streamline # Permutation of transpose node -@pytest.mark.parametrize("identical_op", ["Transpose_0231", "Transpose_0312", "Mul", "Add", "Mul_channelwise", "Add_channelwise"]) +@pytest.mark.parametrize( + "identical_op", + ["Transpose_0231", "Transpose_0312", "Mul", "Add", "Mul_channelwise", "Add_channelwise"], +) def test_move_identical_op_past_join_concat(identical_op): model = create_concat_model(identical_op) build_dir = os.environ["FINN_BUILD_DIR"] @@ -154,13 +156,17 @@ def test_move_identical_op_past_join_concat(identical_op): # Note: it is assumed that both tensors have the same shape and data type input_dict = {} - input_dict[input0_tensor_name] = gen_finn_dt_tensor(model.get_tensor_datatype(input0_tensor_name), - model.get_tensor_shape(input0_tensor_name)) - input_dict[input1_tensor_name] = gen_finn_dt_tensor(model.get_tensor_datatype(input1_tensor_name), - model.get_tensor_shape(input1_tensor_name)) + input_dict[input0_tensor_name] = gen_finn_dt_tensor( + model.get_tensor_datatype(input0_tensor_name), model.get_tensor_shape(input0_tensor_name) + ) + input_dict[input1_tensor_name] = gen_finn_dt_tensor( + model.get_tensor_datatype(input1_tensor_name), model.get_tensor_shape(input1_tensor_name) + ) model_transformed = model.transform(transform_dict[identical_op]) - model_transformed.save(join(build_dir, "concat_pytest_model_{}_trans.onnx".format(identical_op))) + model_transformed.save( + join(build_dir, "concat_pytest_model_{}_trans.onnx".format(identical_op)) + ) assert oxe.compare_execution(model, model_transformed, input_dict) diff --git a/tests/transformation/streamline/test_move_identical_op_past_split.py b/tests/transformation/streamline/test_move_identical_op_past_split.py index 6f9c13b3a7..a104f179be 100644 --- a/tests/transformation/streamline/test_move_identical_op_past_split.py +++ b/tests/transformation/streamline/test_move_identical_op_past_split.py @@ -26,23 +26,22 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import pytest -import os -from os.path import join -import numpy as np +import numpy as np from onnx import TensorProto from onnx import helper as oh from qonnx.core.modelwrapper import ModelWrapper -from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model from qonnx.transformation.general import GiveUniqueNodeNames - +from qonnx.util.basic import gen_finn_dt_tensor import finn.core.onnx_exec as oxe -from finn.transformation.streamline.reorder import MoveScalarLinearPastSplit, MoveTransposePastSplit +from finn.transformation.streamline.reorder import ( + MoveScalarLinearPastSplit, + MoveTransposePastSplit, +) def create_split_model(identical_op): - perm = None if "Transpose" in identical_op: perm = identical_op.split("_")[1] @@ -69,9 +68,7 @@ def create_split_model(identical_op): op_value = 1.5 split = [32, 64] - op_node = oh.make_node( - identical_op, inputs=["in1"], outputs=["op_out"] - ) + op_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op_out"]) if identical_op == "Transpose": new_attr = oh.make_attribute("perm", perm) @@ -79,18 +76,15 @@ def create_split_model(identical_op): elif identical_op == "Mul" or identical_op == "Add": op_init = oh.make_tensor_value_info("op_param", TensorProto.FLOAT, [1]) op_node.input.append(op_init.name) - + in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, in_shape) op_out = oh.make_tensor_value_info("op_out", TensorProto.FLOAT, out_shape) out1_split = oh.make_tensor_value_info("out1_split", TensorProto.FLOAT, out1_split_shape) out2_split = oh.make_tensor_value_info("out2_split", TensorProto.FLOAT, out2_split_shape) split_init = oh.make_tensor_value_info("split", TensorProto.INT64, [2]) - + split_node = oh.make_node( - "Split", - [op_out.name, split_init.name], - [out1_split.name, out2_split.name], - axis=split_axis + "Split", [op_out.name, split_init.name], [out1_split.name, out2_split.name], axis=split_axis ) graph = oh.make_graph( @@ -98,9 +92,7 @@ def create_split_model(identical_op): name="test_graph", inputs=[in1], outputs=[out1_split, out2_split], - value_info=[ - op_out - ], + value_info=[op_out], ) model = oh.make_model(graph) @@ -117,7 +109,7 @@ def create_split_model(identical_op): "Transpose_0231": MoveTransposePastSplit(), "Transpose_0312": MoveTransposePastSplit(), "Mul": MoveScalarLinearPastSplit(), - "Add": MoveScalarLinearPastSplit() + "Add": MoveScalarLinearPastSplit(), } @@ -134,11 +126,14 @@ def test_move_identical_op_past_join_concat(identical_op): # Note: it is assumed that both tensors have the same shape and data type input_dict = {} - input_dict[input0_tensor_name] = gen_finn_dt_tensor(model.get_tensor_datatype(input0_tensor_name), - model.get_tensor_shape(input0_tensor_name)) - + input_dict[input0_tensor_name] = gen_finn_dt_tensor( + model.get_tensor_datatype(input0_tensor_name), model.get_tensor_shape(input0_tensor_name) + ) + model_transformed = model.transform(transform_dict[identical_op]) - # model_transformed.save(join(build_dir, "split_pytest_model_{}_trans.onnx".format(identical_op))) + # model_transformed.save( + # join(build_dir, "split_pytest_model_{}_trans.onnx".format(identical_op)) + # ) assert oxe.compare_execution(model, model_transformed, input_dict)