diff --git a/tools/explorer/CMakeLists.txt b/tools/explorer/CMakeLists.txt index 44613b267..e0128691a 100644 --- a/tools/explorer/CMakeLists.txt +++ b/tools/explorer/CMakeLists.txt @@ -3,7 +3,7 @@ include(ExternalProject) set(TT_EXPLORER_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run.py) set(TTMLIR_BUILD_BIN_DIR ${TTMLIR_BINARY_DIR}/bin) -set(MODEL_EXPLORER_VERSION "95d79ec933643c3b145537ce6fd5d9f0e735683d") +set(MODEL_EXPLORER_VERSION "d0b53c3b7049fd41ea1caff193706272c399fac9") ExternalProject_Add( model-explorer PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer @@ -20,7 +20,7 @@ add_custom_target(explorer COMMAND pip install $<$:-e> ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer/src/model-explorer/src/server/package - DEPENDS TTMLIRPythonModules model-explorer ttrt + DEPENDS TTMLIRPythonModules model-explorer ttrt ttmlir-opt ttmlir-translate ) add_custom_command(TARGET explorer POST_BUILD diff --git a/tools/explorer/test/run_tests.py b/tools/explorer/test/run_tests.py index 91167e86a..75925a44e 100644 --- a/tools/explorer/test/run_tests.py +++ b/tools/explorer/test/run_tests.py @@ -28,22 +28,53 @@ def get_test_files(paths): return files -def execute_command(model_path, settings): +def send_command(command, model_path, settings): cmd = { "extensionId": "tt_adapter", - "cmdId": "execute", + "cmdId": command, "modelPath": model_path, "deleteAfterConversion": False, "settings": settings, } - result = requests.post(COMMAND_URL, json=cmd) + return requests.post(COMMAND_URL, json=cmd, timeout=10) + + +def execute_command(model_path, settings): + result = send_command("execute", model_path, settings) assert result.ok if "error" in result.json(): print(result.json()) assert False +def wait_for_execution_to_finish(timeout): + for _ in range(timeout): + try: + response = send_command("status_check", "", {}) + if response.status_code == 200 and response.json().get("graphs")[0].get( + "isDone" + ): + return response.json() + except requests.RequestException as e: + print(f"Request failed: {e}") + raise Exception("Status check request failed") + time.sleep(1) + raise RuntimeError( + f"Execution did not finish within {MODEL_EXECUTION_TIMEOUT} seconds" + ) + + +def execute_command_and_wait(model_path, settings, timeout): + execute_command(model_path, settings) + adapter_response = wait_for_execution_to_finish(timeout) + assert "graphs" in adapter_response + assert len(adapter_response["graphs"]) == 1 + response = adapter_response["graphs"][0] + assert response["isDone"] + assert response["error"] is None + + @pytest.fixture(scope="function", autouse=True) def start_server(request): server_thread = multiprocessing.Process( @@ -53,10 +84,11 @@ def start_server(request): server_thread.start() # Wait for the server to start - for _ in range(100): # Try for up to 10 seconds + for _ in range(200): # Try for up to 20 seconds try: - response = requests.get(f"http://{HOST}:{PORT}/check_health") + response = requests.get(f"http://{HOST}:{PORT}/check_health", timeout=1) if response.status_code == 200: + print("Explorer server started") break except requests.ConnectionError: pass @@ -75,15 +107,7 @@ def server_shutdown(): @pytest.mark.parametrize("model_path", get_test_files(TEST_LOAD_MODEL_PATHS)) def test_load_model(model_path): - cmd = { - "extensionId": "tt_adapter", - "cmdId": "convert", - "modelPath": model_path, - "deleteAfterConversion": False, - "settings": {}, - } - - result = requests.post(COMMAND_URL, json=cmd) + result = send_command("convert", model_path, {}) assert result.ok if "error" in result.json(): print(result.json()) @@ -92,25 +116,31 @@ def test_load_model(model_path): @pytest.mark.parametrize("model_path", get_test_files(TEST_EXECUTE_MODEL_PATHS)) def test_execute_model(model_path): - execute_command(model_path, {"optimizationPolicy": "DF Sharding"}) + execute_command_and_wait( + model_path, {"optimizationPolicy": "DF Sharding"}, timeout=60 + ) def test_execute_mnist_l1_interleaved(): - execute_command( + execute_command_and_wait( "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", {"optimizationPolicy": "L1 Interleaved"}, + timeout=60, ) def test_execute_mnist_optimizer_disabled(): - execute_command( + execute_command_and_wait( "test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir", {"optimizationPolicy": "Optimizer Disabled"}, + timeout=60, ) def test_execute_model_invalid_policy(): with pytest.raises(AssertionError): - execute_command( - TEST_EXECUTE_MODEL_PATHS[0], {"optimizationPolicy": "Invalid Policy"} + execute_command_and_wait( + TEST_EXECUTE_MODEL_PATHS[0], + {"optimizationPolicy": "Invalid Policy"}, + timeout=60, ) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/main.py b/tools/explorer/tt_adapter/src/tt_adapter/main.py index 3876a0911..53ea68669 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/main.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/main.py @@ -43,11 +43,19 @@ def __init__(self): def convert( self, model_path: str, settings: Dict ) -> model_explorer.ModelExplorerGraphs: + perf_trace = None + if optimized_model_path := self.model_runner.get_optimized_model_path(): + print(f"Using optimized model: {optimized_model_path}") + model_path = optimized_model_path + + # Get performance results. + perf_trace = self.model_runner.get_perf_trace() + module = utils.parse_mlir_file(model_path) # Convert TTIR to Model Explorer Graphs and Display/Return - graph = mlir.build_graph(module) - return {"graphs": [graph]} + graph, perf_data = mlir.build_graph(module, perf_trace) + return {"graphs": [graph], "perf_data": perf_data} def execute( self, model_path: str, settings: Dict @@ -70,9 +78,24 @@ def execute( memory_layout_analysis_enabled = False memory_layout_analysis_policy = None - perf_data = self.model_runner.run( + self.model_runner.run( model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy ) - # TODO(odjuricic, #933) Parse TTNN IR and return the post optimized graph. - return utils.to_adapter_format({"perf_data": perf_data}) + return {"graphs": []} + + def status_check(self, model_path: str, settings: Dict) -> bool: + done = not self.model_runner.is_busy() + logs = self.model_runner.get_logs() + progress = self.model_runner.get_progress() + error = self.model_runner.get_error() + + return utils.to_adapter_format( + { + "isDone": done, + "progress": progress, + "total": 100, + "error": error, + "stdout": logs, + } + ) diff --git a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py index 5dc67de70..843606b06 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/mlir.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/mlir.py @@ -2,9 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 # Utility library for parsing MLIR - +import re from collections import defaultdict -from model_explorer import graph_builder +from model_explorer import graph_builder, node_data_builder from ttmlir.dialects import tt, ttnn, ttir from ttmlir import ir @@ -15,7 +15,10 @@ def get_loc_str(loc): # Constant loc( at the start of the location and ) at the end. Can just strip these characters loc = str(loc) if loc.startswith("loc(") and loc.endswith(")"): - res = str(loc)[4:-1] + # Fuzzy parse first string inside location + # 'loc("matmul_1"("MNISTLinear":4294967295:10))' -> matmul_1 + # TODO(odjuricic) Need to have this pybinded. + res = re.search(r'"([^"]+)"', loc).group(1) else: res = loc # This is a fallback to just visualize / see what the loc is if not processable. except: @@ -427,14 +430,19 @@ def parse_ttnn_ttnn_layout(attr): class OpHandler: + # Help create unique ids for ops with the same location name. + name_dict = defaultdict(int) + def __init__(self, op): self.op = op + self.location = get_loc_str(self.op.location) + self.id = self._create_unique_id() - def get_id(self, names: defaultdict): - name = get_loc_str(self.op.location) - name_num = names[name] + def _create_unique_id(self): + name = self.location + name_num = self.name_dict[name] id = name + "__" + str(name_num) - names[name] += 1 + self.name_dict[name] += 1 return id def get_namespace(self, parent_op=None): @@ -451,17 +459,17 @@ def get_attributes(self): result.extend(AttrHandler.parse_attr(attr)) return result - def make_graph_node(self, name_dict): + def make_graph_node(self): return graph_builder.GraphNode( - id=self.get_id(name_dict), + id=self.id, label=self.op.name, namespace=self.get_namespace(), attrs=self.get_attributes(), ) - def make_constant_node(self, name_dict, constant_name): + def make_constant_node(self, constant_name): return graph_builder.GraphNode( - id=self.get_id(name_dict), + id=self._create_unique_id(), label=constant_name, namespace=self.get_namespace(), ) @@ -478,28 +486,21 @@ def make_constant_node(self, name_dict, constant_name): ] -def get_locs(module): - name_dict = defaultdict(int) - - for op in module.body.operations: - for region in op.regions: - for block in region.blocks: - for op in block.operations: - op = OpHandler(op) - _id = op.get_id(name_dict) - # This will now populate name_dict with all of the locations that are relevant - - # The keys will be all the unique locations, and the values will be the number of times that location appears - return name_dict - - -def build_graph(module): - name_dict = defaultdict(int) +def build_graph(module, perf_trace=None): output_connections = defaultdict(int) graph = graph_builder.Graph(id="tt-graph") op_to_graph_node = {} + # Prepare perf data for color overlay + perf_node_data = {} + loc_to_perf = {} + if perf_trace is not None: + for _, row in perf_trace.iterrows(): + loc = get_loc_str(row["LOC"]) + assert loc not in loc_to_perf + loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"] + module_op = OpHandler(module.operation) module_attrs = module_op.get_attributes() module_attrs = dict((attr.key, attr.value) for attr in module_attrs) @@ -514,7 +515,12 @@ def build_graph(module): for op in block.operations: # Create all the nodes and constants in the first pass. operation = OpHandler(op) - graph_node = operation.make_graph_node(name_dict) + graph_node = operation.make_graph_node() + + if operation.location in loc_to_perf: + perf_node_data[operation.id] = node_data_builder.NodeDataResult( + loc_to_perf[operation.location] + ) if op.name in EMPTY_OPS: append_later.append(graph_node) @@ -531,7 +537,7 @@ def build_graph(module): # This is a constant and we need to create a node for it. operand_node = operation.make_constant_node( - name_dict, operand.get_name() + operand.get_name() ) graph.nodes.append(operand_node) op_to_graph_node[operand] = operand_node @@ -599,5 +605,20 @@ def build_graph(module): ) ) output_connections[source_node.id] += 1 + + # Add performance data to the graph color overlay, if it exists + overlay_data = None + if perf_node_data: + gradient = [ + node_data_builder.GradientItem(stop=0, bgColor="yellow"), + node_data_builder.GradientItem(stop=1, bgColor="red"), + ] + graph_node_data = node_data_builder.GraphNodeData( + results=perf_node_data, gradient=gradient + ) + overlay_data = node_data_builder.ModelNodeData( + graphsData={"tt-graph": graph_node_data} + ) + graph.groupNodeAttributes = group_node_attrs - return graph + return graph, overlay_data diff --git a/tools/explorer/tt_adapter/src/tt_adapter/runner.py b/tools/explorer/tt_adapter/src/tt_adapter/runner.py index 205944acd..b781ec38d 100644 --- a/tools/explorer/tt_adapter/src/tt_adapter/runner.py +++ b/tools/explorer/tt_adapter/src/tt_adapter/runner.py @@ -11,7 +11,12 @@ import ttmlir.passes from . import utils, mlir import pandas as pd -from model_explorer import node_data_builder +import threading +import queue + + +class ExplorerRunException(Exception): + pass class ModelRunner: @@ -22,6 +27,16 @@ class ModelRunner: _instance = None _explorer_artifacts_dir = None + _build_dir = None + + # State variables. + runner_thread = None + runner_error = None + # progress should be a number between 0 and 100. + progress = 0 + log_queue = queue.Queue() + optimized_model_path = None + ttrt_output_dir = None def __new__(cls, *args, **kwargs): if not cls._instance: @@ -41,6 +56,7 @@ def initialize(self): # TODO(odjuricic, #1200) ttrt perf breaks if artifacts dir is changed from default. # self._explorer_artifacts_dir = os.environ['TT_MLIR_HOME'] + '/explorer-artifacts' self._explorer_artifacts_dir = os.environ["TT_MLIR_HOME"] + "/ttrt-artifacts" + self._build_dir = os.environ["TT_MLIR_HOME"] + "/build" os.makedirs(self._explorer_artifacts_dir, exist_ok=True) # Save the system descriptor. @@ -48,87 +64,158 @@ def initialize(self): args={ "--save-artifacts": True, "--artifact-dir": self._explorer_artifacts_dir, + "--quiet": True, } )() - def run( - self, model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy - ): - # TODO(odjuricic, #1174) This should be in a separete thread later. - model_name = os.path.basename(model_path).split(".")[0] + print("ModelRunner initialized.") - options = [ - f'system-desc-path={f"{self._explorer_artifacts_dir}/system_desc.ttsys"}', - "enable-optimizer=true", - f"memory-layout-analysis-enabled={memory_layout_analysis_enabled}", - ] - if memory_layout_analysis_policy: - options.append( - f"memory-layout-analysis-policy={memory_layout_analysis_policy}" - ) - options_string = " ".join(options) + def get_optimized_model_path(self): + return self.optimized_model_path - module = utils.parse_mlir_file(model_path) + def get_output_dir(self): + return self.ttrt_output_dir - # Collect unique locations - name_dict = mlir.get_locs(module) + def get_error(self): + return self.runner_error - try: - print("Running MLIR compile: TTIR to TTNN Backend Pipeline") - print("With options: ", options_string) - # TODO(odjuricic) When we hit compiler assert it terminates the process. We should catch this and return an error to the frontend. - ttmlir.passes.ttir_to_ttnn_backend_pipeline(module, options_string) - except Exception as e: - print("Error running MLIR compile: TTIR to TTNN Backend Pipeline") - raise e + def get_progress(self): + return self.progress - # TODO(odjuricic) Move this file somewhere else, but keep the name. - flatbuffer_file = model_name + ".ttnn" - try: - print("Running TTNN to Flatbuffer File") - ttmlir.passes.ttnn_to_flatbuffer_file(module, flatbuffer_file, {}) - except Exception as e: - print("Error running TTNN to Flatbuffer File") - raise e + def is_busy(self): + return self.runner_thread and self.runner_thread.is_alive() - # TODO(odjuricic) validate that the module was converted to TTNN without fail + def get_logs(self): + logs = [] + while not self.log_queue.empty(): + logs.append(self.log_queue.get()) + return "\n".join(logs) - if os.path.exists(f"{self._explorer_artifacts_dir}/{flatbuffer_file}"): - print("Removing artifacts of previous run.") - os.system(f"rm -rf {self._explorer_artifacts_dir}/{flatbuffer_file}") + def reset_state(self): + assert not self.is_busy() + self.runner_thread = None + self.log_queue.queue.clear() + self.optimized_model_path = None + self.runner_error = None + self.progress = 0 + self.ttrt_output_dir = None - ttrt_perf_command = " ".join( - [ - "ttrt", - "perf", - flatbuffer_file, - f"--artifact-dir={self._explorer_artifacts_dir}", - ] - ) + def log(self, message): + print(message) + self.log_queue.put(message) + + def get_perf_trace(self): + op_perf_file = f"{self.ttrt_output_dir}/perf/ops_perf_results.csv" + if not os.path.exists(op_perf_file): + raise FileNotFoundError(f"Performance file {op_perf_file} not found.") + + return pd.read_csv(op_perf_file) + + def run_in_subprocess(self, command): + self.log(f"Running command:\n{''.join(command)}\n") - print("Running", ttrt_perf_command) process = subprocess.Popen( - ttrt_perf_command, - shell=True, + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, ) for line in process.stdout: - print(line, end="") + self.log(line.strip()) process.stdout.close() process.wait() - if process.returncode != 0: - print(f"Error: TTRT process exited with code {process.returncode}") - raise RuntimeError("Error running TTRT") + return process - op_perf_file = f"{self._explorer_artifacts_dir}/{flatbuffer_file}/perf/ops_perf_results.csv" - if not os.path.exists(op_perf_file): - raise FileNotFoundError(f"Performance file {op_perf_file} not found.") - perf = pd.read_csv(op_perf_file) + def compile_and_run_wrapper(self, model_path, overrides_string): + try: + self.compile_and_run(model_path, overrides_string) + except ExplorerRunException as e: + self.runner_error = str(e) + raise e + except Exception as e: + self.runner_error = "An unexpected error occurred: " + str(e) + self.log(self.runner_error) + raise e + finally: + self.progress = 100 + + def compile_and_run(self, model_path, overrides_string): + model_name = os.path.basename(model_path) + flatbuffer_file = model_name + ".ttnn" + self.ttrt_output_dir = self._explorer_artifacts_dir + "/" + flatbuffer_file + + if os.path.exists(self.ttrt_output_dir): + self.log("Removing artifacts of previous run.") + os.system(f"rm -rf {self.ttrt_output_dir}") + + os.makedirs(self.ttrt_output_dir) + # Copy the model to the run directory. + os.system(f"cp {model_path} {self.ttrt_output_dir}") + + self.progress = 10 + + ############################### Compile ################################## + + ttnn_ir_file = ( + f"{self.ttrt_output_dir}/{model_name.replace('.mlir', '_ttnn.mlir')}" + ) + compile_command = [ + f"{self._build_dir}/bin/ttmlir-opt", + f"--ttir-to-ttnn-backend-pipeline={overrides_string}", + model_path, + "-o", + ttnn_ir_file, + "--mlir-print-debuginfo", + ] + + self.log("Running compile TTIR to TTNN Backend Pipeline") + self.log("With options: " + overrides_string) + + compile_process = self.run_in_subprocess(compile_command) + if compile_process.returncode != 0: + error = "Error running compile TTIR to TTNN Backend Pipeline" + self.log(error) + raise ExplorerRunException(error) + self.progress = 20 + + ############################## Translate ################################# + + to_flatbuffer_command = [ + f"{self._build_dir}/bin/ttmlir-translate", + "--ttnn-to-flatbuffer", + ttnn_ir_file, + "-o", + flatbuffer_file, + ] + + self.log("Running TTNN to Flatbuffer File") + translate_process = self.run_in_subprocess(to_flatbuffer_command) + if translate_process.returncode != 0: + error = "Error while running TTNN to Flatbuffer File" + self.log(error) + raise ExplorerRunException(error) + self.progress = 30 + + ############################## TTRT Perf ################################# + + ttrt_perf_command = [ + "ttrt", + "perf", + flatbuffer_file, + f"--artifact-dir={self._explorer_artifacts_dir}", + ] + + ttrt_process = self.run_in_subprocess(ttrt_perf_command) + + if ttrt_process.returncode != 0: + error = "Error while running TTRT perf" + self.log(error) + raise ExplorerRunException(error) + + perf = self.get_perf_trace() columns = [ "GLOBAL CALL COUNT", "OP CODE", @@ -140,29 +227,34 @@ def run( perf = perf[columns] print(perf) - print(f"Total device duration: {perf['DEVICE FW DURATION [ns]'].sum()}ns") - - # Create the node_data type here - timing_data = list(zip(perf["LOC"], perf["DEVICE FW DURATION [ns]"])) - results = {} - for loc, duration in timing_data: - loc = mlir.get_loc_str(loc).replace("'", '"') - if loc in name_dict: - for i in range(name_dict[loc]): - results[f"{loc}__{i}"] = node_data_builder.NodeDataResult( - value=duration - ) - else: - print( - f"Location {loc} not found in graph, ops data for this op was not reported." - ) - - gradient = [ - node_data_builder.GradientItem(stop=0, bgColor="yellow"), - node_data_builder.GradientItem(stop=1, bgColor="red"), - ] + print("Total device duration: ", perf["DEVICE FW DURATION [ns]"].sum(), "ns") + + self.optimized_model_path = ttnn_ir_file + self.progress = 100 - data = node_data_builder.GraphNodeData(results=results, gradient=gradient) + def run( + self, model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy + ): + # Check if a run is already in progress + if self.is_busy(): + raise RuntimeError( + "A model is already being processed. Please wait for it to finish." + ) + self.reset_state() + + options = [ + f'system-desc-path={f"{self._explorer_artifacts_dir}/system_desc.ttsys"}', + "enable-optimizer=true", + f"memory-layout-analysis-enabled={memory_layout_analysis_enabled}", + ] + if memory_layout_analysis_policy: + options.append( + f"memory-layout-analysis-policy={memory_layout_analysis_policy}" + ) + options_string = " ".join(options) - res = node_data_builder.ModelNodeData(graphsData={"tt-graph": data}) - return res + # Start compile and run in a new thread + self.runner_thread = threading.Thread( + target=self.compile_and_run_wrapper, args=(model_path, options_string) + ) + self.runner_thread.start()