diff --git a/python/TTNNModule.cpp b/python/TTNNModule.cpp index a7df3b619..f1374ba06 100644 --- a/python/TTNNModule.cpp +++ b/python/TTNNModule.cpp @@ -164,6 +164,10 @@ void populateTTNNModule(py::module &m) { } return static_cast( self.getMemLayout().getValue()); - }); + }) + .def_property_readonly("is_tiled", &tt::ttnn::TTNNLayoutAttr::isTiled) + // TODO fix + .def_property_readonly("data_type", + &tt::ttnn::TTNNLayoutAttr::getDataType); } } // namespace mlir::ttmlir::python diff --git a/python/Util.cpp b/python/Util.cpp index c562306bc..b8bf220de 100644 --- a/python/Util.cpp +++ b/python/Util.cpp @@ -3,6 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Bindings/Python/TTMLIRModule.h" +#include +#include namespace mlir::ttmlir::python { @@ -17,25 +19,26 @@ void populateUtilModule(py::module &m) { return source; }); - m.def("get_loc_name", [](MlirLocation _loc) -> std::string { - mlir::Location loc = unwrap(_loc); - if (mlir::isa(loc)) { - mlir::NameLoc nameLoc = mlir::cast(loc); - return nameLoc.getName().str(); - } - return "-"; - }); + m.def("get_loc_name", + [](MlirLocation _loc) -> std::variant { + mlir::Location loc = unwrap(_loc); + if (mlir::isa(loc)) { + mlir::NameLoc nameLoc = mlir::cast(loc); + return nameLoc.getName().str(); + } + return py::none(); + }); - m.def("get_loc_full", [](MlirLocation _loc) -> std::string { - mlir::Location loc = unwrap(_loc); - if (mlir::isa(loc)) { - mlir::FileLineColLoc fileLoc = mlir::cast(loc); - return fileLoc.getFilename().str() + ":" + - std::to_string(fileLoc.getLine()) + ":" + - std::to_string(fileLoc.getColumn()); - } - return "-"; - }); + m.def("get_loc_full", + [](MlirLocation _loc) -> std::variant { + mlir::Location loc = unwrap(_loc); + + std::string locationStr; + llvm::raw_string_ostream output(locationStr); + loc.print(output); + + return locationStr; + }); } } // namespace mlir::ttmlir::python diff --git a/tools/explorer/CMakeLists.txt b/tools/explorer/CMakeLists.txt index 387955854..93a539709 100644 --- a/tools/explorer/CMakeLists.txt +++ b/tools/explorer/CMakeLists.txt @@ -3,7 +3,7 @@ include(ExternalProject) set(TT_EXPLORER_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run.py) set(TTMLIR_BUILD_BIN_DIR ${TTMLIR_BINARY_DIR}/bin) -set(MODEL_EXPLORER_VERSION "ca884d5eb3291507e7f4e76776957e231b2d9b6d") +set(MODEL_EXPLORER_VERSION "8ec112eaee8006301039ee34c55d3751a1b82c14") ExternalProject_Add( model-explorer PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer diff --git a/tools/explorer/test/run_tests.py b/tools/explorer/test/run_tests.py index 485104fbb..b7ba421c3 100644 --- a/tools/explorer/test/run_tests.py +++ b/tools/explorer/test/run_tests.py @@ -16,8 +16,9 @@ "test/ttmlir/Dialect/TTNN/optimizer/mnist_sharding.mlir", "tools/explorer/test/models/*.mlir", ] +MNIST_SHARDING_PATH = "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir" TEST_EXECUTE_MODEL_PATHS = [ - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", + MNIST_SHARDING_PATH, ] @@ -28,7 +29,7 @@ def get_test_files(paths): return files -def send_command(command, model_path, settings): +def send_command(command, model_path, settings={}): cmd = { "extensionId": "tt_adapter", "cmdId": command, @@ -51,7 +52,7 @@ def execute_command(model_path, settings): def wait_for_execution_to_finish(timeout): for _ in range(timeout): try: - response = send_command("status_check", "", {}) + response = send_command("status_check", "") if response.status_code == 200 and response.json().get("graphs")[0].get( "isDone" ): @@ -60,9 +61,7 @@ def wait_for_execution_to_finish(timeout): print(f"Request failed: {e}") raise Exception("Status check request failed") time.sleep(1) - raise RuntimeError( - f"Execution did not finish within {MODEL_EXECUTION_TIMEOUT} seconds" - ) + raise RuntimeError(f"Execution did not finish within {timeout} seconds") def execute_command_and_wait(model_path, settings, timeout): @@ -107,7 +106,7 @@ def server_shutdown(): @pytest.mark.parametrize("model_path", get_test_files(TEST_LOAD_MODEL_PATHS)) def test_load_model(model_path): - result = send_command("convert", model_path, {}) + result = send_command("convert", model_path) assert result.ok if "error" in result.json(): print(result.json()) @@ -119,22 +118,25 @@ def test_execute_model(model_path): execute_command_and_wait( model_path, {"optimizationPolicy": "DF Sharding"}, timeout=60 ) + send_command("convert", model_path) def test_execute_mnist_l1_interleaved(): execute_command_and_wait( - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", + MNIST_SHARDING_PATH, {"optimizationPolicy": "Greedy L1 Interleaved"}, timeout=60, ) + send_command("convert", MNIST_SHARDING_PATH) def test_execute_mnist_optimizer_disabled(): execute_command_and_wait( - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", + MNIST_SHARDING_PATH, {"optimizationPolicy": "Optimizer Disabled"}, timeout=60, ) + send_command("convert", MNIST_SHARDING_PATH) def test_execute_model_invalid_policy(): diff --git a/tools/explorer/tt_adapter/src/tt_adapter/main.py b/tools/explorer/tt_adapter/src/tt_adapter/main.py index d3f516925..f8614c75f 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/main.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/main.py @@ -44,25 +44,39 @@ def __init__(self): def convert( self, model_path: str, settings: Dict ) -> model_explorer.ModelExplorerGraphs: - perf_trace = None - if optimized_model_path := self.model_runner.get_optimized_model_path(): + if optimized_model_path := self.model_runner.get_optimized_model_path( + model_path + ): print(f"Using optimized model: {optimized_model_path}") - model_path = optimized_model_path - # Get performance results. - perf_trace = self.model_runner.get_perf_trace() + perf_trace = self.model_runner.get_perf_trace(model_path) + + module = utils.parse_mlir_file(optimized_model_path) + + # Convert TTIR to Model Explorer Graphs and Display/Return + graph, perf_data = mlir.build_graph(module, perf_trace) + if perf_data: + # TODO(odjuricic) We can probably edit the actual graph response or create our own instead of just adding to dataclass.s + graph = utils.add_to_dataclass( + graph, "overlays", {"Performance Trace": perf_data.graphsData} + ) + + if overrides := self.model_runner.get_overrides(model_path): + graph = utils.add_to_dataclass(graph, "overrides", overrides) + else: + module = utils.parse_mlir_file(model_path) - module = utils.parse_mlir_file(model_path) + # Convert TTIR to Model Explorer Graphs and Display/Return + graph, _ = mlir.build_graph(module) - # Convert TTIR to Model Explorer Graphs and Display/Return - graph, perf_data = mlir.build_graph(module, perf_trace) - if perf_data: - graph = utils.add_to_dataclass(graph, "perf_data", perf_data.graphsData) return {"graphs": [graph]} def execute( self, model_path: str, settings: Dict ) -> model_explorer.ModelExplorerGraphs: + + print("SETTINGS: ", settings) + override_handler = optimizer_overrides.OptimizerOverridesHandler() override_handler.set_system_desc_path( f"{self.model_runner.get_artifacts_dir()}/system_desc.ttsys" @@ -84,7 +98,9 @@ def execute( OPTIMIZATION_POLICIES[optimization_policy] ) - self.model_runner.run(model_path, override_handler.to_string()) + self.model_runner.run( + model_path, override_handler.to_string(), settings.get("overrides", None) + ) return {"graphs": []} diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py index 1e39f0759..db41140d1 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -11,6 +11,7 @@ from ttmlir import ir, util +# TODO: Also change the KeyValue to support editable instead of this. def make_editable_kv(kv, editable): obj = dataclasses.asdict(kv) obj["editable"] = editable @@ -19,14 +20,14 @@ def make_editable_kv(kv, editable): )(**obj) -def get_loc_str(loc): - try: - res = util.get_loc_name(loc) - if res == "-": - res = util.get_loc_full(loc) - except: - res = "unknown" - return res +def parse_loc_string(loc_str): + """ + This can be replaces by ttmlir.ir.Module.parse, but requires some further wodo to extract the actual location object from the module. + """ + match = re.match(r'^loc\("([^"]+)"', loc_str) + if match: + return match.group(1) + return None class AttrHandler: @@ -409,7 +410,7 @@ def parse_ttnn_ttnn_layout(attr): result.append( make_editable_kv( graph_builder.KeyValue( - key="Tensor Memory Layout", + key="tensor_memory_layout", value=str(ttnn.TensorMemoryLayout(memory_layout)), ), editable={ @@ -419,22 +420,59 @@ def parse_ttnn_ttnn_layout(attr): ) ) result.append( - graph_builder.KeyValue( - key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + make_editable_kv( + graph_builder.KeyValue( + key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + ), + editable={ + "input_type": "grid", + "separator": "x", + "min_value": 1, + "max_value": 100, + "step": 1, + }, ) ) result.append( graph_builder.KeyValue(key="memref_shape", value=str(layout.memref.shape)) ) + buffer_attr = ttnn.ir.BufferTypeAttr.maybe_downcast(layout.memref.memory_space) result.append( - graph_builder.KeyValue(key="memref_rank", value=str(layout.memref.rank)) + make_editable_kv( + graph_builder.KeyValue( + key="buffer_type", value=str(ttnn.BufferType(buffer_attr.value)) + ), + editable={ + "input_type": "value_list", + "options": [str(o) for o in ttnn.BufferType], + }, + ) ) - buffer_attr = ttnn.ir.BufferTypeAttr.maybe_downcast(layout.memref.memory_space) + result.append( - graph_builder.KeyValue( - key="memref_memory_space", value=str(ttnn.BufferType(buffer_attr.value)) + make_editable_kv( + graph_builder.KeyValue( + key="is_tiled", + value=str(layout.is_tiled), + ), + editable={ + "input_type": "value_list", + "options": ["True", "False"], + }, ) ) + + # result.append( + # make_editable_kv( + # graph_builder.KeyValue( + # key="data_type", value=str(layout.data_type), + # ), + # editable={ + # "input_type": "value_list", + # "options": [str(o) for o in ttnn.DataType], + # }, + # ) + # ) return result @@ -444,11 +482,14 @@ class OpHandler: def __init__(self, op): self.op = op - self.location = get_loc_str(self.op.location) + self.named_location = util.get_loc_name(self.op.location) + self.full_location = util.get_loc_full(self.op.location) self.id = self._create_unique_id() def _create_unique_id(self): - name = self.location + # Change this to something else PLEASE. + # I think that we can just use uniqe numbers here + name = self.full_location if self.full_location else "unknown" name_num = self.name_dict[name] id = name + "__" + str(name_num) self.name_dict[name] += 1 @@ -456,10 +497,12 @@ def _create_unique_id(self): def get_namespace(self, parent_op=None): op = self.op if not parent_op else parent_op - name = get_loc_str(op.location) + name = util.get_loc_name(op.location) if op.parent and op.parent.name != "builtin.module": - return self.get_namespace(op.parent) + "/" + name - return name + parent_name = self.get_namespace(op.parent) + if parent_name: + return parent_name + "/" + name + return name or "" def get_attributes(self): # Parse Op Attributes themselves @@ -467,7 +510,17 @@ def get_attributes(self): for attr in self.op.attributes: result.extend(AttrHandler.parse_attr(attr)) - # Add output tensor properties to the op itself + # Add location as an attribute + if self.named_location: + result.append( + graph_builder.KeyValue(key="named_location", value=self.named_location) + ) + if self.full_location: + result.append( + graph_builder.KeyValue(key="full_location", value=self.full_location) + ) + + # Add output tensor attriributes to the op itself if self.op.results: output_tensor = self.op.result output_attrs = [] @@ -538,9 +591,11 @@ def build_graph(module, perf_trace=None): loc_to_perf = {} if perf_trace is not None: for _, row in perf_trace.iterrows(): - loc = get_loc_str(row["LOC"]) - assert loc not in loc_to_perf - loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"] + loc = parse_loc_string(row["LOC"]) + # TODO(odjuricic) Locations seem to be missing from ttrt perf csv. + # assert loc not in loc_to_perf + if loc: + loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"] module_op = OpHandler(module.operation) module_attrs = module_op.get_attributes() @@ -558,9 +613,12 @@ def build_graph(module, perf_trace=None): operation = OpHandler(op) graph_node = operation.make_graph_node() - if operation.location in loc_to_perf: + if ( + operation.named_location in loc_to_perf + and operation.op.name != "ttnn.empty" + ): perf_node_data[operation.id] = node_data_builder.NodeDataResult( - loc_to_perf[operation.location] + loc_to_perf[operation.named_location] ) if op.name in EMPTY_OPS: diff --git a/tools/explorer/tt_adapter/src/tt_adapter/runner.py b/tools/explorer/tt_adapter/src/tt_adapter/runner.py index ba9ed9fad..0366dc15a 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/runner.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/runner.py @@ -19,24 +19,40 @@ class ExplorerRunException(Exception): pass +class ModelState: + """ + After a model is compiled and executed we to keep track of all additional data that was created. + """ + + # Path to the compiled TTNN IR file. + optimized_model_path = None + # Path to the output directory where ttrt dumps all model files (perf trace, memory state, etc) + model_output_dir = None + # Overrides, changes that the user made to op configurations. + overrides = None + + class ModelRunner: """ ModelRunner is a singleton class used for compilation and running of models. Ensuring only one can be run at a time. This is necessary because the adaptor class is reinitialized on every request from the frontend, so it cannot keep state. """ + # Global static runner state. Initialized once. _instance = None _explorer_artifacts_dir = None _build_dir = None - # State variables. + # Singleton runner state. Initialized on every run. runner_thread = None runner_error = None + log_queue = queue.Queue() # progress should be a number between 0 and 100. progress = 0 - log_queue = queue.Queue() - optimized_model_path = None - ttrt_output_dir = None + + # State for models that have been executed. + # Contains a mapping from model path to ModelState. + model_state = dict() def __new__(cls, *args, **kwargs): if not cls._instance: @@ -70,11 +86,22 @@ def initialize(self): print("ModelRunner initialized.") - def get_optimized_model_path(self): - return self.optimized_model_path + def get_optimized_model_path(self, model_path): + return ( + self.model_state[model_path].optimized_model_path + if model_path in self.model_state + else None + ) - def get_output_dir(self): - return self.ttrt_output_dir + def get_output_dir(self, model_path): + return self.model_state[model_path].model_output_dir + + def get_overrides(self, model_path): + return ( + self.model_state[model_path].overrides + if model_path in self.model_state + else None + ) def get_error(self): return self.runner_error @@ -94,21 +121,24 @@ def get_logs(self): logs.append(self.log_queue.get()) return "\n".join(logs) - def reset_state(self): + def reset_state(self, model_path): assert not self.is_busy() self.runner_thread = None - self.log_queue.queue.clear() - self.optimized_model_path = None self.runner_error = None self.progress = 0 - self.ttrt_output_dir = None + self.log_queue.queue.clear() + + if model_path in self.model_state: + del self.model_state[model_path] def log(self, message): print(message) self.log_queue.put(message) - def get_perf_trace(self): - op_perf_file = f"{self.ttrt_output_dir}/perf/ops_perf_results.csv" + def get_perf_trace(self, model_path): + op_perf_file = ( + f"{self.model_state[model_path].model_output_dir}/perf/ops_perf_results.csv" + ) if not os.path.exists(op_perf_file): raise FileNotFoundError(f"Performance file {op_perf_file} not found.") @@ -148,22 +178,24 @@ def compile_and_run_wrapper(self, model_path, overrides_string): def compile_and_run(self, model_path, overrides_string): model_name = os.path.basename(model_path) flatbuffer_file = model_name + ".ttnn" - self.ttrt_output_dir = self._explorer_artifacts_dir + "/" + flatbuffer_file + state = self.model_state[model_path] + + state.model_output_dir = self._explorer_artifacts_dir + "/" + flatbuffer_file - if os.path.exists(self.ttrt_output_dir): + if os.path.exists(state.model_output_dir): self.log("Removing artifacts of previous run.") - os.system(f"rm -rf {self.ttrt_output_dir}") + os.system(f"rm -rf {state.model_output_dir}") - os.makedirs(self.ttrt_output_dir) + os.makedirs(state.model_output_dir) # Copy the model to the run directory. - os.system(f"cp {model_path} {self.ttrt_output_dir}") + os.system(f"cp {model_path} {state.model_output_dir}") self.progress = 10 ############################### Compile ################################## ttnn_ir_file = ( - f"{self.ttrt_output_dir}/{model_name.replace('.mlir', '_ttnn.mlir')}" + f"{state.model_output_dir}/{model_name.replace('.mlir', '_ttnn.mlir')}" ) compile_command = [ f"{self._build_dir}/bin/ttmlir-opt", @@ -218,7 +250,7 @@ def compile_and_run(self, model_path, overrides_string): self.log(error) raise ExplorerRunException(error) - perf = self.get_perf_trace() + perf = self.get_perf_trace(model_path) columns = [ "GLOBAL CALL COUNT", "OP CODE", @@ -232,16 +264,18 @@ def compile_and_run(self, model_path, overrides_string): print("Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns") - self.optimized_model_path = ttnn_ir_file + state.optimized_model_path = ttnn_ir_file self.progress = 100 - def run(self, model_path, compile_options): + def run(self, model_path, compile_options, overrides): # Check if a run is already in progress if self.is_busy(): raise RuntimeError( "A model is already being processed. Please wait for it to finish." ) - self.reset_state() + self.reset_state(model_path) + self.model_state[model_path] = ModelState() + self.model_state[model_path].overrides = overrides # Start compile and run in a new thread self.runner_thread = threading.Thread(