Skip to content

Commit

Permalink
pre-commit applied
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal committed Sep 5, 2024
1 parent 36cb589 commit 4c5680d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 89 deletions.
71 changes: 41 additions & 30 deletions src/finn/transformation/streamline/reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from copy import deepcopy

import numpy as np
import qonnx.core.data_layout as DataLayout
import warnings
from copy import deepcopy
from onnx import TensorProto
from onnx import helper as oh
from qonnx.core.datatype import DataType
Expand Down Expand Up @@ -721,7 +720,8 @@ def apply(self, model):
elif producer is not None and producer.op_type == "Transpose":
perms = list(get_by_name(producer.attribute, "perm").ints)
if perms == [0, 3, 1, 2]:
# check if the producer is a fork node (need to move it past the fork before this transform)
# check if the producer is a fork node
# (need to move it past the fork before this transform)
if model.is_fork_node(producer):
model = model.transform(MoveTransposePastFork())
# topology modified, "ask" ModelWrapper to apply this transform again
Expand Down Expand Up @@ -777,7 +777,8 @@ def apply(self, model):
if producer is not None and producer.op_type == "Transpose":
perms = list(get_by_name(producer.attribute, "perm").ints)
if perms == [0, 3, 1, 2]:
# check if the producer is a fork node (need to move it past the fork before this transform)
# check if the producer is a fork node
# (need to move it past the fork before this transform)
if model.is_fork_node(producer):
model = model.transform(MoveTransposePastFork())
# topology modified, "ask" ModelWrapper to apply this transform again
Expand Down Expand Up @@ -957,14 +958,18 @@ def apply(self, model):
for split_output_idx, old_split_output in enumerate(split_outputs):
new_mul_node = deepcopy(producer)
new_split_output = model.make_new_valueinfo_name()
model.set_tensor_datatype(new_split_output, model.get_tensor_datatype(producer.input[0]))
# model.set_tensor_layout(new_split_output, model.get_tensor_layout(producer.input[0]))
model.set_tensor_shape(new_split_output, model.get_tensor_shape(old_split_output))

model.set_tensor_datatype(
new_split_output, model.get_tensor_datatype(producer.input[0])
)

model.set_tensor_shape(
new_split_output, model.get_tensor_shape(old_split_output)
)

n.output[split_output_idx] = new_split_output
new_mul_node.input[0] = new_split_output
new_mul_node.output[0] = old_split_output

graph.node.insert(node_ind, new_mul_node)
node_ind += 1

Expand All @@ -973,15 +978,13 @@ def apply(self, model):
graph.node.remove(producer)
graph_modified = True


if graph_modified:
model = model.transform(SortGraph(), make_deepcopy=False, cleanup=False)

return (model, graph_modified)


class MoveTransposePastSplit(Transformation):

def __init__(self):
super().__init__()
self.ops_to_move = ["Transpose"]
Expand All @@ -1004,14 +1007,18 @@ def apply(self, model):
new_trans_node = deepcopy(producer)
new_split_output = model.make_new_valueinfo_name()
old_split_output_shape = model.get_tensor_shape(old_split_output)
model.set_tensor_datatype(new_split_output, model.get_tensor_datatype(producer.input[0]))
# model.set_tensor_layout(new_split_output, model.get_tensor_layout(producer.input[0]))
model.set_tensor_shape(new_split_output, permute_shape(old_split_output_shape, reverse_perm))

model.set_tensor_datatype(
new_split_output, model.get_tensor_datatype(producer.input[0])
)

model.set_tensor_shape(
new_split_output, permute_shape(old_split_output_shape, reverse_perm)
)

n.output[split_output_idx] = new_split_output
new_trans_node.input[0] = new_split_output
new_trans_node.output[0] = old_split_output

graph.node.insert(node_ind, new_trans_node)
node_ind += 1

Expand Down Expand Up @@ -1342,7 +1349,7 @@ def move_node(self, model, n, producers):
model.graph.node.remove(prod)

return True

def are_producers_identical(self, model, producers):
"""
Checks only op_types
Expand Down Expand Up @@ -1428,18 +1435,18 @@ def are_producers_identical(self, model, producers):
return True

def move_node(self, model, n, producers):
'''
"""
We use the base move_node method to move the first producer
past the join node (and delete the rest)
'''
"""
add_inits = [model.get_initializer(producer.input[1]) for producer in producers]
new_init = np.sum(add_inits)
model.set_initializer(producers[0].input[1], new_init)
super().move_node(model, n, producers)

return True


class MoveTransposePastJoinConcat(MoveIdenticalOpPastJoinOp):
def __init__(self):
super().__init__(["Transpose"], ["Concat"])
Expand All @@ -1461,7 +1468,7 @@ def move_node(self, model, n, producers):
for i in range(len(n.input)):
n.input[i] = trans_inputs[i]

new_concat_out = model.make_new_valueinfo_name() #reuse tensor
new_concat_out = model.make_new_valueinfo_name() # reuse tensor
# reverse the permutation of the concat output
transpose_perm = get_by_name(producers[0].attribute, "perm").ints
reverse_perm = np.argsort(transpose_perm)
Expand Down Expand Up @@ -1490,6 +1497,7 @@ class MoveAffinePastJoinConcat(MoveIdenticalOpPastJoinOp):
"""
Applies to scalar linear or channelwise affine ops with the same parameter value
"""

def __init__(self, linear_ops=["Mul", "Add"]):
super().__init__(linear_ops, ["Concat"])

Expand All @@ -1499,17 +1507,20 @@ def are_producers_identical_scalar_ops(self, model, producers):
producer_param = model.get_initializer(producer.input[1])
if (first_param != producer_param).any() or np.prod(producer_param.shape) != 1:
return False

return True

def are_producers_channelwise_ops(self, channel_dim, model, producers):
for producer in producers:
producer_input = producer.input[0]
num_channels = model.get_tensor_shape(producer_input)[channel_dim]
producer_param = model.get_initializer(producer.input[1])
if len(producer_param.shape) < channel_dim or producer_param.shape[channel_dim] != num_channels:
if (
len(producer_param.shape) < channel_dim
or producer_param.shape[channel_dim] != num_channels
):
return False

return True

def move_node(self, model, n, producers):
Expand All @@ -1519,15 +1530,17 @@ def move_node(self, model, n, producers):
if len(producer.input) != 2 or producer_init is None:
warnings.warn("Producer found that is not single-input, skipping")
return False

# decide if producers are identical scalar ops or channelwise ops
channelwise_op = False
identical_scalar_op = self.are_producers_identical_scalar_ops(model, producers)
if not identical_scalar_op:
channel_dim = get_by_name(n.attribute, "axis").i
channelwise_op = self.are_producers_channelwise_ops(channel_dim, model, producers)
if not channelwise_op:
warnings.warn("Producers are neither identical scalar ops nor channelwise ops, skipping")
warnings.warn(
"Producers are neither identical scalar ops nor channelwise ops, skipping"
)
return False

# Rewire concat inputs
Expand Down Expand Up @@ -1561,12 +1574,10 @@ def move_node(self, model, n, producers):


class MoveMulPastJoinConcat(MoveAffinePastJoinConcat):

def __init__(self):
super().__init__(["Mul"])


class MoveAddPastJoinConcat(MoveAffinePastJoinConcat):

def __init__(self):
super().__init__(["Add"])
super().__init__(["Add"])
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import os
from os.path import join

import numpy as np
from onnx import TensorProto
Expand All @@ -36,11 +34,14 @@
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model

import finn.core.onnx_exec as oxe
from finn.transformation.streamline.reorder import MoveTransposePastJoinAdd, MoveMulPastJoinAdd, MoveAddPastJoinAdd
from finn.transformation.streamline.reorder import (
MoveAddPastJoinAdd,
MoveMulPastJoinAdd,
MoveTransposePastJoinAdd,
)


def create_add_model(identical_op):

perm = None
if "Transpose" in identical_op:
perm = identical_op.split("_")[1]
Expand All @@ -57,13 +58,9 @@ def create_add_model(identical_op):
out_shape = in_shape
op_value = 1.5

op1_node = oh.make_node(
identical_op, inputs=["in1"], outputs=["op1_out"]
)
op1_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op1_out"])

op2_node = oh.make_node(
identical_op, inputs=["in2"], outputs=["op2_out"]
)
op2_node = oh.make_node(identical_op, inputs=["in2"], outputs=["op2_out"])

if identical_op == "Transpose":
new_attr = oh.make_attribute("perm", perm)
Expand All @@ -74,10 +71,8 @@ def create_add_model(identical_op):
op2_init = oh.make_tensor_value_info("op2_param", TensorProto.FLOAT, [1])
op1_node.input.append(op1_init.name)
op2_node.input.append(op2_init.name)

add_node = oh.make_node(
"Add", inputs=["op1_out", "op2_out"], outputs=["out_join1"]
)

add_node = oh.make_node("Add", inputs=["op1_out", "op2_out"], outputs=["out_join1"])

in1 = oh.make_tensor_value_info("in1", TensorProto.FLOAT, in_shape)
in2 = oh.make_tensor_value_info("in2", TensorProto.FLOAT, in_shape)
Expand Down Expand Up @@ -109,7 +104,7 @@ def create_add_model(identical_op):
"Transpose_0231": MoveTransposePastJoinAdd(),
"Transpose_0312": MoveTransposePastJoinAdd(),
"Mul": MoveMulPastJoinAdd(),
"Add": MoveAddPastJoinAdd()
"Add": MoveAddPastJoinAdd(),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,24 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest
import os
from os.path import join
import numpy as np

import numpy as np
import os
from onnx import TensorProto
from onnx import helper as oh
from os.path import join
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model

import finn.core.onnx_exec as oxe
from finn.transformation.streamline.reorder import MoveTransposePastJoinConcat, MoveMulPastJoinConcat, MoveAddPastJoinConcat
from finn.transformation.streamline.reorder import (
MoveAddPastJoinConcat,
MoveMulPastJoinConcat,
MoveTransposePastJoinConcat,
)


def create_concat_model(identical_op):

perm = None
channelwise = False
if "Transpose" in identical_op:
Expand Down Expand Up @@ -81,14 +84,10 @@ def create_concat_model(identical_op):
op2_param_shape = [1]
op1_param = 1.5
op2_param = 1.5

op1_node = oh.make_node(
identical_op, inputs=["in1"], outputs=["op1_out"]
)

op2_node = oh.make_node(
identical_op, inputs=["in2"], outputs=["op2_out"]
)
op1_node = oh.make_node(identical_op, inputs=["in1"], outputs=["op1_out"])

op2_node = oh.make_node(identical_op, inputs=["in2"], outputs=["op2_out"])

if identical_op == "Transpose":
new_attr = oh.make_attribute("perm", perm)
Expand All @@ -99,7 +98,7 @@ def create_concat_model(identical_op):
op2_init = oh.make_tensor_value_info("op2_param", TensorProto.FLOAT, op2_param_shape)
op1_node.input.append(op1_init.name)
op2_node.input.append(op2_init.name)

concat_node = oh.make_node(
"Concat", inputs=["op1_out", "op2_out"], outputs=["out_join1"], axis=concat_axis
)
Expand Down Expand Up @@ -136,13 +135,16 @@ def create_concat_model(identical_op):
"Mul": MoveMulPastJoinConcat(),
"Mul_channelwise": MoveMulPastJoinConcat(),
"Add": MoveAddPastJoinConcat(),
"Add_channelwise": MoveAddPastJoinConcat()
"Add_channelwise": MoveAddPastJoinConcat(),
}


@pytest.mark.streamline
# Permutation of transpose node
@pytest.mark.parametrize("identical_op", ["Transpose_0231", "Transpose_0312", "Mul", "Add", "Mul_channelwise", "Add_channelwise"])
@pytest.mark.parametrize(
"identical_op",
["Transpose_0231", "Transpose_0312", "Mul", "Add", "Mul_channelwise", "Add_channelwise"],
)
def test_move_identical_op_past_join_concat(identical_op):
model = create_concat_model(identical_op)
build_dir = os.environ["FINN_BUILD_DIR"]
Expand All @@ -154,13 +156,17 @@ def test_move_identical_op_past_join_concat(identical_op):

# Note: it is assumed that both tensors have the same shape and data type
input_dict = {}
input_dict[input0_tensor_name] = gen_finn_dt_tensor(model.get_tensor_datatype(input0_tensor_name),
model.get_tensor_shape(input0_tensor_name))
input_dict[input1_tensor_name] = gen_finn_dt_tensor(model.get_tensor_datatype(input1_tensor_name),
model.get_tensor_shape(input1_tensor_name))
input_dict[input0_tensor_name] = gen_finn_dt_tensor(
model.get_tensor_datatype(input0_tensor_name), model.get_tensor_shape(input0_tensor_name)
)
input_dict[input1_tensor_name] = gen_finn_dt_tensor(
model.get_tensor_datatype(input1_tensor_name), model.get_tensor_shape(input1_tensor_name)
)

model_transformed = model.transform(transform_dict[identical_op])
model_transformed.save(join(build_dir, "concat_pytest_model_{}_trans.onnx".format(identical_op)))
model_transformed.save(
join(build_dir, "concat_pytest_model_{}_trans.onnx".format(identical_op))
)

assert oxe.compare_execution(model, model_transformed, input_dict)

Expand Down
Loading

0 comments on commit 4c5680d

Please sign in to comment.