diff --git a/.gitignore b/.gitignore index 07118a90..8725962b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.out typings +outputs # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/main_stream_co copy.py b/main_stream_co copy.py new file mode 100644 index 00000000..dff081c1 --- /dev/null +++ b/main_stream_co copy.py @@ -0,0 +1,69 @@ +import logging as _logging +import re + +from stream.api import optimize_allocation_co +from stream.utils import CostModelEvaluationLUT +from stream.visualization.memory_usage import plot_memory_usage +from stream.visualization.schedule import ( + visualize_timeline_plotly, +) + +_logging_level = _logging.INFO +_logging_format = "%(asctime)s - %(name)s.%(funcName)s +%(lineno)s - %(levelname)s - %(message)s" +_logging.basicConfig(level=_logging_level, format=_logging_format) + +############################################INPUTS############################################ +accelerator = "stream/inputs/examples/hardware/tpu_like_quad_core.yaml" +workload_path = "outputs/custom_ssm.onnx" +mapping_path = "stream/inputs/examples/mapping/tpu_like_quad_core copy.yaml" +mode = "fused" +layer_stacks = [tuple(range(0, 11)), tuple(range(11, 22))] + list((i,) for i in range(22, 49)) +############################################################################################## + +################################PARSING############################### +hw_name = accelerator.split("/")[-1].split(".")[0] +wl_name = re.split(r"/|\.", workload_path)[-1] +if wl_name == "onnx": + wl_name = re.split(r"/|\.", workload_path)[-2] +experiment_id = f"{hw_name}-{wl_name}-{mode}-constraint_optimization" +###################################################################### + +scme = optimize_allocation_co( + hardware=accelerator, + workload=workload_path, + mapping=mapping_path, + mode=mode, + layer_stacks=layer_stacks, + experiment_id=experiment_id, + output_path="outputs", + skip_if_exists=False, +) + +############PLOTTING############# +plot_full_schedule = True +draw_dependencies = True +plot_data_transfer = True +section_start_percent = (0,) +percent_shown = (100,) +################################# + +#########################PLOTTING PATHS############################## +timeline_fig_path_plotly = f"outputs/{experiment_id}/schedule.html" +memory_fig_path = f"outputs/{experiment_id}/memory.png" +##################################################################### + +#####################CostModelEvaluationLUT LOAD############################# +cost_lut_path = f"outputs/{experiment_id}/cost_lut_post_co.pickle" +cost_lut = CostModelEvaluationLUT(cost_lut_path) +############################################################################# + +# Plotting schedule timeline of best SCME +visualize_timeline_plotly( + scme, + draw_dependencies=draw_dependencies, + draw_communication=plot_data_transfer, + fig_path=timeline_fig_path_plotly, + cost_lut=cost_lut, +) +# Plotting memory usage of best SCME +plot_memory_usage(scme, section_start_percent, percent_shown, fig_path=memory_fig_path) diff --git a/outputs/custom_ssm.onnx b/outputs/custom_ssm.onnx new file mode 100644 index 00000000..97da8bdb Binary files /dev/null and b/outputs/custom_ssm.onnx differ diff --git a/stream/inputs/examples/mapping/tpu_like_quad_core copy.yaml b/stream/inputs/examples/mapping/tpu_like_quad_core copy.yaml new file mode 100644 index 00000000..c44ae57e --- /dev/null +++ b/stream/inputs/examples/mapping/tpu_like_quad_core copy.yaml @@ -0,0 +1,55 @@ +- name: default + core_allocation: [0, 1, 2, 3] + intra_core_tiling: + - D, 1 + inter_core_tiling: + - K, * + +- name: Conv + core_allocation: [0, 1, 2, 3] + intra_core_tiling: + - OY, 1 + inter_core_tiling: + - K, * + +- name: Gemm + core_allocation: [0, 1, 2, 3] + intra_core_tiling: + - D, 1 + inter_core_tiling: + - H, * + +- name: Pool + core_allocation: [4] + intra_core_tiling: + - OY, 1 + inter_core_tiling: + - K, * + +- name: MaxPool + core_allocation: [4] + intra_core_tiling: + - OY, 1 + inter_core_tiling: + - K, * + +- name: AveragePool + core_allocation: [4] + intra_core_tiling: + - OY, 1 + inter_core_tiling: + - K, * + +- name: GlobalAveragePool + core_allocation: [4] + intra_core_tiling: + - OY, 1 + inter_core_tiling: + - K, * + +- name: Add + core_allocation: [5] + intra_core_tiling: + - D, 1 + inter_core_tiling: + - H, * diff --git a/stream/node_tensor.py b/stream/node_tensor.py index 76b5e58b..26dbb22e 100644 --- a/stream/node_tensor.py +++ b/stream/node_tensor.py @@ -47,7 +47,7 @@ def _get_and_increment_pointer(self): @property def shape(self) -> None: # type: ignore """Protect the original shape attribute to prevent errors""" - raise ValueError("The numpy shape of NodeTensor is hidden in an abstraction layer") + raise ValueError("The numpy shape of NodeTensor is hidden in an abstraction layer. Call `tensor_shape` instead") @property def full_shape(self): @@ -125,6 +125,25 @@ def gather(self, gather_indices: int | list[int], axis: int) -> "NodeTensor": axis = axis - 1 if axis < 0 else axis return (np.take(self.as_ndarray(), gather_indices, axis=axis)).view(NodeTensor) + def split(self, split_indices: list[int], axis: int) -> "list[NodeTensor]": + axis = axis - 1 if axis < 0 else axis + return [t.view(NodeTensor) for t in np.split(self.as_ndarray(), split_indices, axis=axis)] + + def slice(self, starts: int, ends: int, axis: int, steps: int) -> "NodeTensor": + assert starts != 1 and ends != -1 + axis = len(self.tensor_shape) - 1 if axis < 0 else axis + match axis: + case 0: + return self.as_ndarray()[starts:ends:steps, ...].view(NodeTensor) + case 1: + return self.as_ndarray()[:, starts:ends:steps, ...].view(NodeTensor) + case 2: + return self.as_ndarray()[:, :, starts:ends:steps, ...].view(NodeTensor) + case 3: + return self.as_ndarray()[:, :, :, starts:ends:steps, ...].view(NodeTensor) + case _: + raise NotImplementedError + def concat_with_empty(self, shape: tuple[int, ...], axis: int, variable_input_first: bool): empty_shape = self.convert_to_full_shape(shape) empty_tensor = np.zeros(empty_shape, dtype=object) diff --git a/stream/onnx_utils.py b/stream/onnx_utils.py new file mode 100644 index 00000000..dfa9b434 --- /dev/null +++ b/stream/onnx_utils.py @@ -0,0 +1,107 @@ +import numpy as np +from onnx import AttributeProto, ModelProto, NodeProto, numpy_helper +from zigzag.parser.onnx.utils import get_onnx_tensor_type + + +def get_attribute_as_ints( + node: NodeProto, attribute_name: str, default: list[int] | int | None = None +) -> list[int] | int: + """! Return the value of an attribute of given name from the given attributes + If name does not exist in attrs, the default provided by the caller is used. + If the caller doesn't supply a default, an error is thrown. + + """ + attrs = node.attribute + attrs_names = [attr.name for attr in attrs] + try: + name_idx = attrs_names.index(attribute_name) + value = attrs[name_idx] + attr_type = value.type + if attr_type == AttributeProto.AttributeType.INT: # type: ignore + return int(value.i) + elif attr_type == AttributeProto.AttributeType.INTS: # type: ignore + return list(value.ints) + elif attr_type == AttributeProto.AttributeType.TENSOR: # type: ignore + return list(numpy_helper.to_array(value.t).tolist()) # type: ignore + else: + raise NotImplementedError(f"Attribute extraction of type {attr_type} not supported.") + except ValueError as exc: + if default is not None: + return default + else: + raise ValueError( + f"Node {node.name} has no attribute called {attribute_name} and no default was given. Attributes = {attrs_names}." + ) from exc + + +def get_onnx_input_shapes(node: NodeProto, onnx_model: ModelProto) -> list[list[int]]: + """Return the shape of each input operand""" + input_names = node.input + input_shapes = [get_onnx_tensor_type(name, onnx_model).shape for name in input_names] + return input_shapes + + +def get_onnx_output_shapes(node: NodeProto, onnx_model: ModelProto) -> list[list[int]]: + """Return the shape of each output operand""" + + output_names = node.output + output_shapes = [get_onnx_tensor_type(name, onnx_model).shape for name in output_names] + return output_shapes + + +def has_asymmetric_input_data(node: NodeProto, onnx_model: ModelProto): + """Return true iff the node has two inputs and the input nodes have a different shape""" + if len(node.input) != 2: + return False + + input_shape1, input_shape2 = get_onnx_input_shapes(node, onnx_model) + return input_shape1 != input_shape2 + + +def get_constant_tensor_int(onnx_model: ModelProto, constant_output_name: str): + """In some cases, the constants to a node (e.g. slice and split indices) are saved as tensors within a constant + node. The output name of the constant nodes corresponds to the input name of the node that uses this constant + tensor.""" + + for node in onnx_model.graph.node: + if node.op_type == "Constant" and node.output[0] == constant_output_name: + for attr in node.attribute: + if attr.name == "value": + tensor = attr.t # This is an ONNX TensorProto + # Decode tensor to a numpy array + array = np.frombuffer(tensor.raw_data, dtype=int) + array = array.reshape([dim for dim in tensor.dims]) + + return [int(i) for i in array] + + raise ValueError(f"Cannot find {constant_output_name}") + + +def get_axis_attribute(node: NodeProto): + """Find the value of the axis associated with this ONNX node""" + ATTR_NAME = "axis" + + value = get_attribute_as_ints(node, ATTR_NAME) + if not isinstance(value, int): + raise ValueError(f"{ATTR_NAME} attribute as list of ints not supported") + return value + + +def get_split_attribute(node: NodeProto, onnx_model: ModelProto): + output_name = next(n for n in node.input if "split" in n.lower()) + return get_constant_tensor_int(onnx_model, output_name) + + +def get_slice_attributes(node: NodeProto, onnx_model: ModelProto): + """Get the `starts`, `ends`, `axes` and `steps` tensors for a slice node. + NOTE: this assumes that the attributes are given as inputs in this order""" + if len(node.input) != 5: + raise NotImplementedError("Unsure how to get slice attributes from Node") + + starts_output_name, ends_output_name, axes_output_name, steps_output_name = node.input[1:5] + + starts_value = get_constant_tensor_int(onnx_model, starts_output_name) + ends_value = get_constant_tensor_int(onnx_model, ends_output_name) + axes_value = get_constant_tensor_int(onnx_model, axes_output_name) + steps_value = get_constant_tensor_int(onnx_model, steps_output_name) + return starts_value, ends_value, axes_value, steps_value diff --git a/stream/parser/onnx/asymmetric_simd.py b/stream/parser/onnx/asymmetric_simd.py index a5ca54e7..3bedfa15 100644 --- a/stream/parser/onnx/asymmetric_simd.py +++ b/stream/parser/onnx/asymmetric_simd.py @@ -1,12 +1,9 @@ from typing import Any -from zigzag.parser.onnx.utils import ( - get_node_input_output_dimension_shapes, -) from zigzag.parser.workload_factory import LayerNodeFactory +from stream.onnx_utils import get_onnx_input_shapes, get_onnx_output_shapes from stream.parser.onnx.operator_parser import OnnxComputeOperatorParser -from stream.utils import get_onnx_input_shapes from stream.workload.computation.computation_node import ComputationNode @@ -30,7 +27,7 @@ def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[ data["name"] = self.node.name data["operator_type"] = self.node.op_type data["operand_source"] = self.get_operand_source_input_format() - data["operand_precision"] = self.get_operand_precision_input_format() + data["operand_precision"] = self.get_operand_precision_user_format() data["dimension_relations"] = [] data["loop_sizes"] = output_shape @@ -41,8 +38,15 @@ def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[ def generate_node(self): # Get the input and output activation shapes - input_shape1, input_shape2 = get_onnx_input_shapes(self.node, self.onnx_model) - _, output_shape = get_node_input_output_dimension_shapes(self.node, self.onnx_model) + input_shapes = get_onnx_input_shapes(self.node, self.onnx_model) + if len(input_shapes) != 2: + raise NotImplementedError("Only SIMD nodes with input length 2 are supported") + input_shape1, input_shape2 = input_shapes + + output_shapes = get_onnx_output_shapes(self.node, self.onnx_model) + if len(output_shapes) != 1: + raise NotImplementedError("Only SIMD nodes with input length 2 are supported") + output_shape = output_shapes.pop() if input_shape1 == output_shape: non_batched_input_shape = input_shape2 @@ -57,6 +61,7 @@ def generate_node(self): node_factory = LayerNodeFactory(node_data, mapping_data=None) node_attrs = node_factory.create_node_attr() mapping = self.get_mapping_this_node() + input_names = list(self.node.input) return ComputationNode( node_id=self.node_id, @@ -64,4 +69,5 @@ def generate_node(self): node_attr=node_attrs, mapping_attr=mapping, op_type=self.node.op_type, + input_names=input_names, ) diff --git a/stream/parser/onnx/concat.py b/stream/parser/onnx/concat.py index 3c63643e..229db1b9 100644 --- a/stream/parser/onnx/concat.py +++ b/stream/parser/onnx/concat.py @@ -7,10 +7,21 @@ class ConcatParser(OnnxOperatorParser): """Parses an onnx gather operator into a ConcatNode.""" + def get_axis_value(self): + AXIS_ATTR = "axis" + + """Find the value of the axis associated with this concat node in ONNX""" + # `axis` is an attribute of the node + try: + axis_attr = next(filter(lambda x: x.name == AXIS_ATTR, self.node.attribute)) + return axis_attr.i + except StopIteration: + raise ValueError("Axis attribute not found in ONNX node") + def generate_node(self): predecessors = self.get_node_predecessors() - axis = self.get_axis_value() + input_names = list(self.node.input) input_1, input_2 = self.node.input[0], self.node.input[1] @@ -36,13 +47,5 @@ def generate_node(self): axis=axis, constant_shape=constant_shape, variable_input_first=variable_input_first, + input_names=input_names, ) - - def get_axis_value(self): - """Find the value of the axis associated with this concat node in ONNX""" - # `axis` is an attribute of the node - try: - axis_attr = next(filter(lambda x: x.name == "axis", self.node.attribute)) - return axis_attr.i - except StopIteration: - raise ValueError("Axis attribute not found in ONNX node") diff --git a/stream/parser/onnx/conv.py b/stream/parser/onnx/conv.py index 939a8556..1181c176 100644 --- a/stream/parser/onnx/conv.py +++ b/stream/parser/onnx/conv.py @@ -19,118 +19,112 @@ class ConvParser(OnnxComputeOperatorParser): OP_TYPE = "conv" - def get_layer_node_user_format( # type: ignore + def get_layer_node_user_format( self, - kernel_shape: list[int], - strides: list[int], - dilations: list[int], - group_size: int, - padding: list[int], - ia_shape: list[int], - oa_shape: list[int], + input_shape: list[int], + output_shape: list[int], ) -> dict[str, Any]: """ Generate the necessary dictionary items required for the LayerNode creation. + + """ - # convert the data types to precisions based on the onnx definition + predecessors = self.get_node_predecessors() + + # Extract extra attributes + attrs = self.node.attribute + kernel_shape: list[int] = get_attribute_ints_with_name("kernel_shape", attrs, default=None) # type:ignore + strides: list[int] = get_attribute_ints_with_name("strides", attrs, default=[1, 1]) # type:ignore + dilations: list[int] = get_attribute_ints_with_name("dilations", attrs, default=[1, 1]) # type:ignore + group_size: int = get_attribute_ints_with_name("group", attrs, default=1) # type:ignore + padding: list[int] = get_attribute_ints_with_name("pads", attrs, default=[0, 0, 0, 0]) # type:ignore - # Equation data: dict[str, Any] = {} data["id"] = self.node_id - data["name"] = f"Layer{self.node_id}" + data["name"] = self.node.name data["operator_type"] = ConvParser.OP_TYPE - # IMPORTANT: If any of the input loops require padding, they should be defined as the rightmost dimensions in - # the equation. This is because we construct the dimensionality order and then add the padding to those last - # dimensions in the order - if group_size > 1: - data["equation"] = "O[b][g][k][oy][ox]+=W[g][c][fy][fx]*I[b][g][c][iy][ix]" - else: - data["equation"] = "O[b][g][k][oy][ox]+=W[k][c][fy][fx]*I[b][g][c][iy][ix]" + data["operand_precision"] = self.get_operand_precision_user_format() + data["operand_source"] = self.get_operand_source_user_format(predecessors) + + # 1D Conv case: append dimensions of size 1 so equation holds. Conv in FY dimension + is_1d_conv = len(kernel_shape) == 1 # Get dimension sizes from input parameters - assert ia_shape[0] == oa_shape[0], "Batch size is different for input and output activations." - B = oa_shape[0] + assert input_shape[0] == output_shape[0], "Batch size is different for input and output activations." + B = output_shape[0] G = group_size - K = ceil(oa_shape[1] / G) - OX = oa_shape[3] - OY = oa_shape[2] - C = ceil(ia_shape[1] / G) - IX = ia_shape[3] - IY = ia_shape[2] + K = ceil(output_shape[1] / G) + C = ceil(input_shape[1] / G) FX = kernel_shape[0] - FY = kernel_shape[1] - data["loop_dims"] = ["B", "K", "G", "OX", "OY", "C", "FX", "FY"] - data["loop_sizes"] = [B, K, G, OX, OY, C, FX, FY] - - data["pr_loop_dims"] = ["IX", "IY"] - data["pr_loop_sizes"] = [IX, IY] - data["dimension_relations"] = [ - f"ix={strides[0]}*ox+{dilations[0]}*fx", - f"iy={strides[1]}*oy+{dilations[1]}*fy", - ] - data["operand_precision"] = {"O": 16, "O_final": 8, "W": 8, "I": 8} - - # Add information wrt how this conv node's input/output tensors - # are represented in the onnx model vs how they are represented in the equation above. - # Because onnx doesn't actually encode the group dimension in a separate dimension - # but instead keeps it as a "groups" parameter. - # Concretely, this entry contains for the I and O operand how the G + C/K should be converted - # to a single "CH" (channel) dimension. - - # Add padding information - data["padding"] = [ - [padding[0], padding[2]], - [padding[1], padding[3]], - ] - - # Find the previous layer(s) that should be this node's parent(s) - node_inputs = self.node.input - assert len(node_inputs) >= 2, f"Conv should have at least two input names, but has: {node_inputs}." - (first_input_name, second_input_name) = node_inputs[:2] - - source_list_I = [ - src for (src, src_output_names) in self.nodes_outputs.items() if first_input_name in src_output_names - ] - source_list_W = [ - src for (src, src_output_names) in self.nodes_outputs.items() if second_input_name in src_output_names - ] - assert len(source_list_I) <= 1 - assert len(source_list_W) <= 1 - - source_I = source_list_I[0] if len(source_list_I) == 1 else self.node_id - source_W = source_list_W[0] if len(source_list_W) == 1 else self.node_id - - data["operand_source"] = { - "I": source_I, - "W": source_W, - } + IX = input_shape[2] + OX = output_shape[2] + + weight_dim = "g" if group_size > 1 else "k" + + # IMPORTANT: If any of the input loops require padding, they should be defined as the rightmost dimensions in + # the equation. This is because we construct the dimensionality order and then add the padding to those last + # dimensions in the order. + # Add information wrt how this conv node's input/output tensors are represented in the onnx model vs how they + # are represented in the equation. Because onnx doesn't actually encode the group dimension in a separate + # dimension but instead keeps it as a "groups" parameter. Concretely, this entry contains for the I and O + # operand how the G + C/K should be converted to a single "CH" (channel) dimension. + + if is_1d_conv: + # No FY, OY, IY + loop_size_dict = {"B": B, "K": K, "G": G, "OX": OX, "C": C, "FX": FX} + data["equation"] = f"O[b][g][k][ox]+=W[{weight_dim}][c][fx]*I[b][g][c][ix]" + data["pr_loop_dims"] = ["IX"] + data["pr_loop_sizes"] = [IX] + data["dimension_relations"] = [ + f"ix={strides[0]}*ox+{dilations[0]}*fx", + ] + data["padding"] = [ + [padding[0], padding[1]], + ] + else: + assert len(input_shape) == 4 and len(output_shape) == 4 and len(padding) == 4 and len(strides) == 2 + FY = kernel_shape[1] # TODO is kernel_shape in (FX, FY) format or (FY, FX)? (I assumed the former) + IY = input_shape[3] + OY = output_shape[3] + loop_size_dict = {"B": B, "K": K, "G": G, "OX": OX, "C": C, "FX": FX, "OY": OY, "FY": FY} + data["equation"] = f"O[b][g][k][oy][ox]+=W[{weight_dim}][c][fy][fx]*I[b][g][c][iy][ix]" + data["pr_loop_dims"] = ["IX", "IY"] + data["pr_loop_sizes"] = [IX, IY] + data["dimension_relations"] = [ + f"ix={strides[0]}*ox+{dilations[0]}*fx", + f"iy={strides[1]}*oy+{dilations[1]}*fy", + ] + data["padding"] = [ + [padding[0], padding[2]], + [padding[1], padding[3]], + ] + + # Remove C/K if they have size 1 + for dim in ["C", "K"]: + if loop_size_dict[dim] == 1: + del loop_size_dict[dim] + # Remove from equation + data["equation"] = data["equation"].replace(f"[{dim.lower()}]", "") + + data["loop_dims"] = list(loop_size_dict.keys()) + data["loop_sizes"] = list(loop_size_dict.values()) return data def generate_node(self): - attrs = self.node.attribute - kernel_shape: list[int] = get_attribute_ints_with_name("kernel_shape", attrs, default=None) # type:ignore - strides: list[int] = get_attribute_ints_with_name("strides", attrs, default=[1, 1]) # type:ignore - dilations: list[int] = get_attribute_ints_with_name("dilations", attrs, default=[1, 1]) # type:ignore - group_size: int = get_attribute_ints_with_name("group", attrs, default=1) # type:ignore - padding: list[int] = get_attribute_ints_with_name("pads", attrs, default=[0, 0, 0, 0]) # type:ignore # Get the input and output activation shapes - ia_dimension_shape, oa_dimension_shape = get_node_input_output_dimension_shapes(self.node, self.onnx_model) + input_shape, output_shape = get_node_input_output_dimension_shapes(self.node, self.onnx_model) node_data: dict[str, Any] = self.get_layer_node_user_format( - kernel_shape, - strides, - dilations, - group_size, - padding, - ia_dimension_shape, - oa_dimension_shape, + input_shape, + output_shape, ) node_factory = LayerNodeFactory(node_data, mapping_data=None) node_attrs = node_factory.create_node_attr() mapping = self.get_mapping_this_node() + input_names = list(self.node.input) return ComputationNode( node_id=self.node_id, @@ -139,4 +133,5 @@ def generate_node(self): mapping_attr=mapping, op_type=ConvParser.OP_TYPE, operand_tensor_reshape=None, + input_names=input_names, ) diff --git a/stream/parser/onnx/default.py b/stream/parser/onnx/default.py index 8bdd3f99..645fc88a 100644 --- a/stream/parser/onnx/default.py +++ b/stream/parser/onnx/default.py @@ -7,10 +7,12 @@ class DefaultNodeParser(OnnxOperatorParser): def generate_node(self): predecessors = self.get_node_predecessors() + input_names = list(self.node.input) return DummyNode( node_id=self.node_id, node_name=self.node.name, predecessors=predecessors, op_type=self.node.op_type.lower(), + input_names=input_names, ) diff --git a/stream/parser/onnx/einsum.py b/stream/parser/onnx/einsum.py new file mode 100644 index 00000000..8c79f124 --- /dev/null +++ b/stream/parser/onnx/einsum.py @@ -0,0 +1,96 @@ +import logging +import re +from typing import Any + +from stream.onnx_utils import get_onnx_input_shapes, get_onnx_output_shapes +from stream.parser.onnx.operator_parser import OnnxComputeOperatorParser + +logger = logging.getLogger(__name__) + + +class EinsumParser(OnnxComputeOperatorParser): + + def get_einsum_equation(self): + ATTR_NAME = "equation" + + attrs_names = [attr.name for attr in self.node.attribute] + name_idx = attrs_names.index(ATTR_NAME) + attr_proto = self.node.attribute[name_idx] + value = attr_proto.s.decode("utf-8") + return value + + def get_layer_dims_per_op(self): + einsum_equation = self.get_einsum_equation() + + return re.split(",|->", einsum_equation) + + def get_layer_equation(self, layer_dims_per_op: list[str]): + def put_in_brackets(s: str): + """e.g. `abc` -> `[a][b][c]""" + if s == "": + return "[]" + return "".join([f"[{char}]" for char in s]) + + match len(layer_dims_per_op): + case 2: + dims_I, dims_O = layer_dims_per_op + dims_W = "" + case 3: + dims_I, dims_W, dims_O = layer_dims_per_op + case _: + raise NotImplementedError + + equation = f"O{put_in_brackets(dims_O)}+=I{put_in_brackets(dims_I)}*W{put_in_brackets(dims_W)}" + return equation + + def get_layer_dim_sizes_dict(self, layer_dims_per_op: list[str]): + input_shapes = get_onnx_input_shapes(self.node, self.onnx_model) + output_shapes = get_onnx_output_shapes(self.node, self.onnx_model) + + if len(output_shapes) != 1: + raise ValueError("Einsum with more than one output not supported") + + shapes = input_shapes + output_shapes + + if len(layer_dims_per_op) != len(shapes): + raise ValueError("Einsum equation has more parts than node inputs") + + sizes_dict: dict[str, int] = {} + for layer_dims, sizes in zip(layer_dims_per_op, shapes): + if len(layer_dims) != len(sizes): + # TODO is the order of the equation guaranteed to be the same as the input order? + raise ValueError(f"Einsum equation part {layer_dims} and operand input shape {sizes} do not match") + for layer_dim, size in zip(layer_dims.upper(), sizes): + if layer_dim not in sizes_dict: + sizes_dict[layer_dim] = size + else: + if sizes_dict[layer_dim] != size: + raise ValueError(f"Not clear what the size of {layer_dim} is in Einsum") + + return sizes_dict + + def get_layer_node_user_format( + self, + input_shape: list[int], # Argument required because of a caller function in superclass + output_shape: list[int], # TODO put shape logic in this method for all `OnnxComputeOperatorParser` subclasses + ) -> dict[str, Any]: + """Generate layer data in user input format for Einsum.""" + predecessors = self.get_node_predecessors() + + data: dict[str, Any] = {} + data["id"] = self.node_id + data["name"] = self.node.name + data["operator_type"] = self.node.op_type + data["dimension_relations"] = [] + data["operand_source"] = self.get_operand_source_user_format(predecessors) + data["operand_precision"] = self.get_operand_precision_user_format() + + # + layer_dims_per_op = self.get_layer_dims_per_op() + sizes_dict = self.get_layer_dim_sizes_dict(layer_dims_per_op) + + data["loop_dims"] = list(sizes_dict.keys()) + data["loop_sizes"] = list(sizes_dict.values()) + data["equation"] = self.get_layer_equation(layer_dims_per_op) + + return data diff --git a/stream/parser/onnx/elementwise.py b/stream/parser/onnx/elementwise.py index d7b68a55..55e035d8 100644 --- a/stream/parser/onnx/elementwise.py +++ b/stream/parser/onnx/elementwise.py @@ -14,6 +14,8 @@ def __init__(self, node_id, node, nodes_outputs, mapping, onnx_model) -> None: self.name = node.name def generate_node(self): + input_names = list(self.node.input) + # Get the predecessors of this node predecessors = [] for node_input in self.node.input: @@ -28,5 +30,6 @@ def generate_node(self): node_id=self.node_id, node_name=self.name, predecessor=predecessors, + input_names=input_names, ) return node_obj diff --git a/stream/parser/onnx/flatten.py b/stream/parser/onnx/flatten.py index 215f6676..35e4f0c0 100644 --- a/stream/parser/onnx/flatten.py +++ b/stream/parser/onnx/flatten.py @@ -1,5 +1,3 @@ -from zigzag.parser.onnx.utils import get_attribute_ints_with_name - from stream.parser.onnx.operator_parser import OnnxOperatorParser from stream.workload.dependency_propagation.flatten_node import FlattenNode @@ -12,12 +10,13 @@ def generate_node(self): assert len(predecessors) == 1 predecessor = predecessors[0] - attrs = self.node.attribute - # Get the axis which indicates how to flatten the input tensor - axis: int | None = get_attribute_ints_with_name("axis", attrs, default=None) # type: ignore + input_names = list(self.node.input) + axis = self.get_axis_attribute() + return FlattenNode( node_id=self.node_id, node_name=self.node.name, predecessor=predecessor, axis=axis, + input_names=input_names, ) diff --git a/stream/parser/onnx/gather.py b/stream/parser/onnx/gather.py index e7a32cc9..b9c3fde2 100644 --- a/stream/parser/onnx/gather.py +++ b/stream/parser/onnx/gather.py @@ -9,8 +9,9 @@ class GatherParser(OnnxOperatorParser): def generate_node(self): predecessors = self.get_node_predecessors() - axis = self.get_axis_value() + axis = self.get_axis_attribute() indices = self.get_indices_value() + input_names = list(self.node.input) return GatherNode( node_id=self.node_id, @@ -18,6 +19,7 @@ def generate_node(self): predecessors=predecessors, gather_axis=axis, gather_indices=indices, + input_names=input_names, ) def get_indices_value(self): @@ -39,13 +41,3 @@ def get_indices_value(self): indices = DEFAULT return indices - - def get_axis_value(self): - """Find the value of the axis associated with this gather node in ONNX""" - # `axis` is an attribute of the node - try: - axis_attr = next(filter(lambda x: x.name == "axis", self.node.attribute)) - axis = axis_attr.i - except StopIteration: - axis = 0 - return axis diff --git a/stream/parser/onnx/lpnormalization.py b/stream/parser/onnx/lpnormalization.py index 0ca2569f..6f4ddc5b 100644 --- a/stream/parser/onnx/lpnormalization.py +++ b/stream/parser/onnx/lpnormalization.py @@ -11,6 +11,8 @@ def __init__(self, node_id, node, nodes_outputs, mapping, onnx_model) -> None: super().__init__(node_id, node, nodes_outputs, mapping, onnx_model) def generate_node(self): + input_names = list(self.node.input) + # Get the predecessors of this node # TODO use superclass' `get_node_predecessors` predecessors = [] @@ -23,5 +25,6 @@ def generate_node(self): node_id=self.node_id, node_name=self.node_name, predecessor=self.predecessor, + input_names=input_names, ) return node_obj diff --git a/stream/parser/onnx/model.py b/stream/parser/onnx/model.py index 3de76808..7109471a 100644 --- a/stream/parser/onnx/model.py +++ b/stream/parser/onnx/model.py @@ -5,22 +5,25 @@ from zigzag.parser.onnx.utils import parse_onnx_model_from_path from stream.hardware.architecture.accelerator import Accelerator -from stream.parser.onnx.asymmetric_simd import AsymmetricSimdParser from stream.parser.onnx.concat import ConcatParser from stream.parser.onnx.conv import ConvParser from stream.parser.onnx.default import DefaultNodeParser +from stream.parser.onnx.einsum import EinsumParser from stream.parser.onnx.flatten import FlattenParser from stream.parser.onnx.gather import GatherParser from stream.parser.onnx.gemm import GemmParser from stream.parser.onnx.lpnormalization import LpNormalizationParser from stream.parser.onnx.matmul import MatMulParser +from stream.parser.onnx.mul import MulParser from stream.parser.onnx.operator_parser import OnnxOperatorParser from stream.parser.onnx.pooling import PoolingParser +from stream.parser.onnx.reduce_1d import Reduce1DParser from stream.parser.onnx.reshape import ReshapeParser from stream.parser.onnx.simd import SimdParser +from stream.parser.onnx.slice import SliceParser from stream.parser.onnx.softmax import SoftmaxParser +from stream.parser.onnx.split import SplitParser from stream.parser.onnx.transpose import TransposeParser -from stream.utils import get_onnx_input_shapes, has_asymmetric_input_data from stream.workload.mapping import InterCoreMappingAttributes from stream.workload.onnx_workload import ONNXWorkload @@ -32,26 +35,37 @@ class ONNXModelParser: # Map the node's op_type to the corresponding Parser class OP_TYPE_TO_PARSER: dict[str, Type[OnnxOperatorParser]] = { + # General "QLinearConv": ConvParser, "Conv": ConvParser, "MatMul": MatMulParser, "Gemm": GemmParser, + "Einsum": EinsumParser, "MaxPool": PoolingParser, "AveragePool": PoolingParser, "GlobalMaxPool": PoolingParser, "GlobalAveragePool": PoolingParser, - "Add": SimdParser, - "Mul": SimdParser, + "Add": MulParser, + "Mul": MulParser, "Softmax": SoftmaxParser, + # Single-input element-wise + "ReduceMean": Reduce1DParser, "Relu": SimdParser, "Gelu": SimdParser, "Silu": SimdParser, + "Sqrt": SimdParser, + "Div": SimdParser, + "Pow": SimdParser, + "Reciprocal": SimdParser, # Div with 1 as numerator + # Dependency propagation "LpNormalization": LpNormalizationParser, "Gather": GatherParser, "Transpose": TransposeParser, "Reshape": ReshapeParser, "Flatten": FlattenParser, "Concat": ConcatParser, + "Split": SplitParser, + "Slice": SliceParser, } def __init__( @@ -71,15 +85,14 @@ def run(self): self.workload = self.parse_workload() def get_parser_class(self, node: NodeProto): - # A temporary fix an element-wise Add or Mul which has asymmetric input data -> treat it as a DummyNode. - # TODO support node with asymmetric input data. - if node.op_type in ["Add", "Mul"] and has_asymmetric_input_data(node, self.onnx_model): - in_shape_1, in_shape_2 = get_onnx_input_shapes(node, self.onnx_model) - # In case only the batch dimension is missing. Other cases are not supported for now - if abs(len(in_shape_1) - len(in_shape_2)) == 1: - return AsymmetricSimdParser - else: - return DefaultNodeParser + # # A temporary fix an element-wise Add which has asymmetric input data -> treat it as a DummyNode. + # if node.op_type in ["Add", "Mul"] and has_asymmetric_input_data(node, self.onnx_model): + # in_shape_1, in_shape_2 = get_onnx_input_shapes(node, self.onnx_model) + # # In case only the batch dimension is missing. Other cases are not supported for now + # if abs(len(in_shape_1) - len(in_shape_2)) == 1: + # return AsymmetricSimdParser + # else: + # return DefaultNodeParser parser_class = ONNXModelParser.OP_TYPE_TO_PARSER.get(node.op_type) if not parser_class: diff --git a/stream/parser/onnx/mul.py b/stream/parser/onnx/mul.py new file mode 100644 index 00000000..612a9406 --- /dev/null +++ b/stream/parser/onnx/mul.py @@ -0,0 +1,103 @@ +from typing import Any + +from stream.onnx_utils import get_onnx_input_shapes, get_onnx_output_shapes +from stream.parser.onnx.operator_parser import OnnxComputeOperatorParser + + +class MulParser(OnnxComputeOperatorParser): + """Parses an ONNX operator representing an elementwise operation (Mul) into a ComputationNode.""" + + def get_common_and_broadcast_shape(self): + """This node assumes that the ONNX node has 2 inputs and 1 output. One input shape is identical to the output + shape, and the other shape can broadcast in dimensions. + Returns the common shape (in and out) and the broadcast shape""" + input_shapes = get_onnx_input_shapes(self.node, self.onnx_model) + output_shapes = get_onnx_output_shapes(self.node, self.onnx_model) + + if len(input_shapes) != 2 or len(output_shapes) != 1: + raise NotImplementedError + + output_shape = output_shapes.pop() + if not any(shape == output_shape for shape in input_shapes): + raise NotImplementedError + + input_shape = output_shape + input_shapes.remove(output_shape) + broadcast_shape = input_shapes.pop() + + # e.g. (3,5) * (8,3,5) is ok (broadcast over dim 0), but (3,2) * (8,3,5) is unclear + for broadcast_size, in_size in zip(reversed(broadcast_shape), reversed(input_shape)): + if broadcast_size != in_size and broadcast_size != 1: + raise ValueError + + return input_shape, broadcast_shape + + def get_operand_source_input_format(self, shape_of_w: list[int]): + """This method needs more care in this subclass, since the equation assumes that the input with 'broadcast' + shape is always at `W`""" + predecessors = self.get_node_predecessors() + match len(predecessors): + case 0: + # e.g. first node of graph + return {"W": self.node_id, "I": self.node_id} + case 1: + # One source operand, one constant + return {"W": self.node_id, "I": predecessors[0]} + case 2: + # Two source operands, none are constant + # Name of the input that corresponds to the W shape + broadcast_intput = self.node.input[get_onnx_input_shapes(self.node, self.onnx_model).index(shape_of_w)] + try: + node_id_W = next( + node_id + for node_id, outputs in self.nodes_outputs.items() + if broadcast_intput in outputs and node_id in predecessors + ) + node_id_I = ( + node_id_W + if predecessors[0] == predecessors[1] + else next(i for i in predecessors if i != node_id_W) + ) + return {"W": node_id_W, "I": node_id_I} + except StopIteration: + raise ValueError(f"Cannot find correct inputs of {self .node.name}") + case _: + raise ValueError("No more than 2 layer predecessors expected") + + def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[int]): + """ + Generate the necessary dictionary items required for the LayerNode creation. + """ + common_shape, broadcast_shape = self.get_common_and_broadcast_shape() + + data: dict[str, Any] = {} + data["id"] = self.node_id + data["name"] = self.node.name + data["operator_type"] = self.node.op_type + data["operand_source"] = self.get_operand_source_input_format(shape_of_w=broadcast_shape) + data["operand_precision"] = self.get_operand_precision_user_format() + data["dimension_relations"] = [] + data["loop_sizes"] = common_shape + + match len(common_shape): + case 1: + loop_dims = ["K"] + case 2: + loop_dims = ["D", "K"] + case 3: + loop_dims = ["B", "D", "K"] + case 4: + loop_dims = ["B", "H", "D", "K"] + case _: + raise NotImplementedError + + loop_dims_broadcast = reversed([dim for dim, _ in zip(reversed(loop_dims), reversed(broadcast_shape))]) + + equation_dims_common = "".join([f"[{dim.lower()}]" for dim in loop_dims]) + equation_dims_broadcast = "".join([f"[{dim.lower()}]" for dim in loop_dims_broadcast]) + equation = f"O{equation_dims_common}+=I{equation_dims_common}*W{equation_dims_broadcast}" + + data["loop_dims"] = loop_dims + data["equation"] = equation + + return data diff --git a/stream/parser/onnx/operator_parser.py b/stream/parser/onnx/operator_parser.py index 343b2665..e288d18d 100644 --- a/stream/parser/onnx/operator_parser.py +++ b/stream/parser/onnx/operator_parser.py @@ -7,6 +7,7 @@ from zigzag.parser.workload_factory import LayerNodeFactory from stream.hardware.architecture.accelerator import Accelerator +from stream.onnx_utils import get_axis_attribute from stream.workload.computation.computation_node import ComputationNode from stream.workload.mapping import InterCoreMappingAttributes from stream.workload.node import Node @@ -39,6 +40,9 @@ def generate_node(self) -> Node: ... def get_operand_source_input_format(self): predecessors = self.get_node_predecessors() match len(predecessors): + case 0: + # e.g. first node of graph + return {"W": self.node_id, "I": self.node_id} case 1: # One source operand, one constant return {"W": self.node_id, "I": predecessors[0]} @@ -49,6 +53,9 @@ def get_operand_source_input_format(self): case _: raise ValueError("No more than 2 layer predecessors expected") + def get_axis_attribute(self): + return get_axis_attribute(self.node) + class OnnxComputeOperatorParser(OnnxOperatorParser, metaclass=ABCMeta): @@ -58,12 +65,20 @@ def run(self) -> Generator[ComputationNode, None, None]: @abstractmethod def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[int]) -> dict[str, Any]: ... - def get_operand_precision_input_format(self) -> dict[str, int]: - act_precision = self.get_activation_precision() - weight_precision = self.get_weight_precision() - intermediate_output_precision = self.get_intermediate_output_precision() + def get_operand_precision_user_format(self) -> dict[str, int]: + act_precision: int = self.get_activation_precision() + weight_precision: int = self.get_weight_precision() + intermediate_output_precision: int = self.get_intermediate_output_precision() predecessors = self.get_node_predecessors() match len(predecessors): + case 0: + # e.g. the first node in the network -> assume only one variable input + return { + "W": weight_precision, + "I": act_precision, + "O_final": act_precision, + "O": intermediate_output_precision, + } case 1: # One source operand, one constant return { @@ -120,8 +135,8 @@ def generate_node(self): node_data = self.get_layer_node_user_format(input_shape, output_shape) node_factory = LayerNodeFactory(node_data, mapping_data=[]) node_attrs = node_factory.create_node_attr() - mapping = self.get_mapping_this_node() + input_names = list(self.node.input) return ComputationNode( node_id=self.node_id, @@ -129,4 +144,5 @@ def generate_node(self): op_type=self.node.op_type, node_attr=node_attrs, mapping_attr=mapping, + input_names=input_names, ) diff --git a/stream/parser/onnx/pooling.py b/stream/parser/onnx/pooling.py index ff120f9f..780efbec 100644 --- a/stream/parser/onnx/pooling.py +++ b/stream/parser/onnx/pooling.py @@ -117,10 +117,12 @@ def generate_node(self): node_factory = LayerNodeFactory(node_data, None) node_attrs = node_factory.create_node_attr() mapping = self.get_mapping_this_node() + input_names = list(self.node.input) return PoolingNode( node_id=self.node_id, node_name=self.node.name, node_attr=node_attrs, mapping_attr=mapping, + input_names=input_names, ) diff --git a/stream/parser/onnx/reduce_1d.py b/stream/parser/onnx/reduce_1d.py index be4ecc1e..26f8d7ff 100644 --- a/stream/parser/onnx/reduce_1d.py +++ b/stream/parser/onnx/reduce_1d.py @@ -8,33 +8,65 @@ class Reduce1DParser(OnnxComputeOperatorParser): e.g. sum over one row or max of a single row """ + def get_reduction_dim(self, input_shape: list[int], output_shape: list[int]): + """Returns the axis in which the dimension is reduced""" + + # The case that keepdim=True: the reduced dimension is kept with size 1 + if len(input_shape) == len(output_shape): + different_size = [a != b for a, b in zip(input_shape, output_shape)] + if sum(different_size) != 1: + raise ValueError(f"Input and output shapes {input_shape}, {output_shape} should only differ in one dim") + reduction_dim = different_size.index(True) + if output_shape[reduction_dim] != 1: + raise ValueError(f"The reduced dimension at axis {reduction_dim} in {output_shape} is larger than 1") + return reduction_dim + + # Other: assume that the reduction is at axis=-1 + if not all(a == b for a, b in zip(input_shape, output_shape)): + raise NotImplementedError("Reduce node with reduction axis other than -1 not implemented yet.") + reduction_dim = len(input_shape) - 1 # Last dimension + def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[int]): """ Generate the necessary dictionary items required for the LayerNode creation. """ - # TODO check the output shape as well? - assert len(self.get_node_predecessors()) == 1 + if len(self.get_node_predecessors()) != 1: + raise NotImplementedError + + if self.get_reduction_dim(input_shape, output_shape) != len(input_shape) - 1: + raise NotImplementedError("Only reduction in axis=-1 is supported") + + # This is a ONNX node property but can be inferred from the shapes + keep_dim = len(input_shape) == len(output_shape) data: dict[str, Any] = {} data["id"] = self.node_id data["name"] = self.node.name data["operator_type"] = self.node.op_type data["operand_source"] = self.get_operand_source_input_format() - data["operand_precision"] = self.get_operand_precision_input_format() + data["operand_precision"] = self.get_operand_precision_user_format() data["dimension_relations"] = [] data["loop_sizes"] = input_shape + # C is always the reduction dim + # If keep_dim: add an arbitrary dim of size 1 + reduced_dim_output = "CR" # C reduced to 1 + eq_part_CR = f"[{reduced_dim_output}]" if keep_dim else "" match len(input_shape): case 2: - data["equation"] = "O[k]+=I[k][c]*W[]" + data["equation"] = f"O[k]{eq_part_CR}+=I[k][c]*W[]" data["loop_dims"] = ["K", "C"] case 3: - data["equation"] = "O[b][k]+=I[b][k][c]*W[]" + data["equation"] = f"O[b][k]{eq_part_CR}+=I[b][k][c]*W[]" data["loop_dims"] = ["B", "K", "C"] case 4: - data["equation"] = "O[b][h][k]+=I[b][h][k][c]*W[]" + data["equation"] = f"O[b][h][k]{eq_part_CR}+=I[b][h][k][c]*W[]" data["loop_dims"] = ["B", "H", "K", "C"] case _: raise NotImplementedError + if keep_dim: + data["loop_dims"] += [reduced_dim_output] + data["loop_sizes"] += [1] + return data diff --git a/stream/parser/onnx/reshape.py b/stream/parser/onnx/reshape.py index 325eb378..1ed9c193 100644 --- a/stream/parser/onnx/reshape.py +++ b/stream/parser/onnx/reshape.py @@ -14,10 +14,12 @@ def generate_node(self): # The operator shape is saved as the second input, so we need to get the input's dimension shape shape = tuple(get_node_input_output_dimension_shapes(self.node, self.onnx_model)[1]) + input_names = list(self.node.input) return ReshapeNode( node_id=self.node_id, node_name=self.node.name, predecessor=predecessor, shape=shape, + input_names=input_names, ) diff --git a/stream/parser/onnx/simd.py b/stream/parser/onnx/simd.py index 2c5ae21d..32d83b1e 100644 --- a/stream/parser/onnx/simd.py +++ b/stream/parser/onnx/simd.py @@ -6,6 +6,7 @@ class SimdParser(OnnxComputeOperatorParser): """Parses an ONNX operator representing an elementwise operation (simd) into a ComputationNode. e.g. Add, etc. + # TODO this functionality is exactly the same as Mul but without support for broadcast (asymmetric) shapes """ def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[int]): @@ -22,7 +23,7 @@ def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[ data["name"] = self.node.name data["operator_type"] = self.node.op_type data["operand_source"] = self.get_operand_source_input_format() - data["operand_precision"] = self.get_operand_precision_input_format() + data["operand_precision"] = self.get_operand_precision_user_format() data["dimension_relations"] = [] data["loop_sizes"] = output_shape diff --git a/stream/parser/onnx/slice.py b/stream/parser/onnx/slice.py new file mode 100644 index 00000000..113b5d8e --- /dev/null +++ b/stream/parser/onnx/slice.py @@ -0,0 +1,33 @@ +from stream.onnx_utils import get_slice_attributes +from stream.parser.onnx.operator_parser import OnnxOperatorParser +from stream.workload.dependency_propagation.slice_node import SliceNode + + +class SliceParser(OnnxOperatorParser): + """Parses an onnx gather operator into a SliceNode.""" + + def generate_node(self): + if len(self.node.output) > 1: + raise NotImplementedError("Slice node with multiple output slices not yet supported.") + + # Single predecessor + predecessors = self.get_node_predecessors() + if len(predecessors) > 1: + raise ValueError("Slice node should not have more than one input") + predecessor = predecessors.pop() + + starts_value, ends_value, axes_value, steps_value = get_slice_attributes(self.node, self.onnx_model) + input_names = list(self.node.input) + output_names = list(self.node.output) + + return SliceNode( + node_id=self.node_id, + node_name=self.node.name, + predecessor=predecessor, + starts=starts_value, + ends=ends_value, + axes=axes_value, + steps=steps_value, + input_names=input_names, + output_names=output_names, + ) diff --git a/stream/parser/onnx/softmax.py b/stream/parser/onnx/softmax.py index 3f0a0506..25703b26 100644 --- a/stream/parser/onnx/softmax.py +++ b/stream/parser/onnx/softmax.py @@ -93,7 +93,7 @@ def get_layer_node_user_format(self, input_shape: list[int], output_shape: list[ data["name"] = self.node.name data["operator_type"] = self.node.op_type data["operand_source"] = self.get_operand_source_input_format() - data["operand_precision"] = self.get_operand_precision_input_format() + data["operand_precision"] = self.get_operand_precision_user_format() data["dimension_relations"] = [] data["loop_sizes"] = input_shape diff --git a/stream/parser/onnx/split.py b/stream/parser/onnx/split.py new file mode 100644 index 00000000..95d1967b --- /dev/null +++ b/stream/parser/onnx/split.py @@ -0,0 +1,32 @@ +from stream.onnx_utils import get_split_attribute +from stream.parser.onnx.operator_parser import OnnxOperatorParser +from stream.workload.dependency_propagation.split_node import SplitNode + + +class SplitParser(OnnxOperatorParser): + """Parses an onnx gather operator into a SplitNode.""" + + def generate_node(self): + # Single predecessor + predecessors = self.get_node_predecessors() + if len(predecessors) > 1: + raise ValueError("Split node should not have more than one input") + predecessor = predecessors.pop() + + axis = self.get_axis_attribute() + splits = get_split_attribute(self.node, self.onnx_model) + input_names = list(self.node.input) + output_names = list(self.node.output) + + if len(splits) != len(output_names): + raise ValueError + + return SplitNode( + node_id=self.node_id, + node_name=self.node.name, + predecessor=predecessor, + axis=axis, + splits=splits, + input_names=input_names, + output_names=output_names, + ) diff --git a/stream/parser/onnx/transpose.py b/stream/parser/onnx/transpose.py index ba6dae2a..0b3bcb7a 100644 --- a/stream/parser/onnx/transpose.py +++ b/stream/parser/onnx/transpose.py @@ -11,12 +11,14 @@ def generate_node(self): predecessor = predecessors.pop() permute_axes = self.get_permute_indices() + input_names = list(self.node.input) return TransposeNode( node_id=self.node_id, node_name=self.node.name, predecessor=predecessor, permute_axes=permute_axes, + input_names=input_names, ) def get_permute_indices(self): diff --git a/stream/stages/generation/tiled_workload_generation.py b/stream/stages/generation/tiled_workload_generation.py index df5b4953..5bb772f3 100644 --- a/stream/stages/generation/tiled_workload_generation.py +++ b/stream/stages/generation/tiled_workload_generation.py @@ -19,13 +19,7 @@ from stream.workload.computation.computation_node import ComputationNode, LoopRanges from stream.workload.dependency_propagation.concat_node import ConcatNode from stream.workload.dependency_propagation.dummy_node import DummyNode -from stream.workload.dependency_propagation.elementwise_node import ElementwiseNode -from stream.workload.dependency_propagation.flatten_node import FlattenNode -from stream.workload.dependency_propagation.gather_node import GatherNode -from stream.workload.dependency_propagation.lpnormalization_node import LpNormalizationNode -from stream.workload.dependency_propagation.reshape_node import ReshapeNode -from stream.workload.dependency_propagation.transpose_node import TransposeNode -from stream.workload.dnn_workload import DNNWorkloadStream +from stream.workload.dependency_propagation.propagation_node import PropagationNode from stream.workload.node import Node from stream.workload.onnx_workload import ComputationNodeWorkload, ONNXWorkload from stream.workload.tensor import Tensor @@ -128,7 +122,7 @@ def get_scheduling_order(workload: ComputationNodeWorkload): return sorted(((n.id, n.sub_id) for n in workload.node_list), reverse=True) @staticmethod - def get_all_node_pairs(G: DNNWorkloadStream) -> tuple[tuple[ComputationNode, ComputationNode, bool], ...]: + def get_all_node_pairs(G: ONNXWorkload) -> tuple[tuple[ComputationNode, ComputationNode, bool], ...]: pairs: list[tuple[ComputationNode, ComputationNode, bool]] = [] for node in G.topological_sort(): if not isinstance(node, ComputationNode): @@ -382,6 +376,7 @@ def get_bounding_box_dimensions( # where the onnx tensors are always flattened back to 4D (merging the G+C or G+K into one channel dimension) dimensions, loop_ranges = self.flatten_grouped_convolution_ranges(producer, consumer, dimensions, loop_ranges) bounding_box = [loop_ranges[dim] for dim in dimensions] + # TODO can bounding box have size 1? Will probably crash if so if not interleaved: bounding_box_flat = tuple([item for sublist in bounding_box for item in sublist]) @@ -401,6 +396,12 @@ def bounding_box_generator( inclusive_ranges = self.convert_to_inclusive_data_range(node.loop_ranges) dimensions = node.operand_dimensionality_order[operand] bounds = self.get_bounding_box_dimensions(producer, consumer, dimensions, inclusive_ranges) + + # TODO this is a whacky fix + # RTree doesn't accept bound of one dimension + if len(bounds) == 2: + bounds = (0, 0) + bounds + yield (i, bounds, None) def get_nb_input_dimensions(self, node: ComputationNode, operand: LayerOperand): @@ -422,7 +423,7 @@ def build_rtree( """ props = index.Property() # We assume all nodes in 'nodes' have identical dimensions - props.dimension = self.get_nb_input_dimensions(nodes[0], operand) + props.dimension = max(self.get_nb_input_dimensions(nodes[0], operand), 2) rtree = index.Index(self.bounding_box_generator(producer, consumer, nodes, operand), properties=props) return rtree @@ -585,6 +586,7 @@ def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand) assert ( len(paths_between) > 0 ), "No paths between producer and consumer found without ComputationNode in intermediates." + for path_between in paths_between: # First node in the path is a ComputationNode, of which we extract the output operand dependency tensor first_node = path_between[0] @@ -592,10 +594,10 @@ def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand) tensor = get_tensor_cn_for_op(first_node, dependent_operand=Constants.OUTPUT_LAYER_OP) # Propagate through intermediate, non-computation nodes - for _, node in enumerate(path_between[1:-1], start=1): - if isinstance(node, ComputationNode): - raise ValueError("Intermediate nodes should not be of type ComputationNode.") - tensor = self.propagate_cn_production_for_non_cn(node, tensor) + for i, node in enumerate(path_between[1:-1], start=1): + assert isinstance(node, PropagationNode), "Intermediate nodes should not be of type ComputationNode" + next_node = path_between[i + 1] + tensor = node.propagate(tensor, next_node) # Final node: Computation node final_node: ComputationNode = path_between[-1] # type: ignore @@ -607,7 +609,7 @@ def get_tensor_cn_for_op(node: ComputationNode, dependent_operand: LayerOperand) ) # Error handling of shape mismatches in tensor propagation - def get_final_tensor_alt_operand(): + def _get_final_tensor_alt_operand(): """Error handling case 1: sources for `W` and `I` operand are swapped for this node -> try the other one""" try: @@ -617,7 +619,7 @@ def get_final_tensor_alt_operand(): raise TensorDimensionMismatchException return get_tensor_cn_for_op(final_node, alt_operand) - def get_shape_inferred_propagated_tensor(tensor: NodeTensor, final_tensor: NodeTensor): + def _get_shape_inferred_propagated_tensor(tensor: NodeTensor, final_tensor: NodeTensor): """Error handling case 2: dimensions of ComputationNode (`final_tensor`) were altered by stream (e.g. to be properly divisible) but this is not reflected in `ConcatNode` with constant shape. -> manually fix shape""" @@ -644,17 +646,17 @@ def get_shape_inferred_propagated_tensor(tensor: NodeTensor, final_tensor: NodeT inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor) except TensorDimensionMismatchException: try: # Error case 1 - final_tensor = get_final_tensor_alt_operand() + final_tensor = _get_final_tensor_alt_operand() inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor) except TensorDimensionMismatchException: try: # Error case 2 final_tensor = get_tensor_cn_for_op(final_node, dependent_operand) - tensor = get_shape_inferred_propagated_tensor(tensor, final_tensor) + tensor = _get_shape_inferred_propagated_tensor(tensor, final_tensor) inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor) except TensorDimensionMismatchException: # Error case 1 and 2 combined - final_tensor = get_final_tensor_alt_operand() - tensor = get_shape_inferred_propagated_tensor(tensor, final_tensor) + final_tensor = _get_final_tensor_alt_operand() + tensor = _get_shape_inferred_propagated_tensor(tensor, final_tensor) inter_edges = self.get_inter_edges_tensor_based(tensor, final_tensor) for producer, cons in inter_edges: @@ -670,27 +672,6 @@ def get_shape_inferred_propagated_tensor(tensor: NodeTensor, final_tensor: NodeT ) return all_inter_edges - def propagate_cn_production_for_non_cn(self, node: Node, input_tensor: NodeTensor) -> NodeTensor: - match node: - case ReshapeNode(): - return node.reshape_operand_tensor(input_tensor) - case TransposeNode(): - return node.transpose(input_tensor) - case LpNormalizationNode(): - return node.lpnormalization_operand_tensor(input_tensor) - case FlattenNode(): - return node.flatten(input_tensor) - case ElementwiseNode(): - return input_tensor.copy() - case GatherNode(): - return node.gather_operand_tensor(input_tensor) - case ConcatNode(): - return node.concat(input_tensor) - case DummyNode(): - return input_tensor - case _: - raise NotImplementedError(f"Tensor propagation not implemented for node {node.name}.") - @staticmethod def get_inter_edges_tensor_based(producer_output_tensor: NodeTensor, consumer_input_tensor: NodeTensor): """This method obtains the edges between a producer and consumer. diff --git a/stream/stages/generation/tiling_generation.py b/stream/stages/generation/tiling_generation.py index 829c8b13..70b3191d 100644 --- a/stream/stages/generation/tiling_generation.py +++ b/stream/stages/generation/tiling_generation.py @@ -1,4 +1,5 @@ import logging +from collections import defaultdict from typing import Any import numpy as np @@ -16,6 +17,31 @@ class TilingGenerationStage(Stage): + # Split the node in this dimension to enable fusion within core + FUSION_PARTITION_DIM_DEFAULT: defaultdict[str, LayerDim] = defaultdict( + lambda: LayerDim("K"), + { + "conv": LayerDim("OY"), + "matmul": LayerDim("D"), + "gemm": LayerDim("D"), + "pooling": LayerDim("OY"), + "add": LayerDim("D"), + "mul": LayerDim("D"), + "softmax": LayerDim("K"), + "max": LayerDim("K"), + "div": LayerDim("K"), + "exp": LayerDim("K"), + "sum": LayerDim("K"), + "relu": LayerDim("K"), + "gelu": LayerDim("K"), + "silu": LayerDim("K"), + }, + ) + FUSION_PARTITION_SIZE_DEFAULT = 2 + + # Split node in this dimension to partition layer over cores. NOTE this list is ordered + INTER_CORE_PARTITION_DIM_DEFAULT = [LayerDim("G"), LayerDim("H"), LayerDim("K")] + def __init__( self, list_of_callables: list[StageCallable], @@ -109,11 +135,22 @@ def remove_invalid_entries_from_intra_core_tiling(self, node: ComputationNode): node.intra_core_tiling = valid_tiling - def generate_intra_core_tiling(self, node: ComputationNode) -> TILING_T: - partition_dim = node.fusion_partition_dims[0] + def get_fusion_partition_dim(self, node: ComputationNode) -> LayerDim: + partition_dim = TilingGenerationStage.FUSION_PARTITION_DIM_DEFAULT[node.type] + + # Default partition dim is not present in this node -> take some arbitrary other dim if partition_dim not in node.layer_dim_sizes: - raise ValueError(f"Suggested partition dimension {partition_dim} for {node} is not part of this node") - return [(node.fusion_partition_dims[0], node.layer_dim_sizes[partition_dim])] + partition_dim: LayerDim = next( + dim for dim in node.layer_dim_sizes if dim != LayerDim("B") and dim != LayerDim("G") + ) + + return partition_dim + + def generate_intra_core_tiling(self, node: ComputationNode) -> TILING_T: + partition_dim = self.get_fusion_partition_dim(node) + size = min(TilingGenerationStage.FUSION_PARTITION_SIZE_DEFAULT, node.layer_dim_sizes[partition_dim]) + tiling = [(partition_dim, size)] + return tiling def remove_invalid_entries_from_inter_core_tiling(self, node: ComputationNode): """Check wether this node's inter core tiling has invalid entries: non-existent layer dimension for this node @@ -143,14 +180,12 @@ def remove_invalid_entries_from_inter_core_tiling(self, node: ComputationNode): node.inter_core_tiling = valid_tiling def generate_inter_core_tiling(self, node: ComputationNode) -> TILING_T: - if node.layer_dim_sizes.data.get(LayerDim("G"), 1) > 1: - loop_dim = LayerDim("G") - elif node.layer_dim_sizes.data.get(LayerDim("K"), 1) > 1: - loop_dim = LayerDim("K") - else: - raise ValueError("Unknown what loop dim to split across cores") - - return [(loop_dim, "*")] + for dim in TilingGenerationStage.INTER_CORE_PARTITION_DIM_DEFAULT: + if dim in node.layer_dim_sizes and node.layer_dim_sizes[dim] > 1: + return [(dim, "*")] + + # No valid dim found -> just take someting + return [(next(iter(node.layer_dim_sizes)), "*")] @staticmethod def split_operator(model: ModelProto, node_name: str, num_splits: int): diff --git a/stream/utils.py b/stream/utils.py index 67f93f7a..06328b57 100644 --- a/stream/utils.py +++ b/stream/utils.py @@ -4,11 +4,9 @@ from typing import TYPE_CHECKING, Any, TypeAlias from numpy.typing import NDArray -from onnx import ModelProto, NodeProto from zigzag.cost_model.cost_model import CostModelEvaluation from zigzag.datatypes import MemoryOperand from zigzag.mapping.data_movement import FourWayDataMoving -from zigzag.parser.onnx.utils import get_onnx_tensor_type from stream.hardware.architecture.core import Core from stream.workload.mapping import TILING_T @@ -21,25 +19,6 @@ ARRAY_T: TypeAlias = NDArray[Any] -def get_onnx_input_shapes(node: NodeProto, onnx_model: ModelProto) -> tuple[list[int], list[int]]: - if len(node.input) != 2: - raise ValueError(f"Node {node.name} does not have two inputs") - input_name1 = node.input[0] - input_name2 = node.input[1] - input_shape1 = get_onnx_tensor_type(input_name1, onnx_model).shape - input_shape2 = get_onnx_tensor_type(input_name2, onnx_model).shape - return input_shape1, input_shape2 - - -def has_asymmetric_input_data(node: NodeProto, onnx_model: ModelProto): - """Return true iff the node has two inputs and the input nodes have a different shape""" - if len(node.input) != 2: - return False - - input_shape1, input_shape2 = get_onnx_input_shapes(node, onnx_model) - return input_shape1 != input_shape2 - - def get_too_large_operands(cme: CostModelEvaluation, accelerator: "Accelerator", core_id: int) -> list[MemoryOperand]: """Create a list of memory operands for which an extra memory level (i.e. offchip) was added. diff --git a/stream/workload/computation/computation_node.py b/stream/workload/computation/computation_node.py index cad237c5..2cdc51af 100644 --- a/stream/workload/computation/computation_node.py +++ b/stream/workload/computation/computation_node.py @@ -32,24 +32,6 @@ class ComputationNode(LayerNode, Node): too_large_operands: list[MemoryOperand] - # Map the node's op_type to the corresponding layer dimension to split on for fusion - FUSION_DIM_MAPPING: dict[str, list[LayerDim]] = { - "conv": [LayerDim("OY")], - "matmul": [LayerDim("D")], - "gemm": [LayerDim("D")], - "pooling": [LayerDim("OY")], - "add": [LayerDim("D")], - "mul": [LayerDim("D")], - "softmax": [LayerDim("K")], - "max": [LayerDim("K")], - "div": [LayerDim("K")], - "exp": [LayerDim("K")], - "sum": [LayerDim("K")], - "relu": [LayerDim("K")], - "gelu": [LayerDim("K")], - "silu": [LayerDim("K")], - } # TODO default to "K" ? - def __init__( self, node_id: int, @@ -61,6 +43,7 @@ def __init__( produces_final_output: bool = False, group_id: int = 0, sub_id: int = -1, # To distinguish alternative versions of this node + input_names: list[str] = [], ): op_type = op_type.lower() @@ -76,6 +59,7 @@ def __init__( offchip_energy=0, runtime=0, possible_core_allocation=mapping_attr.core_allocation, + input_names=input_names, ) # Overwrite default spatial mapping with given one @@ -111,11 +95,6 @@ def __init__( self.nb_real_predecessors = None self._static_hash_value = self.__compute_static_hash() - try: - self.fusion_partition_dims = ComputationNode.FUSION_DIM_MAPPING[op_type] - except KeyError: - raise NotImplementedError(f"Fusion partitioning dimensions not defined for {op_type}") - # Each ComputationNode will save a tensor for all its defined operands. # For example, a conv layer will have an I tensor, W tensor and O tensor. self.operand_tensors: dict[LayerOperand, Tensor] = {} diff --git a/stream/workload/computation/pooling_node.py b/stream/workload/computation/pooling_node.py index 0c4151c5..74a27169 100644 --- a/stream/workload/computation/pooling_node.py +++ b/stream/workload/computation/pooling_node.py @@ -5,12 +5,15 @@ class PoolingNode(ComputationNode): + """TODO this node can be replaced by instantiating ComputationNode directly""" + def __init__( self, node_id: int, node_name: str, node_attr: LayerNodeAttributes, mapping_attr: InterCoreMappingAttributes, + input_names: list[str] = [], ): super().__init__( node_id=node_id, @@ -18,4 +21,5 @@ def __init__( node_attr=node_attr, mapping_attr=mapping_attr, op_type="pooling", + input_names=input_names, ) diff --git a/stream/workload/dependency_propagation/concat_node.py b/stream/workload/dependency_propagation/concat_node.py index 113aba48..acd956a5 100644 --- a/stream/workload/dependency_propagation/concat_node.py +++ b/stream/workload/dependency_propagation/concat_node.py @@ -1,11 +1,11 @@ from zigzag.datatypes import LayerOperand -from zigzag.workload.layer_node_abc import LayerNodeABC from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode from stream.workload.node import Node -class ConcatNode(Node, LayerNodeABC): +class ConcatNode(PropagationNode): """Class that represents an onnx Concat node with one constant input.""" def __init__( @@ -16,6 +16,7 @@ def __init__( axis: int, constant_shape: tuple[int, ...], variable_input_first: bool, + input_names: list[str] = [], ) -> None: """Initialize the ConcatNode @@ -26,17 +27,8 @@ def __init__( variable_input_first: Wether the result is `concat(input, constant_tensor)` or `concat(constant_tensor, input)` """ - Node.__init__( - self, - node_id=node_id, - node_name=node_name, - type="gather", - onchip_energy=0, - offchip_energy=0, - runtime=0, - possible_core_allocation=[-1], - ) - LayerNodeABC.__init__(self, node_id=node_id, node_name=node_name) + op_type = "concat" + super().__init__(node_id, node_name, op_type, input_names) self.axis = axis self.constant_shape = constant_shape @@ -53,7 +45,7 @@ def __init__( case _: raise ValueError("More than two inputs for ConcatNode") - def concat(self, tensor: NodeTensor) -> NodeTensor: + def propagate(self, tensor: NodeTensor, next_node: Node | None = None) -> NodeTensor: """Perform gather operation on the tensor.""" return tensor.concat_with_empty( shape=self.constant_shape, axis=self.axis, variable_input_first=self.variable_input_first diff --git a/stream/workload/dependency_propagation/dummy_node.py b/stream/workload/dependency_propagation/dummy_node.py index e24dc0bd..9e26f04e 100644 --- a/stream/workload/dependency_propagation/dummy_node.py +++ b/stream/workload/dependency_propagation/dummy_node.py @@ -1,9 +1,11 @@ from zigzag.workload.dummy_node import DummyNode as DummyNodeZigZag +from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode from stream.workload.node import Node -class DummyNode(DummyNodeZigZag, Node): +class DummyNode(DummyNodeZigZag, PropagationNode): """DummyNode of an onnx operator that is not import for finer graph generation or for cost estimation, but plays a role because of the passing of the input and output tensors. """ @@ -14,7 +16,9 @@ def __init__( node_name: str, predecessors: list[int], op_type: str = "dummy", + input_names: list[str] = [], ) -> None: + PropagationNode.__init__(self, node_id, node_name, op_type, input_names) DummyNodeZigZag.__init__( self, node_id=node_id, @@ -22,13 +26,6 @@ def __init__( node_type=op_type, node_name=node_name, ) - Node.__init__( - self, - node_id=node_id, - node_name=node_name, - type=op_type, - onchip_energy=0, - offchip_energy=0, - runtime=0, - possible_core_allocation=[-1], - ) + + def propagate(self, tensor: NodeTensor, next_node: Node | None = None) -> NodeTensor: + return tensor diff --git a/stream/workload/dependency_propagation/elementwise_node.py b/stream/workload/dependency_propagation/elementwise_node.py index fbb507b2..47d2fa66 100644 --- a/stream/workload/dependency_propagation/elementwise_node.py +++ b/stream/workload/dependency_propagation/elementwise_node.py @@ -1,25 +1,21 @@ from zigzag.datatypes import LayerOperand +from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode from stream.workload.node import Node -class ElementwiseNode(Node): +class ElementwiseNode(PropagationNode): def __init__( self, node_id: int, node_name: str, predecessor: int, + input_names: list[str], ) -> None: - super().__init__( - node_id=node_id, - node_name=node_name, - type="elementwise", - onchip_energy=0, - offchip_energy=0, - runtime=0, - possible_core_allocation=[-1], - ) + op_type = "elementwise" + super().__init__(node_id, node_name, op_type, input_names) self.input_operand_source = {LayerOperand("I"): predecessor} def join(self, tensor1, tensor2): @@ -30,3 +26,6 @@ def join(self, tensor1, tensor2): tensor2 (np.ndarray): The second input tensor """ return tensor1 | tensor2 + + def propagate(self, tensor: NodeTensor, next_node: Node | None = None) -> NodeTensor: + return tensor diff --git a/stream/workload/dependency_propagation/flatten_node.py b/stream/workload/dependency_propagation/flatten_node.py index dbe48577..cd82be9e 100644 --- a/stream/workload/dependency_propagation/flatten_node.py +++ b/stream/workload/dependency_propagation/flatten_node.py @@ -1,12 +1,12 @@ import numpy as np from zigzag.datatypes import LayerOperand -from zigzag.workload.layer_node_abc import LayerNodeABC from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode from stream.workload.node import Node -class FlattenNode(Node, LayerNodeABC): +class FlattenNode(PropagationNode): """Class that represents an onnx Flatten node.""" def __init__( @@ -15,32 +15,23 @@ def __init__( node_name: str, predecessor: int | None, axis: int | None, + input_names: list[str], ) -> None: """Initialize the FlattenNode Args: - shape (list): The output tensor's shape. + shape: The output tensor's shape. """ - super().__init__( - node_id=node_id, - node_name=node_name, - type="flatten", - onchip_energy=0, - offchip_energy=0, - runtime=0, - possible_core_allocation=[-1], - ) + op_type = "flatten" + super().__init__(node_id, node_name, op_type, input_names) + self.axis = axis if predecessor is not None: self.input_operand_source = {LayerOperand("I"): predecessor} - def flatten(self, input_tensor: NodeTensor) -> NodeTensor: - """Reshape an input tensor - - Args: - input_tensor (np.ndarray): The input tensor - """ - shape = input_tensor.tensor_shape + def propagate(self, tensor: NodeTensor, next_node: Node | None = None) -> NodeTensor: + """Reshape an input tensor""" + shape = tensor.tensor_shape # taken from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-51 new_shape = (1, -1) if self.axis == 0 else (np.prod(shape[0 : self.axis]).astype(int), -1) - return input_tensor.reshape(new_shape) + return tensor.reshape(new_shape) diff --git a/stream/workload/dependency_propagation/gather_node.py b/stream/workload/dependency_propagation/gather_node.py index 1967ac06..6d584072 100644 --- a/stream/workload/dependency_propagation/gather_node.py +++ b/stream/workload/dependency_propagation/gather_node.py @@ -1,11 +1,11 @@ from zigzag.datatypes import LayerOperand -from zigzag.workload.layer_node_abc import LayerNodeABC from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode from stream.workload.node import Node -class GatherNode(Node, LayerNodeABC): +class GatherNode(PropagationNode): """Class that represents an onnx Reshape node.""" def __init__( @@ -15,6 +15,7 @@ def __init__( predecessors: list[int], gather_axis: int, gather_indices: int | list[int], + input_names: list[str] = [], ) -> None: """Initialize the GatherNode @@ -23,17 +24,8 @@ def __init__( gather_axis: Which axis to gather on. gather_indices: Indices of elements to be gathered. """ - Node.__init__( - self, - node_id=node_id, - node_name=node_name, - type="gather", - onchip_energy=0, - offchip_energy=0, - runtime=0, - possible_core_allocation=[-1], - ) - LayerNodeABC.__init__(self, node_id=node_id, node_name=node_name) + op_type = "gather" + super().__init__(node_id, node_name, op_type, input_names) self.gather_axis = gather_axis self.gather_indices = gather_indices @@ -48,6 +40,6 @@ def __init__( case _: raise ValueError("More than two inputs for GatherNode") - def gather_operand_tensor(self, tensor: NodeTensor) -> NodeTensor: + def propagate(self, tensor: NodeTensor, next_node: Node | None = None) -> NodeTensor: """Perform gather operation on the tensor.""" return tensor.gather(self.gather_indices, axis=self.gather_axis) diff --git a/stream/workload/dependency_propagation/propagation_node.py b/stream/workload/dependency_propagation/propagation_node.py new file mode 100644 index 00000000..401a3242 --- /dev/null +++ b/stream/workload/dependency_propagation/propagation_node.py @@ -0,0 +1,28 @@ +from abc import abstractmethod + +from zigzag.workload.layer_node_abc import LayerNodeABC + +from stream.node_tensor import NodeTensor +from stream.workload.node import Node + + +class PropagationNode(Node, LayerNodeABC): + """Stream node that does not perform computations and is not mapped on hardware, but propagates dependencies + between nodes""" + + def __init__(self, node_id: int, node_name: str, op_type: str, input_names: list[str]): + Node.__init__( + self, + node_id=node_id, + node_name=node_name, + type=op_type, + onchip_energy=0, + offchip_energy=0, + runtime=0, + possible_core_allocation=[-1], + input_names=input_names, + ) + LayerNodeABC.__init__(self, node_id=node_id, node_name=node_name) + + @abstractmethod + def propagate(self, tensor: NodeTensor, next_node: Node | None = None) -> NodeTensor: ... diff --git a/stream/workload/dependency_propagation/reshape_node.py b/stream/workload/dependency_propagation/reshape_node.py index c1223240..33ef3537 100644 --- a/stream/workload/dependency_propagation/reshape_node.py +++ b/stream/workload/dependency_propagation/reshape_node.py @@ -1,11 +1,11 @@ +from yaml import Node from zigzag.datatypes import Constants -from zigzag.workload.layer_node_abc import LayerNodeABC from stream.node_tensor import NodeTensor -from stream.workload.node import Node +from stream.workload.dependency_propagation.propagation_node import PropagationNode -class ReshapeNode(Node, LayerNodeABC): +class ReshapeNode(PropagationNode): """Class that represents an onnx Reshape node.""" def __init__( @@ -15,6 +15,7 @@ def __init__( predecessor: int, shape: tuple[int, ...], allow_zero: bool = False, + input_names: list[str] = [], ) -> None: """Initialize the ReshapeNode @@ -23,23 +24,14 @@ def __init__( shape: The output tensor's shape. allow_zero: wether the output shape can be 0 at some dimensions. Iff True, shape `[2,0,3]` becomes `[2,3]` """ - Node.__init__( - self, - node_id=node_id, - node_name=node_name, - type="reshape", - onchip_energy=0, - offchip_energy=0, - runtime=0, - possible_core_allocation=[-1], - ) - LayerNodeABC.__init__(self, node_id=node_id, node_name=node_name) + op_type = "reshape" + super().__init__(node_id, node_name, op_type, input_names) self.allow_zero = allow_zero self.shape = shape self.input_operand_source = {Constants.LAYER_OP_I: predecessor} - def reshape_operand_tensor(self, tensor: NodeTensor): + def propagate(self, tensor: NodeTensor, next_node: Node) -> NodeTensor: """Reshape the tensor back to the representation needed for producer/consumer.""" new_shape = self.shape if not new_shape: diff --git a/stream/workload/dependency_propagation/slice_node.py b/stream/workload/dependency_propagation/slice_node.py new file mode 100644 index 00000000..49de89d5 --- /dev/null +++ b/stream/workload/dependency_propagation/slice_node.py @@ -0,0 +1,45 @@ +from zigzag.datatypes import Constants + +from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode +from stream.workload.node import Node + + +class SliceNode(PropagationNode): + """Class that represents an onnx Slice node.""" + + def __init__( + self, + node_id: int, + node_name: str, + predecessor: int, + starts: list[int], + ends: list[int], + axes: list[int], + steps: list[int], + output_names: list[str], + input_names: list[str] = [], + ) -> None: + """Initialize the SliceNode + Slice the tensor at axis `axis`. The sizes are given by `Slices`. `len(Slices)` is the number of output nodes. + + Args: + predecessors: The id of this node's parent. + axis: axis in which to Slice + Slices: sizes of the output Slices in the given axis + output_names: the node names that correspond to the Slices + """ + op_type = "Slice" + super().__init__(node_id, node_name, op_type, input_names) + + self.starts = starts + self.ends = ends + self.axes = axes + self.steps = steps + self.input_operand_source = {Constants.LAYER_OP_I: predecessor} + self.output_names = output_names + + def propagate(self, tensor: NodeTensor, next_node: Node | None = None): + """Slice the tensor. + Currently assumes only one slice is created.""" + return tensor.slice(starts=self.starts[0], ends=self.ends[0], axis=self.axes[0], steps=self.steps[0]) diff --git a/stream/workload/dependency_propagation/split_node.py b/stream/workload/dependency_propagation/split_node.py new file mode 100644 index 00000000..631be9f0 --- /dev/null +++ b/stream/workload/dependency_propagation/split_node.py @@ -0,0 +1,56 @@ +import numpy as np +from zigzag.datatypes import Constants + +from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode +from stream.workload.node import Node + + +class SplitNode(PropagationNode): + """Class that represents an onnx Split node.""" + + def __init__( + self, + node_id: int, + node_name: str, + predecessor: int, + axis: int, + splits: list[int], + output_names: list[str], + input_names: list[str] = [], + ) -> None: + """Initialize the SplitNode + Split the tensor at axis `axis`. The sizes are given by `splits`. `len(splits)` is the number of output nodes. + + Args: + predecessors: The id of this node's parent. + axis: axis in which to split + splits: sizes of the output splits in the given axis + output_names: the node names that correspond to the splits + """ + assert len(splits) == len(output_names) + op_type = "split" + super().__init__(node_id, node_name, op_type, input_names) + + self.axis = axis + self.splits = splits + self.input_operand_source = {Constants.LAYER_OP_I: predecessor} + self.output_names = output_names + + def propagate(self, tensor: NodeTensor, next_node: Node): + """Split the tensor back to the representation needed for producer/consumer.""" + + # Numpy requires the indices where to split instead of the sizes of the resulting splits + split_indices = list(np.cumsum(self.splits)[:-1]) + output_tensors = tensor.split(split_indices, axis=self.axis) + + # Find which split part corresponds to the input of the next node + try: + index = next(i for i, output_name in enumerate(self.output_names) if output_name in next_node.input_names) + except StopIteration: + raise ValueError( + f"Cannot find this nodes' ({self.name}) outputs {self.output_names} in next nodes' inputs {next_node.input_names}" + ) + + output_tensor = output_tensors[index] + return output_tensor diff --git a/stream/workload/dependency_propagation/transpose_node.py b/stream/workload/dependency_propagation/transpose_node.py index e4fb1223..d2d2fb23 100644 --- a/stream/workload/dependency_propagation/transpose_node.py +++ b/stream/workload/dependency_propagation/transpose_node.py @@ -1,11 +1,11 @@ from zigzag.datatypes import LayerOperand -from zigzag.workload.layer_node_abc import LayerNodeABC from stream.node_tensor import NodeTensor +from stream.workload.dependency_propagation.propagation_node import PropagationNode from stream.workload.node import Node -class TransposeNode(Node, LayerNodeABC): +class TransposeNode(PropagationNode): """Class that represents an onnx Transpose node.""" def __init__( @@ -14,26 +14,14 @@ def __init__( node_name: str, predecessor: int, permute_axes: list[int] | None = None, + input_names: list[str] = [], ) -> None: - Node.__init__( - self, - node_id=node_id, - node_name=node_name, - type="reshape", - onchip_energy=0, - offchip_energy=0, - runtime=0, - possible_core_allocation=[-1], - ) - LayerNodeABC.__init__(self, node_id=node_id, node_name=node_name) + op_type = "transpose" + super().__init__(node_id, node_name, op_type, input_names) self.permute_axes = permute_axes self.input_operand_source = {LayerOperand("I"): predecessor} - def transpose(self, input_tensor: NodeTensor) -> NodeTensor: - """Transpose an input tensor. - - Args: - input_tensor (np.ndarray): The input tensor - """ - return input_tensor.transpose(axes=self.permute_axes) + def propagate(self, tensor: NodeTensor, next_node: Node | None = None) -> NodeTensor: + """Transpose an input tensor.""" + return tensor.transpose(axes=self.permute_axes) diff --git a/stream/workload/node.py b/stream/workload/node.py index 06f216ac..c720ef88 100644 --- a/stream/workload/node.py +++ b/stream/workload/node.py @@ -20,17 +20,19 @@ def __init__( possible_core_allocation: list[int], core_allocation_is_fixed: bool = False, chosen_core_allocation: int | None = None, + input_names: list[str] = [], ) -> None: """Initialize the Node metaclass Args: - type (str): The type of Node. - energy (float): The energy consumption of this Node. - runtime (int): The runtime of this Node. - possible_core_allocation (int): The core id on which this Node can be mapped. - inputs: (List[str]): The names of the input tensors of this node - outputs: (List[str]): The names of the output tensors of this node. + type: The type of Node. + energy: The energy consumption of this Node. + runtime: The runtime of this Node. + possible_core_allocation: The core id on which this Node can be mapped. + inputs: The names of the input tensors of this node + outputs: The names of the output tensors of this node. chosen_core_allocation: The final core allocation of this node + input_names: Names of the ONNX input node """ super().__init__(node_id, node_name) @@ -41,6 +43,7 @@ def __init__( self.possible_core_allocation = possible_core_allocation self.core_allocation_is_fixed = core_allocation_is_fixed self.chosen_core_allocation = chosen_core_allocation + self.input_names = input_names # will be set by the scheduler self.start = None # will be set by the scheduler