Skip to content

Commit

Permalink
Move explorer execution to background (#1504)
Browse files Browse the repository at this point in the history
* Use subprocesses for compile and to_flatbuffer in order for asserts to not kill the server
* Setup high level error messages for execution steps
* Send incremental logs to the frontend
  • Loading branch information
odjuricicTT authored Dec 13, 2024
1 parent b0c0c2b commit 30a7a9e
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 139 deletions.
4 changes: 2 additions & 2 deletions tools/explorer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ include(ExternalProject)
set(TT_EXPLORER_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run.py)
set(TTMLIR_BUILD_BIN_DIR ${TTMLIR_BINARY_DIR}/bin)

set(MODEL_EXPLORER_VERSION "95d79ec933643c3b145537ce6fd5d9f0e735683d")
set(MODEL_EXPLORER_VERSION "d0b53c3b7049fd41ea1caff193706272c399fac9")
ExternalProject_Add(
model-explorer
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer
Expand All @@ -20,7 +20,7 @@ add_custom_target(explorer
COMMAND pip install $<$<CONFIG:Debug>:-e> ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter
COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer/src/model-explorer/src/server/package

DEPENDS TTMLIRPythonModules model-explorer ttrt
DEPENDS TTMLIRPythonModules model-explorer ttrt ttmlir-opt ttmlir-translate
)

add_custom_command(TARGET explorer POST_BUILD
Expand Down
68 changes: 49 additions & 19 deletions tools/explorer/test/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,53 @@ def get_test_files(paths):
return files


def execute_command(model_path, settings):
def send_command(command, model_path, settings):
cmd = {
"extensionId": "tt_adapter",
"cmdId": "execute",
"cmdId": command,
"modelPath": model_path,
"deleteAfterConversion": False,
"settings": settings,
}

result = requests.post(COMMAND_URL, json=cmd)
return requests.post(COMMAND_URL, json=cmd, timeout=10)


def execute_command(model_path, settings):
result = send_command("execute", model_path, settings)
assert result.ok
if "error" in result.json():
print(result.json())
assert False


def wait_for_execution_to_finish(timeout):
for _ in range(timeout):
try:
response = send_command("status_check", "", {})
if response.status_code == 200 and response.json().get("graphs")[0].get(
"isDone"
):
return response.json()
except requests.RequestException as e:
print(f"Request failed: {e}")
raise Exception("Status check request failed")
time.sleep(1)
raise RuntimeError(
f"Execution did not finish within {MODEL_EXECUTION_TIMEOUT} seconds"
)


def execute_command_and_wait(model_path, settings, timeout):
execute_command(model_path, settings)
adapter_response = wait_for_execution_to_finish(timeout)
assert "graphs" in adapter_response
assert len(adapter_response["graphs"]) == 1
response = adapter_response["graphs"][0]
assert response["isDone"]
assert response["error"] is None


@pytest.fixture(scope="function", autouse=True)
def start_server(request):
server_thread = multiprocessing.Process(
Expand All @@ -53,10 +84,11 @@ def start_server(request):
server_thread.start()

# Wait for the server to start
for _ in range(100): # Try for up to 10 seconds
for _ in range(200): # Try for up to 20 seconds
try:
response = requests.get(f"http://{HOST}:{PORT}/check_health")
response = requests.get(f"http://{HOST}:{PORT}/check_health", timeout=1)
if response.status_code == 200:
print("Explorer server started")
break
except requests.ConnectionError:
pass
Expand All @@ -75,15 +107,7 @@ def server_shutdown():

@pytest.mark.parametrize("model_path", get_test_files(TEST_LOAD_MODEL_PATHS))
def test_load_model(model_path):
cmd = {
"extensionId": "tt_adapter",
"cmdId": "convert",
"modelPath": model_path,
"deleteAfterConversion": False,
"settings": {},
}

result = requests.post(COMMAND_URL, json=cmd)
result = send_command("convert", model_path, {})
assert result.ok
if "error" in result.json():
print(result.json())
Expand All @@ -92,25 +116,31 @@ def test_load_model(model_path):

@pytest.mark.parametrize("model_path", get_test_files(TEST_EXECUTE_MODEL_PATHS))
def test_execute_model(model_path):
execute_command(model_path, {"optimizationPolicy": "DF Sharding"})
execute_command_and_wait(
model_path, {"optimizationPolicy": "DF Sharding"}, timeout=60
)


def test_execute_mnist_l1_interleaved():
execute_command(
execute_command_and_wait(
"test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir",
{"optimizationPolicy": "L1 Interleaved"},
timeout=60,
)


def test_execute_mnist_optimizer_disabled():
execute_command(
execute_command_and_wait(
"test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir",
{"optimizationPolicy": "Optimizer Disabled"},
timeout=60,
)


def test_execute_model_invalid_policy():
with pytest.raises(AssertionError):
execute_command(
TEST_EXECUTE_MODEL_PATHS[0], {"optimizationPolicy": "Invalid Policy"}
execute_command_and_wait(
TEST_EXECUTE_MODEL_PATHS[0],
{"optimizationPolicy": "Invalid Policy"},
timeout=60,
)
33 changes: 28 additions & 5 deletions tools/explorer/tt_adapter/src/tt_adapter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,19 @@ def __init__(self):
def convert(
self, model_path: str, settings: Dict
) -> model_explorer.ModelExplorerGraphs:
perf_trace = None
if optimized_model_path := self.model_runner.get_optimized_model_path():
print(f"Using optimized model: {optimized_model_path}")
model_path = optimized_model_path

# Get performance results.
perf_trace = self.model_runner.get_perf_trace()

module = utils.parse_mlir_file(model_path)

# Convert TTIR to Model Explorer Graphs and Display/Return
graph = mlir.build_graph(module)
return {"graphs": [graph]}
graph, perf_data = mlir.build_graph(module, perf_trace)
return {"graphs": [graph], "perf_data": perf_data}

def execute(
self, model_path: str, settings: Dict
Expand All @@ -70,9 +78,24 @@ def execute(
memory_layout_analysis_enabled = False
memory_layout_analysis_policy = None

perf_data = self.model_runner.run(
self.model_runner.run(
model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy
)

# TODO(odjuricic, #933) Parse TTNN IR and return the post optimized graph.
return utils.to_adapter_format({"perf_data": perf_data})
return {"graphs": []}

def status_check(self, model_path: str, settings: Dict) -> bool:
done = not self.model_runner.is_busy()
logs = self.model_runner.get_logs()
progress = self.model_runner.get_progress()
error = self.model_runner.get_error()

return utils.to_adapter_format(
{
"isDone": done,
"progress": progress,
"total": 100,
"error": error,
"stdout": logs,
}
)
83 changes: 52 additions & 31 deletions tools/explorer/tt_adapter/src/tt_adapter/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0
# Utility library for parsing MLIR

import re
from collections import defaultdict
from model_explorer import graph_builder
from model_explorer import graph_builder, node_data_builder

from ttmlir.dialects import tt, ttnn, ttir
from ttmlir import ir
Expand All @@ -15,7 +15,10 @@ def get_loc_str(loc):
# Constant loc( at the start of the location and ) at the end. Can just strip these characters
loc = str(loc)
if loc.startswith("loc(") and loc.endswith(")"):
res = str(loc)[4:-1]
# Fuzzy parse first string inside location
# 'loc("matmul_1"("MNISTLinear":4294967295:10))' -> matmul_1
# TODO(odjuricic) Need to have this pybinded.
res = re.search(r'"([^"]+)"', loc).group(1)
else:
res = loc # This is a fallback to just visualize / see what the loc is if not processable.
except:
Expand Down Expand Up @@ -427,14 +430,19 @@ def parse_ttnn_ttnn_layout(attr):


class OpHandler:
# Help create unique ids for ops with the same location name.
name_dict = defaultdict(int)

def __init__(self, op):
self.op = op
self.location = get_loc_str(self.op.location)
self.id = self._create_unique_id()

def get_id(self, names: defaultdict):
name = get_loc_str(self.op.location)
name_num = names[name]
def _create_unique_id(self):
name = self.location
name_num = self.name_dict[name]
id = name + "__" + str(name_num)
names[name] += 1
self.name_dict[name] += 1
return id

def get_namespace(self, parent_op=None):
Expand All @@ -451,17 +459,17 @@ def get_attributes(self):
result.extend(AttrHandler.parse_attr(attr))
return result

def make_graph_node(self, name_dict):
def make_graph_node(self):
return graph_builder.GraphNode(
id=self.get_id(name_dict),
id=self.id,
label=self.op.name,
namespace=self.get_namespace(),
attrs=self.get_attributes(),
)

def make_constant_node(self, name_dict, constant_name):
def make_constant_node(self, constant_name):
return graph_builder.GraphNode(
id=self.get_id(name_dict),
id=self._create_unique_id(),
label=constant_name,
namespace=self.get_namespace(),
)
Expand All @@ -478,28 +486,21 @@ def make_constant_node(self, name_dict, constant_name):
]


def get_locs(module):
name_dict = defaultdict(int)

for op in module.body.operations:
for region in op.regions:
for block in region.blocks:
for op in block.operations:
op = OpHandler(op)
_id = op.get_id(name_dict)
# This will now populate name_dict with all of the locations that are relevant

# The keys will be all the unique locations, and the values will be the number of times that location appears
return name_dict


def build_graph(module):
name_dict = defaultdict(int)
def build_graph(module, perf_trace=None):
output_connections = defaultdict(int)
graph = graph_builder.Graph(id="tt-graph")

op_to_graph_node = {}

# Prepare perf data for color overlay
perf_node_data = {}
loc_to_perf = {}
if perf_trace is not None:
for _, row in perf_trace.iterrows():
loc = get_loc_str(row["LOC"])
assert loc not in loc_to_perf
loc_to_perf[loc] = row["DEVICE FW DURATION [ns]"]

module_op = OpHandler(module.operation)
module_attrs = module_op.get_attributes()
module_attrs = dict((attr.key, attr.value) for attr in module_attrs)
Expand All @@ -514,7 +515,12 @@ def build_graph(module):
for op in block.operations:
# Create all the nodes and constants in the first pass.
operation = OpHandler(op)
graph_node = operation.make_graph_node(name_dict)
graph_node = operation.make_graph_node()

if operation.location in loc_to_perf:
perf_node_data[operation.id] = node_data_builder.NodeDataResult(
loc_to_perf[operation.location]
)

if op.name in EMPTY_OPS:
append_later.append(graph_node)
Expand All @@ -531,7 +537,7 @@ def build_graph(module):

# This is a constant and we need to create a node for it.
operand_node = operation.make_constant_node(
name_dict, operand.get_name()
operand.get_name()
)
graph.nodes.append(operand_node)
op_to_graph_node[operand] = operand_node
Expand Down Expand Up @@ -599,5 +605,20 @@ def build_graph(module):
)
)
output_connections[source_node.id] += 1

# Add performance data to the graph color overlay, if it exists
overlay_data = None
if perf_node_data:
gradient = [
node_data_builder.GradientItem(stop=0, bgColor="yellow"),
node_data_builder.GradientItem(stop=1, bgColor="red"),
]
graph_node_data = node_data_builder.GraphNodeData(
results=perf_node_data, gradient=gradient
)
overlay_data = node_data_builder.ModelNodeData(
graphsData={"tt-graph": graph_node_data}
)

graph.groupNodeAttributes = group_node_attrs
return graph
return graph, overlay_data
Loading

0 comments on commit 30a7a9e

Please sign in to comment.