diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 8e23597a4..fe0375d70 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -29,4 +29,4 @@ /test/ttmlir/Silicon/TTNN/optimizer/ @nobradovictt @odjuricicTT /test/unittests/Optimizer @nobradovictt @odjuricicTT /tools/ @svuckovicTT @mtopalovicTT -/tools/explorer/ @odjuricicTT @nobradovictt @vprajapati-tt +/tools/explorer/ @odjuricicTT @nobradovictt @vprajapati-tt @vcanicTT diff --git a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h index b13d37564..ba90bc2ec 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h @@ -76,9 +76,8 @@ class OptimizerOverridesHandler { // Wrapper methods we use to expose the adders to the python bindings void addInputLayoutOverridePybindWrapper(std::string, std::vector &); - void addOutputLayoutOverridePybindWrapper(std::string, std::vector &, - BufferType, TensorMemoryLayout, - tt::ttnn::Layout, tt::DataType); + void addOutputLayoutOverridePybindWrapper(std::string, + OutputLayoutOverrideParams); private: // Flags for enabling/disabling the optimizer passes diff --git a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h index cd2d3585f..da0aa7430 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h @@ -5,13 +5,10 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H #define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H -#include - #include #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" namespace mlir::tt::ttnn { @@ -31,12 +28,12 @@ struct OptionNames { }; struct OutputLayoutOverrideParams { - std::optional> grid; - std::optional bufferType; - std::optional - tensorMemoryLayout; // INTERLEAVED / SHARDED etc... - std::optional memoryLayout; // ROW_MAJOR / TILE - std::optional dataType; + std::optional> grid = std::nullopt; + std::optional bufferType = std::nullopt; + std::optional tensorMemoryLayout = + std::nullopt; // INTERLEAVED / SHARDED etc... + std::optional memoryLayout = std::nullopt; // ROW_MAJOR / TILE + std::optional dataType = std::nullopt; // Check if all layout parameters that are generated in LegalLayoutAnalysis // are overridden. DataType is the only that is not. @@ -45,7 +42,7 @@ struct OutputLayoutOverrideParams { tensorMemoryLayout.has_value() && memoryLayout.has_value(); } - bool operator==(const OutputLayoutOverrideParams rhs) const { + bool operator==(const OutputLayoutOverrideParams &rhs) const { if (grid.has_value() != rhs.grid.has_value()) { return false; } diff --git a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp index 157c1e50d..d7f0b52fd 100644 --- a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp @@ -22,6 +22,7 @@ void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy( } void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy( MemoryLayoutAnalysisPolicyType value) { + enableMemoryLayoutAnalysisPolicy = true; memoryLayoutAnalysisPolicy = value; } @@ -198,13 +199,9 @@ void OptimizerOverridesHandler::addInputLayoutOverridePybindWrapper( } void OptimizerOverridesHandler::addOutputLayoutOverridePybindWrapper( - std::string opName, std::vector &grid, BufferType bufferType, - TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout, - tt::DataType dataType) { + std::string opName, OutputLayoutOverrideParams overrideParams) { StringRef opNameStringRef(opName); - SmallVector gridSmallVector(grid.begin(), grid.end()); - addOutputLayoutOverride(opNameStringRef, gridSmallVector, bufferType, - tensorMemoryLayout, memoryLayout, dataType); + addOutputLayoutOverride(opNameStringRef, overrideParams); } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/PassOverrides.cpp b/lib/Dialect/TTNN/Utils/PassOverrides.cpp index ad59ea91c..b170dd244 100644 --- a/lib/Dialect/TTNN/Utils/PassOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/PassOverrides.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" +#include namespace mlir::tt::ttnn { @@ -102,32 +103,43 @@ std::string OutputLayoutOverrideParser::toString( res += std::string(entry.getKey()) + "="; const OutputLayoutOverrideParams ¶ms = entry.getValue(); - // Print grid values + std::vector parts; + + // Collect grid values if (params.grid.has_value()) { + std::string gridStr; for (size_t i = 0; i < params.grid.value().size(); ++i) { - res += std::to_string(params.grid.value()[i]); + gridStr += std::to_string(params.grid.value()[i]); if (i < params.grid.value().size() - 1) { - res += "x"; + gridStr += "x"; } } + parts.push_back(gridStr); } - // Print memory space and memory layout + // Collect memory space and memory layout if (params.bufferType.has_value()) { - res += ":" + std::string(mlir::tt::ttnn::stringifyBufferType( - params.bufferType.value())); + parts.push_back(std::string( + mlir::tt::ttnn::stringifyBufferType(params.bufferType.value()))); } if (params.tensorMemoryLayout.has_value()) { - res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( - params.tensorMemoryLayout.value())); + parts.push_back(std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( + params.tensorMemoryLayout.value()))); } if (params.memoryLayout.has_value()) { - res += ":" + std::string(mlir::tt::ttnn::stringifyLayout( - params.memoryLayout.value())); + parts.push_back(std::string( + mlir::tt::ttnn::stringifyLayout(params.memoryLayout.value()))); } if (params.dataType.has_value()) { - res += ":" + std::string( - mlir::tt::DataTypeEnumToString(params.dataType.value())); + parts.push_back( + std::string(mlir::tt::DataTypeEnumToString(params.dataType.value()))); } + + // Join parts with ":" + res += std::accumulate(parts.begin(), parts.end(), std::string(), + [](const std::string &a, const std::string &b) { + return a.empty() ? b : a + ":" + b; + }); + if (++count < value.size()) { res += ","; } diff --git a/python/OptimizerOverrides.cpp b/python/OptimizerOverrides.cpp index 18806654c..3b901988d 100644 --- a/python/OptimizerOverrides.cpp +++ b/python/OptimizerOverrides.cpp @@ -4,6 +4,7 @@ #include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" #include "ttmlir/Bindings/Python/TTMLIRModule.h" +#include namespace mlir::ttmlir::python { @@ -133,12 +134,20 @@ void populateOptimizerOverridesModule(py::module &m) { "grid", [](const mlir::tt::ttnn::OutputLayoutOverrideParams &obj) { // Getter: Convert SmallVector to std::vector - return std::vector(obj.grid->begin(), obj.grid->end()); + if (obj.grid.has_value()) { + return std::make_optional>(obj.grid->begin(), + obj.grid->end()); + } + return std::make_optional>(); }, [](mlir::tt::ttnn::OutputLayoutOverrideParams &obj, const std::vector &input) { // Setter: Convert std::vector to SmallVector - obj.grid->clear(); + if (!obj.grid.has_value()) { + obj.grid = SmallVector(); + } else { + obj.grid->clear(); + } obj.grid->append(input.begin(), input.end()); }) .def_readwrite("buffer_type", @@ -149,7 +158,45 @@ void populateOptimizerOverridesModule(py::module &m) { .def_readwrite("memory_layout", &mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout) .def_readwrite("data_type", - &mlir::tt::ttnn::OutputLayoutOverrideParams::dataType); + &mlir::tt::ttnn::OutputLayoutOverrideParams::dataType) + .def("set_buffer_type_from_str", + [](mlir::tt::ttnn::OutputLayoutOverrideParams &obj, + const std::string &value) { + if (auto bufferType = mlir::tt::ttnn::symbolizeBufferType(value)) { + obj.bufferType = bufferType; + } else { + throw std::invalid_argument("Invalid buffer type: " + value); + } + }) + .def("set_tensor_memory_layout_from_str", + [](mlir::tt::ttnn::OutputLayoutOverrideParams &obj, + const std::string &value) { + if (auto tensorMemoryLayout = + mlir::tt::ttnn::symbolizeTensorMemoryLayout(value)) { + obj.tensorMemoryLayout = tensorMemoryLayout; + } else { + throw std::invalid_argument("Invalid tensor memory layout: " + + value); + } + }) + .def("set_memory_layout_from_str", + [](mlir::tt::ttnn::OutputLayoutOverrideParams &obj, + const std::string &value) { + if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(value)) { + obj.memoryLayout = memoryLayout; + } else { + throw std::invalid_argument("Invalid memory layout: " + value); + } + }) + .def("set_data_type_from_str", + [](mlir::tt::ttnn::OutputLayoutOverrideParams &obj, + const std::string &value) { + if (auto dataType = mlir::tt::DataTypeStringToEnum(value)) { + obj.dataType = dataType; + } else { + throw std::invalid_argument("Invalid data type: " + value); + } + }); } } // namespace mlir::ttmlir::python diff --git a/python/TTNNModule.cpp b/python/TTNNModule.cpp index a7df3b619..6b2710be7 100644 --- a/python/TTNNModule.cpp +++ b/python/TTNNModule.cpp @@ -156,7 +156,7 @@ void populateTTNNModule(py::module &m) { .def_property_readonly( "memref", [](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getMemref()); }) - .def_property_readonly("memory_layout_as_int", + .def_property_readonly("tensor_memory_layout_as_int", [](tt::ttnn::TTNNLayoutAttr self) -> std::variant { if (!self.getMemLayout()) { @@ -164,6 +164,16 @@ void populateTTNNModule(py::module &m) { } return static_cast( self.getMemLayout().getValue()); + }) + .def_property_readonly("memory_layout_as_int", + [](tt::ttnn::TTNNLayoutAttr self) { + return static_cast(self.getLayout()); + }) + .def_property_readonly("data_type_as_int", + [](tt::ttnn::TTNNLayoutAttr self) { + return static_cast(self.getDataType()); }); + // .def_property_readonly("data_type", + // &tt::ttnn::TTNNLayoutAttr::getDataType); } } // namespace mlir::ttmlir::python diff --git a/python/Util.cpp b/python/Util.cpp index c562306bc..b8bf220de 100644 --- a/python/Util.cpp +++ b/python/Util.cpp @@ -3,6 +3,8 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Bindings/Python/TTMLIRModule.h" +#include +#include namespace mlir::ttmlir::python { @@ -17,25 +19,26 @@ void populateUtilModule(py::module &m) { return source; }); - m.def("get_loc_name", [](MlirLocation _loc) -> std::string { - mlir::Location loc = unwrap(_loc); - if (mlir::isa(loc)) { - mlir::NameLoc nameLoc = mlir::cast(loc); - return nameLoc.getName().str(); - } - return "-"; - }); + m.def("get_loc_name", + [](MlirLocation _loc) -> std::variant { + mlir::Location loc = unwrap(_loc); + if (mlir::isa(loc)) { + mlir::NameLoc nameLoc = mlir::cast(loc); + return nameLoc.getName().str(); + } + return py::none(); + }); - m.def("get_loc_full", [](MlirLocation _loc) -> std::string { - mlir::Location loc = unwrap(_loc); - if (mlir::isa(loc)) { - mlir::FileLineColLoc fileLoc = mlir::cast(loc); - return fileLoc.getFilename().str() + ":" + - std::to_string(fileLoc.getLine()) + ":" + - std::to_string(fileLoc.getColumn()); - } - return "-"; - }); + m.def("get_loc_full", + [](MlirLocation _loc) -> std::variant { + mlir::Location loc = unwrap(_loc); + + std::string locationStr; + llvm::raw_string_ostream output(locationStr); + loc.print(output); + + return locationStr; + }); } } // namespace mlir::ttmlir::python diff --git a/tools/explorer/CMakeLists.txt b/tools/explorer/CMakeLists.txt index 387955854..aac0dd056 100644 --- a/tools/explorer/CMakeLists.txt +++ b/tools/explorer/CMakeLists.txt @@ -3,7 +3,7 @@ include(ExternalProject) set(TT_EXPLORER_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run.py) set(TTMLIR_BUILD_BIN_DIR ${TTMLIR_BINARY_DIR}/bin) -set(MODEL_EXPLORER_VERSION "ca884d5eb3291507e7f4e76776957e231b2d9b6d") +set(MODEL_EXPLORER_VERSION "4ffe3f0c11969b8eb628e1309bcde1990bb470b7") ExternalProject_Add( model-explorer PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer diff --git a/tools/explorer/test/run_tests.py b/tools/explorer/test/run_tests.py index 485104fbb..a3684cf72 100644 --- a/tools/explorer/test/run_tests.py +++ b/tools/explorer/test/run_tests.py @@ -16,11 +16,43 @@ "test/ttmlir/Dialect/TTNN/optimizer/mnist_sharding.mlir", "tools/explorer/test/models/*.mlir", ] +MNIST_SHARDING_PATH = "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir" TEST_EXECUTE_MODEL_PATHS = [ - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", + MNIST_SHARDING_PATH, ] +@pytest.fixture(scope="function", autouse=True) +def start_server(request): + """Start the model explorer server before running tests and stop it after.""" + server_thread = multiprocessing.Process( + target=model_explorer.visualize, + kwargs={"extensions": ["tt_adapter"], "host": HOST, "port": PORT}, + ) + server_thread.start() + + # Wait for the server to start + for _ in range(200): # Try for up to 20 seconds + try: + response = requests.get(f"http://{HOST}:{PORT}/check_health", timeout=1) + if response.status_code == 200: + print("Explorer server started") + break + except requests.ConnectionError: + pass + finally: + time.sleep(0.1) + else: + raise RuntimeError("Server did not start within the expected time") + + # Terminate the server and wait for it to finish. + def server_shutdown(): + server_thread.terminate() + server_thread.join() + + request.addfinalizer(server_shutdown) + + def get_test_files(paths): files = [] for path in paths: @@ -28,7 +60,7 @@ def get_test_files(paths): return files -def send_command(command, model_path, settings): +def send_command(command, model_path, settings={}): cmd = { "extensionId": "tt_adapter", "cmdId": command, @@ -51,7 +83,7 @@ def execute_command(model_path, settings): def wait_for_execution_to_finish(timeout): for _ in range(timeout): try: - response = send_command("status_check", "", {}) + response = send_command("status_check", "") if response.status_code == 200 and response.json().get("graphs")[0].get( "isDone" ): @@ -60,9 +92,7 @@ def wait_for_execution_to_finish(timeout): print(f"Request failed: {e}") raise Exception("Status check request failed") time.sleep(1) - raise RuntimeError( - f"Execution did not finish within {MODEL_EXECUTION_TIMEOUT} seconds" - ) + raise RuntimeError(f"Execution did not finish within {timeout} seconds") def execute_command_and_wait(model_path, settings, timeout): @@ -75,43 +105,18 @@ def execute_command_and_wait(model_path, settings, timeout): assert response["error"] is None -@pytest.fixture(scope="function", autouse=True) -def start_server(request): - server_thread = multiprocessing.Process( - target=model_explorer.visualize, - kwargs={"extensions": ["tt_adapter"], "host": HOST, "port": PORT}, - ) - server_thread.start() - - # Wait for the server to start - for _ in range(200): # Try for up to 20 seconds - try: - response = requests.get(f"http://{HOST}:{PORT}/check_health", timeout=1) - if response.status_code == 200: - print("Explorer server started") - break - except requests.ConnectionError: - pass - finally: - time.sleep(0.1) - else: - raise RuntimeError("Server did not start within the expected time") - - # Terminate the server and wait for it to finish. - def server_shutdown(): - server_thread.terminate() - server_thread.join() - - request.addfinalizer(server_shutdown) - - -@pytest.mark.parametrize("model_path", get_test_files(TEST_LOAD_MODEL_PATHS)) -def test_load_model(model_path): - result = send_command("convert", model_path, {}) +def convert_command_and_assert(model_path): + result = send_command("convert", model_path) assert result.ok if "error" in result.json(): print(result.json()) assert False + return result.json() + + +@pytest.mark.parametrize("model_path", get_test_files(TEST_LOAD_MODEL_PATHS)) +def test_load_model(model_path): + convert_command_and_assert(model_path) @pytest.mark.parametrize("model_path", get_test_files(TEST_EXECUTE_MODEL_PATHS)) @@ -119,22 +124,56 @@ def test_execute_model(model_path): execute_command_and_wait( model_path, {"optimizationPolicy": "DF Sharding"}, timeout=60 ) + convert_command_and_assert(model_path) def test_execute_mnist_l1_interleaved(): execute_command_and_wait( - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", + MNIST_SHARDING_PATH, {"optimizationPolicy": "Greedy L1 Interleaved"}, timeout=60, ) + convert_command_and_assert(MNIST_SHARDING_PATH) def test_execute_mnist_optimizer_disabled(): execute_command_and_wait( - "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", + MNIST_SHARDING_PATH, {"optimizationPolicy": "Optimizer Disabled"}, timeout=60, ) + convert_command_and_assert(MNIST_SHARDING_PATH) + + +def test_execute_mnist_with_overrides(): + overrides = { + 'loc("matmul_1"("MNISTLinear":4294967295:10))__17': { + "named_location": "matmul_1", + "attributes": [ + {"key": "data_type", "value": "f32"}, + {"key": "memory_layout", "value": "tile"}, + {"key": "buffer_type", "value": "dram"}, + {"key": "tensor_memory_layout", "value": "interleaved"}, + {"key": "grid_shape", "value": "[8,8]"}, + ], + } + } + execute_command_and_wait( + MNIST_SHARDING_PATH, + {"optimizationPolicy": "DF Sharding", "overrides": overrides}, + timeout=60, + ) + convert_command_and_assert(MNIST_SHARDING_PATH) + + +def test_execute_and_check_perf_data_exists(): + execute_command_and_wait( + MNIST_SHARDING_PATH, + {"optimizationPolicy": "DF Sharding"}, + timeout=60, + ) + result = convert_command_and_assert(MNIST_SHARDING_PATH) + assert "perf_data" in result["graphs"][0] def test_execute_model_invalid_policy(): diff --git a/tools/explorer/tt_adapter/src/tt_adapter/main.py b/tools/explorer/tt_adapter/src/tt_adapter/main.py index 9d0307d11..d494721ae 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/main.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/main.py @@ -6,16 +6,16 @@ from . import runner, utils, mlir import dataclasses import enum +from ttmlir import optimizer_overrides +OPTIMIZER_DISABLED_POLICY = "Optimizer Disabled" -class OptimizationPolicy(enum.Enum): - DFSharding = "DF Sharding" - GreedyL1Interleaved = "Greedy L1 Interleaved" - BFInterleaved = "BF Interleaved" - OptimizerDisabled = "Optimizer Disabled" - - -OPTIMIZATION_POLICIES = [member.value for member in OptimizationPolicy] +OPTIMIZATION_POLICIES = { + "DF Sharding": optimizer_overrides.MemoryLayoutAnalysisPolicyType.DFSharding, + "Greedy L1 Interleaved": optimizer_overrides.MemoryLayoutAnalysisPolicyType.GreedyL1Interleaved, + "BF Interleaved": optimizer_overrides.MemoryLayoutAnalysisPolicyType.BFInterleaved, + OPTIMIZER_DISABLED_POLICY: False, +} @dataclasses.dataclass @@ -23,6 +23,51 @@ class TTAdapterMetadata(model_explorer.AdapterMetadata): settings: Dict[str, list] = dataclasses.field(default_factory=dict) +def settings_to_overrides(settings, artifacts_dir): + override_handler = optimizer_overrides.OptimizerOverridesHandler() + override_handler.set_system_desc_path(f"{artifacts_dir}/system_desc.ttsys") + + # Parse optimization policy from settings. + optimization_policy = settings.get("optimizationPolicy") + if optimization_policy not in OPTIMIZATION_POLICIES: + raise ValueError(f"Invalid optimization policy selected: {optimization_policy}") + + if optimization_policy == OPTIMIZER_DISABLED_POLICY: + override_handler.set_enable_optimizer(False) + else: + override_handler.set_enable_optimizer(True) + override_handler.set_enable_memory_layout_analysis(True) + override_handler.set_memory_layout_analysis_policy( + OPTIMIZATION_POLICIES[optimization_policy] + ) + + # Convert settings to output layout overrides. + if settings.get("overrides"): + for op_id, overrides in settings["overrides"].items(): + output_layout_override = optimizer_overrides.OutputLayoutOverrideParams() + op_loc = overrides["named_location"] + for attr in overrides["attributes"]: + match attr["key"]: + case "data_type": + output_layout_override.set_data_type_from_str(attr["value"]) + case "memory_layout": + output_layout_override.set_memory_layout_from_str(attr["value"]) + case "buffer_type": + output_layout_override.set_buffer_type_from_str(attr["value"]) + case "tensor_memory_layout": + output_layout_override.set_tensor_memory_layout_from_str( + attr["value"] + ) + case "grid_shape": + output_layout_override.grid = [ + int(x) for x in attr["value"].strip("[]").split(",") + ] + case _: + raise ValueError(f"Invalid override attribute: {attr['key']}") + override_handler.add_output_layout_override(op_loc, output_layout_override) + return override_handler + + class TTAdapter(model_explorer.Adapter): metadata = TTAdapterMetadata( id="tt_adapter", @@ -31,7 +76,7 @@ class TTAdapter(model_explorer.Adapter): source_repo="https://github.com/tenstorrent/tt-mlir/tree/main/tools/explorer/tt_adapter", fileExts=["mlir", "ttir"], settings={ - "optimizationPolicies": OPTIMIZATION_POLICIES, + "optimizationPolicies": list(OPTIMIZATION_POLICIES.keys()), }, ) model_runner = None @@ -44,45 +89,39 @@ def __init__(self): def convert( self, model_path: str, settings: Dict ) -> model_explorer.ModelExplorerGraphs: - perf_trace = None - if optimized_model_path := self.model_runner.get_optimized_model_path(): + if optimized_model_path := self.model_runner.get_optimized_model_path( + model_path + ): print(f"Using optimized model: {optimized_model_path}") - model_path = optimized_model_path - # Get performance results. - perf_trace = self.model_runner.get_perf_trace() + perf_trace = self.model_runner.get_perf_trace(model_path) - module = utils.parse_mlir_file(model_path) + module = utils.parse_mlir_file(optimized_model_path) + + # Convert TTIR to Model Explorer Graphs and Display/Return + graph, perf_data = mlir.build_graph(module, perf_trace) + if perf_data: + # TODO(odjuricic) We should replace the perf_data with overlays once this is fixed on FE. + graph = utils.add_to_dataclass(graph, "perf_data", perf_data.graphsData) + + if overrides := self.model_runner.get_overrides(model_path): + graph = utils.add_to_dataclass(graph, "overrides", overrides) + else: + module = utils.parse_mlir_file(model_path) + + # Convert TTIR to Model Explorer Graphs and Display/Return + graph, _ = mlir.build_graph(module) - # Convert TTIR to Model Explorer Graphs and Display/Return - graph, perf_data = mlir.build_graph(module, perf_trace) - if perf_data: - graph = utils.add_to_dataclass(graph, "perf_data", perf_data.graphsData) return {"graphs": [graph]} def execute( self, model_path: str, settings: Dict ) -> model_explorer.ModelExplorerGraphs: - # TODO(odjuricic, #1178) settings need to be parsed. - # Waiting on override class for this. - - # Parse optimization policy from settings. - optimization_policy = settings.get("optimizationPolicy") - if optimization_policy not in OPTIMIZATION_POLICIES: - raise ValueError( - f"Invalid optimization policy selected: {optimization_policy}" - ) - optimization_policy = OptimizationPolicy(optimization_policy) - - memory_layout_analysis_enabled = True - memory_layout_analysis_policy = optimization_policy.name - - if optimization_policy == OptimizationPolicy.OptimizerDisabled: - memory_layout_analysis_enabled = False - memory_layout_analysis_policy = None - + override_handler = settings_to_overrides( + settings, self.model_runner.get_artifacts_dir() + ) self.model_runner.run( - model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy + model_path, override_handler.to_string(), settings.get("overrides", None) ) return {"graphs": []} diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py index eac036a38..3ceeea731 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -5,19 +5,29 @@ import re from collections import defaultdict from model_explorer import graph_builder, node_data_builder +import dataclasses from ttmlir.dialects import tt, ttnn, ttir from ttmlir import ir, util -def get_loc_str(loc): - try: - res = util.get_loc_name(loc) - if res == "-": - res = util.get_loc_full(loc) - except: - res = "unknown" - return res +# TODO(odjuricic): Also change the KeyValue to support editable instead of this. +def make_editable_kv(kv, editable): + obj = dataclasses.asdict(kv) + obj["editable"] = editable + return dataclasses.make_dataclass( + "KeyValue", ((k, type(v)) for k, v in obj.items()) + )(**obj) + + +def parse_loc_string(loc_str): + """ + This can be replaced by ttmlir.ir.Module.parse, but requires some further wodo to extract the actual location object from the module. + """ + match = re.match(r'^loc\("([^"]+)"', loc_str) + if match: + return match.group(1) + return None class AttrHandler: @@ -395,29 +405,73 @@ def parse_ttnn_ttnn_layout(attr): layout = ttnn.ir.TTNNLayoutAttr.maybe_downcast(attr) result = [] result.append(graph_builder.KeyValue(key="linear", value=str(layout.linear))) - memory_layout = layout.memory_layout_as_int + memory_layout = layout.tensor_memory_layout_as_int if memory_layout is not None: result.append( - graph_builder.KeyValue( - key="memory_layout", - value=str(ttnn.TensorMemoryLayout(memory_layout)), + make_editable_kv( + graph_builder.KeyValue( + key="tensor_memory_layout", + value=str(ttnn.TensorMemoryLayout(memory_layout)), + ), + editable={ + "input_type": "value_list", + "options": [str(o) for o in ttnn.TensorMemoryLayout], + }, ) ) result.append( - graph_builder.KeyValue( - key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + make_editable_kv( + graph_builder.KeyValue( + key="grid_shape", value="x".join(map(str, layout.grid_attr.shape)) + ), + editable={ + "input_type": "grid", + "separator": "x", + "min_value": 1, + "max_value": 100, + "step": 1, + }, ) ) result.append( graph_builder.KeyValue(key="memref_shape", value=str(layout.memref.shape)) ) + buffer_attr = ttnn.ir.BufferTypeAttr.maybe_downcast(layout.memref.memory_space) result.append( - graph_builder.KeyValue(key="memref_rank", value=str(layout.memref.rank)) + make_editable_kv( + graph_builder.KeyValue( + key="buffer_type", value=str(ttnn.BufferType(buffer_attr.value)) + ), + editable={ + "input_type": "value_list", + "options": [str(o) for o in ttnn.BufferType], + }, + ) ) - buffer_attr = ttnn.ir.BufferTypeAttr.maybe_downcast(layout.memref.memory_space) + result.append( - graph_builder.KeyValue( - key="memref_memory_space", value=str(ttnn.BufferType(buffer_attr.value)) + make_editable_kv( + graph_builder.KeyValue( + key="memory_layout", + value=str(ttnn.Layout(layout.memory_layout_as_int)), + ), + editable={ + "input_type": "value_list", + "options": [str(o) for o in ttnn.Layout], + }, + ) + ) + + result.append( + make_editable_kv( + graph_builder.KeyValue( + key="data_type", + value=str(tt.DataType(layout.data_type_as_int)), + ), + editable={ + "input_type": "value_list", + "options": [str(o) for o in tt.DataType], + }, ) ) return result @@ -429,11 +483,12 @@ class OpHandler: def __init__(self, op): self.op = op - self.location = get_loc_str(self.op.location) + self.named_location = util.get_loc_name(self.op.location) + self.full_location = util.get_loc_full(self.op.location) self.id = self._create_unique_id() def _create_unique_id(self): - name = self.location + name = self.full_location if self.full_location else "unknown" name_num = self.name_dict[name] id = name + "__" + str(name_num) self.name_dict[name] += 1 @@ -441,16 +496,60 @@ def _create_unique_id(self): def get_namespace(self, parent_op=None): op = self.op if not parent_op else parent_op - name = get_loc_str(op.location) + name = util.get_loc_name(op.location) if op.parent and op.parent.name != "builtin.module": - return self.get_namespace(op.parent) + "/" + name - return name + parent_name = self.get_namespace(op.parent) + if parent_name: + return parent_name + "/" + name + return name or "" def get_attributes(self): # Parse Op Attributes themselves result = [] for attr in self.op.attributes: result.extend(AttrHandler.parse_attr(attr)) + + # Add location as an attribute + if self.named_location: + result.append( + graph_builder.KeyValue(key="named_location", value=self.named_location) + ) + if self.full_location: + result.append( + graph_builder.KeyValue(key="full_location", value=self.full_location) + ) + + # Add output tensor attriributes to the op itself + if self.op.results: + output_tensor = self.op.result + output_attrs = [] + if isinstance(output_tensor.type, ir.RankedTensorType): + output_attrs = [ + graph_builder.KeyValue( + key="shape", value=str(output_tensor.type.shape) + ), + graph_builder.KeyValue( + key="dtype", value=str(output_tensor.type.element_type) + ), + graph_builder.KeyValue( + key="rank", value=str(output_tensor.type.rank) + ), + ] + if hasattr(output_tensor.type, "encoding") and output_tensor.type.encoding: + if "ttnn_layout" in str(output_tensor.type.encoding): + output_attrs.extend( + AttrHandler.parse_attr( + output_tensor.type.encoding.get_named("ttnn_layout") + ) + ) + else: + # Parse as a standard layout + output_attrs.extend( + AttrHandler.parse_attr( + output_tensor.type.encoding.get_named("tt.layout") + ) + ) + result.extend(output_attrs) return result def make_graph_node(self): @@ -477,6 +576,7 @@ def make_constant_node(self, constant_name): FILTERED_OPS = [ "ttnn.deallocate", "ttnn.get_device", + *EMPTY_OPS, ] @@ -491,9 +591,10 @@ def build_graph(module, perf_trace=None): loc_to_perf = {} if perf_trace is not None: for _, row in perf_trace.iterrows(): - loc = get_loc_str(row["LOC"]) + loc = parse_loc_string(row["LOC"]) assert loc not in loc_to_perf - loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"] + if loc: + loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"] module_op = OpHandler(module.operation) module_attrs = module_op.get_attributes() @@ -511,12 +612,15 @@ def build_graph(module, perf_trace=None): operation = OpHandler(op) graph_node = operation.make_graph_node() - if operation.location in loc_to_perf: + if ( + operation.named_location in loc_to_perf + and operation.op.name not in EMPTY_OPS + ): perf_node_data[operation.id] = node_data_builder.NodeDataResult( - loc_to_perf[operation.location] + loc_to_perf[operation.named_location] ) - if op.name in EMPTY_OPS: + if op.name not in FILTERED_OPS and op.name in EMPTY_OPS: append_later.append(graph_node) elif op.name not in FILTERED_OPS: graph.nodes.append(graph_node) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/runner.py b/tools/explorer/tt_adapter/src/tt_adapter/runner.py index b781ec38d..577d20f80 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/runner.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/runner.py @@ -19,24 +19,40 @@ class ExplorerRunException(Exception): pass +class ModelState: + """ + After a model is compiled and executed we keep track of all additional data that was created. + """ + + # Path to the compiled TTNN IR file. + optimized_model_path = None + # Path to the output directory where ttrt dumps all model files (perf trace, memory state, etc) + model_output_dir = None + # Overrides, changes that the user made to op configurations. + overrides = None + + class ModelRunner: """ ModelRunner is a singleton class used for compilation and running of models. Ensuring only one can be run at a time. This is necessary because the adaptor class is reinitialized on every request from the frontend, so it cannot keep state. """ + # Global static runner state. Initialized once. _instance = None _explorer_artifacts_dir = None _build_dir = None - # State variables. + # Singleton runner state. Initialized on every run. runner_thread = None runner_error = None + log_queue = queue.Queue() # progress should be a number between 0 and 100. progress = 0 - log_queue = queue.Queue() - optimized_model_path = None - ttrt_output_dir = None + + # State for models that have been executed. + # Contains a mapping from model path to ModelState. + model_state = dict() def __new__(cls, *args, **kwargs): if not cls._instance: @@ -70,11 +86,18 @@ def initialize(self): print("ModelRunner initialized.") - def get_optimized_model_path(self): - return self.optimized_model_path + def get_optimized_model_path(self, model_path): + if model_path in self.model_state: + return self.model_state[model_path].optimized_model_path + return None - def get_output_dir(self): - return self.ttrt_output_dir + def get_output_dir(self, model_path): + return self.model_state[model_path].model_output_dir + + def get_overrides(self, model_path): + if model_path in self.model_state: + return self.model_state[model_path].overrides + return None def get_error(self): return self.runner_error @@ -82,6 +105,9 @@ def get_error(self): def get_progress(self): return self.progress + def get_artifacts_dir(self): + return self._explorer_artifacts_dir + def is_busy(self): return self.runner_thread and self.runner_thread.is_alive() @@ -91,28 +117,31 @@ def get_logs(self): logs.append(self.log_queue.get()) return "\n".join(logs) - def reset_state(self): + def reset_state(self, model_path): assert not self.is_busy() self.runner_thread = None - self.log_queue.queue.clear() - self.optimized_model_path = None self.runner_error = None self.progress = 0 - self.ttrt_output_dir = None + self.log_queue.queue.clear() + + if model_path in self.model_state: + del self.model_state[model_path] def log(self, message): print(message) self.log_queue.put(message) - def get_perf_trace(self): - op_perf_file = f"{self.ttrt_output_dir}/perf/ops_perf_results.csv" + def get_perf_trace(self, model_path): + op_perf_file = ( + f"{self.model_state[model_path].model_output_dir}/perf/ops_perf_results.csv" + ) if not os.path.exists(op_perf_file): raise FileNotFoundError(f"Performance file {op_perf_file} not found.") return pd.read_csv(op_perf_file) def run_in_subprocess(self, command): - self.log(f"Running command:\n{''.join(command)}\n") + self.log(f"Running command:\n{' '.join(command)}\n") process = subprocess.Popen( command, @@ -145,22 +174,24 @@ def compile_and_run_wrapper(self, model_path, overrides_string): def compile_and_run(self, model_path, overrides_string): model_name = os.path.basename(model_path) flatbuffer_file = model_name + ".ttnn" - self.ttrt_output_dir = self._explorer_artifacts_dir + "/" + flatbuffer_file + state = self.model_state[model_path] + + state.model_output_dir = self._explorer_artifacts_dir + "/" + flatbuffer_file - if os.path.exists(self.ttrt_output_dir): + if os.path.exists(state.model_output_dir): self.log("Removing artifacts of previous run.") - os.system(f"rm -rf {self.ttrt_output_dir}") + os.system(f"rm -rf {state.model_output_dir}") - os.makedirs(self.ttrt_output_dir) + os.makedirs(state.model_output_dir) # Copy the model to the run directory. - os.system(f"cp {model_path} {self.ttrt_output_dir}") + os.system(f"cp {model_path} {state.model_output_dir}") self.progress = 10 ############################### Compile ################################## ttnn_ir_file = ( - f"{self.ttrt_output_dir}/{model_name.replace('.mlir', '_ttnn.mlir')}" + f"{state.model_output_dir}/{model_name.replace('.mlir', '_ttnn.mlir')}" ) compile_command = [ f"{self._build_dir}/bin/ttmlir-opt", @@ -215,7 +246,7 @@ def compile_and_run(self, model_path, overrides_string): self.log(error) raise ExplorerRunException(error) - perf = self.get_perf_trace() + perf = self.get_perf_trace(model_path) columns = [ "GLOBAL CALL COUNT", "OP CODE", @@ -229,32 +260,21 @@ def compile_and_run(self, model_path, overrides_string): print("Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns") - self.optimized_model_path = ttnn_ir_file + state.optimized_model_path = ttnn_ir_file self.progress = 100 - def run( - self, model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy - ): + def run(self, model_path, compile_options, overrides): # Check if a run is already in progress if self.is_busy(): raise RuntimeError( "A model is already being processed. Please wait for it to finish." ) - self.reset_state() - - options = [ - f'system-desc-path={f"{self._explorer_artifacts_dir}/system_desc.ttsys"}', - "enable-optimizer=true", - f"memory-layout-analysis-enabled={memory_layout_analysis_enabled}", - ] - if memory_layout_analysis_policy: - options.append( - f"memory-layout-analysis-policy={memory_layout_analysis_policy}" - ) - options_string = " ".join(options) + self.reset_state(model_path) + self.model_state[model_path] = ModelState() + self.model_state[model_path].overrides = overrides # Start compile and run in a new thread self.runner_thread = threading.Thread( - target=self.compile_and_run_wrapper, args=(model_path, options_string) + target=self.compile_and_run_wrapper, args=(model_path, compile_options) ) self.runner_thread.start()