diff --git a/tools/explorer/tt_adapter/src/tt_adapter/main.py b/tools/explorer/tt_adapter/src/tt_adapter/main.py index d0c49b7af..3876a0911 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/main.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/main.py @@ -70,9 +70,9 @@ def execute( memory_layout_analysis_enabled = False memory_layout_analysis_policy = None - ttnn_ir = self.model_runner.run( + perf_data = self.model_runner.run( model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy ) # TODO(odjuricic, #933) Parse TTNN IR and return the post optimized graph. - return {"graphs": []} + return utils.to_adapter_format({"perf_data": perf_data}) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py index b9ae471ca..6b064b155 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -12,7 +12,12 @@ def get_loc_str(loc): try: - res = str(loc).split('"')[1] + # Constant loc( at the start of the location and ) at the end. Can just strip these characters + loc = str(loc) + if loc.startswith("loc(") and loc.endswith(")"): + res = str(loc)[4:-1] + else: + res = loc # This is a fallback to just visualize / see what the loc is if not processable. except: res = "unknown" return res @@ -471,6 +476,21 @@ def make_constant_node(self, name_dict, constant_name): ] +def get_locs(module): + name_dict = defaultdict(int) + + for op in module.body.operations: + for region in op.regions: + for block in region.blocks: + for op in block.operations: + op = OpHandler(op) + _id = op.get_id(name_dict) + # This will now populate name_dict with all of the locations that are relevant + + # The keys will be all the unique locations, and the values will be the number of times that location appears + return name_dict + + def build_graph(module): name_dict = defaultdict(int) output_connections = defaultdict(int) @@ -479,7 +499,11 @@ def build_graph(module): op_to_graph_node = {} module_op = OpHandler(module.operation) - graph.nodes.append(module_op.make_graph_node(name_dict)) + module_attrs = module_op.get_attributes() + module_attrs = dict((attr.key, attr.value) for attr in module_attrs) + # Add module attributes to the graph as "namespace attributes" + group_node_attrs = {} + group_node_attrs[module_op.get_namespace()] = module_attrs for op in module.body.operations: append_later = [] @@ -567,5 +591,5 @@ def build_graph(module): ) ) output_connections[source_node.id] += 1 - + graph.groupNodeAttributes = group_node_attrs return graph diff --git a/tools/explorer/tt_adapter/src/tt_adapter/runner.py b/tools/explorer/tt_adapter/src/tt_adapter/runner.py index 65da2ec2b..205944acd 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/runner.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/runner.py @@ -9,8 +9,9 @@ # os.environ["TTRT_LOGGER_LEVEL"] = "ERROR" from ttrt import API as ttrt import ttmlir.passes -from . import utils +from . import utils, mlir import pandas as pd +from model_explorer import node_data_builder class ModelRunner: @@ -69,6 +70,9 @@ def run( module = utils.parse_mlir_file(model_path) + # Collect unique locations + name_dict = mlir.get_locs(module) + try: print("Running MLIR compile: TTIR to TTNN Backend Pipeline") print("With options: ", options_string) @@ -131,8 +135,34 @@ def run( "DEVICE FW DURATION [ns]", "CORE COUNT", "OUTPUT_0_MEMORY", + "LOC", ] perf = perf[columns] print(perf) - print("Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns") + print(f"Total device duration: {perf['DEVICE FW DURATION [ns]'].sum()}ns") + + # Create the node_data type here + timing_data = list(zip(perf["LOC"], perf["DEVICE FW DURATION [ns]"])) + results = {} + for loc, duration in timing_data: + loc = mlir.get_loc_str(loc).replace("'", '"') + if loc in name_dict: + for i in range(name_dict[loc]): + results[f"{loc}__{i}"] = node_data_builder.NodeDataResult( + value=duration + ) + else: + print( + f"Location {loc} not found in graph, ops data for this op was not reported." + ) + + gradient = [ + node_data_builder.GradientItem(stop=0, bgColor="yellow"), + node_data_builder.GradientItem(stop=1, bgColor="red"), + ] + + data = node_data_builder.GraphNodeData(results=results, gradient=gradient) + + res = node_data_builder.ModelNodeData(graphsData={"tt-graph": data}) + return res diff --git a/tools/explorer/tt_adapter/src/tt_adapter/utils.py b/tools/explorer/tt_adapter/src/tt_adapter/utils.py index bca7e640b..4b404a204 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/utils.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/utils.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import ttmlir +from dataclasses import make_dataclass def parse_mlir_file(model_path): @@ -11,3 +12,11 @@ def parse_mlir_file(model_path): ttmlir.dialects.ttnn.register_dialect(ctx) module = ttmlir.ir.Module.parse(model_file.read(), ctx) return module + + +def to_dataclass(obj: dict, dc_name: str = "tempClass"): + return make_dataclass(dc_name, ((k, type(v)) for k, v in obj.items()))(**obj) + + +def to_adapter_format(obj: dict): + return {"graphs": [to_dataclass(obj)]}