Skip to content

Commit

Permalink
Add override handler
Browse files Browse the repository at this point in the history
  • Loading branch information
odjuricicTT committed Dec 16, 2024
1 parent 1c1cb18 commit e959b38
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 38 deletions.
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy(
}
void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy(
MemoryLayoutAnalysisPolicyType value) {
enableMemoryLayoutAnalysisPolicy = true;
memoryLayoutAnalysisPolicy = value;
}

Expand Down
41 changes: 21 additions & 20 deletions tools/explorer/tt_adapter/src/tt_adapter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from . import runner, utils, mlir
import dataclasses
import enum
from ttmlir import optimizer_overrides

OPTIMIZER_DISABLED_POLICY = "Optimizer Disabled"

class OptimizationPolicy(enum.Enum):
DFSharding = "DF Sharding"
L1Interleaved = "L1 Interleaved"
OptimizerDisabled = "Optimizer Disabled"


OPTIMIZATION_POLICIES = [member.value for member in OptimizationPolicy]
OPTIMIZATION_POLICIES = {
"DF Sharding": optimizer_overrides.MemoryLayoutAnalysisPolicyType.DFSharding,
"L1 Interleaved": optimizer_overrides.MemoryLayoutAnalysisPolicyType.L1Interleaved,
OPTIMIZER_DISABLED_POLICY: False,
}


@dataclasses.dataclass
Expand All @@ -30,7 +30,7 @@ class TTAdapter(model_explorer.Adapter):
source_repo="https://github.com/tenstorrent/tt-mlir/tree/main/tools/explorer/tt_adapter",
fileExts=["mlir", "ttir"],
settings={
"optimizationPolicies": OPTIMIZATION_POLICIES,
"optimizationPolicies": list(OPTIMIZATION_POLICIES.keys()),
},
)
model_runner = None
Expand Down Expand Up @@ -60,27 +60,28 @@ def convert(
def execute(
self, model_path: str, settings: Dict
) -> model_explorer.ModelExplorerGraphs:
# TODO(odjuricic, #1178) settings need to be parsed.
# Waiting on override class for this.
override_handler = optimizer_overrides.OptimizerOverridesHandler()
override_handler.set_system_desc_path(
f"{self.model_runner.get_artifacts_dir()}/system_desc.ttsys"
)

# Parse optimization policy from settings.
optimization_policy = settings.get("optimizationPolicy")
if optimization_policy not in OPTIMIZATION_POLICIES:
raise ValueError(
f"Invalid optimization policy selected: {optimization_policy}"
)
optimization_policy = OptimizationPolicy(optimization_policy)

memory_layout_analysis_enabled = True
memory_layout_analysis_policy = optimization_policy.name

if optimization_policy == OptimizationPolicy.OptimizerDisabled:
memory_layout_analysis_enabled = False
memory_layout_analysis_policy = None
if optimization_policy == OPTIMIZER_DISABLED_POLICY:
override_handler.set_enable_optimizer(False)
else:
override_handler.set_enable_optimizer(True)
override_handler.set_enable_memory_layout_analysis(True)
override_handler.set_memory_layout_analysis_policy(
OPTIMIZATION_POLICIES[optimization_policy]
)

self.model_runner.run(
model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy
)
self.model_runner.run(model_path, override_handler.to_string())

return {"graphs": []}

Expand Down
53 changes: 50 additions & 3 deletions tools/explorer/tt_adapter/src/tt_adapter/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@
import re
from collections import defaultdict
from model_explorer import graph_builder, node_data_builder
import dataclasses

from ttmlir.dialects import tt, ttnn, ttir
from ttmlir import ir


def make_editable_kv(kv, editable):
obj = dataclasses.asdict(kv)
obj["editable"] = editable
return dataclasses.make_dataclass(
"KeyValue", ((k, type(v)) for k, v in obj.items())
)(**obj)


def get_loc_str(loc):
try:
# Constant loc( at the start of the location and ) at the end. Can just strip these characters
Expand Down Expand Up @@ -404,9 +413,15 @@ def parse_ttnn_ttnn_layout(attr):
memory_layout = layout.memory_layout_as_int
if memory_layout is not None:
result.append(
graph_builder.KeyValue(
key="memory_layout",
value=str(ttnn.TensorMemoryLayout(memory_layout)),
make_editable_kv(
graph_builder.KeyValue(
key="Tensor Memory Layout",
value=str(ttnn.TensorMemoryLayout(memory_layout)),
),
editable={
"input_type": "value_list",
"options": [str(o) for o in ttnn.TensorMemoryLayout],
},
)
)
result.append(
Expand Down Expand Up @@ -457,6 +472,38 @@ def get_attributes(self):
result = []
for attr in self.op.attributes:
result.extend(AttrHandler.parse_attr(attr))

# Add output tensor properties to the op itself
if self.op.results:
output_tensor = self.op.result
output_attrs = []
if isinstance(output_tensor.type, ir.RankedTensorType):
output_attrs = [
graph_builder.KeyValue(
key="shape", value=str(output_tensor.type.shape)
),
graph_builder.KeyValue(
key="dtype", value=str(output_tensor.type.element_type)
),
graph_builder.KeyValue(
key="rank", value=str(output_tensor.type.rank)
),
]
if hasattr(output_tensor.type, "encoding") and output_tensor.type.encoding:
if "ttnn_layout" in str(output_tensor.type.encoding):
output_attrs.extend(
AttrHandler.parse_attr(
output_tensor.type.encoding.get_named("ttnn_layout")
)
)
else:
# Parse as a standard layout
output_attrs.extend(
AttrHandler.parse_attr(
output_tensor.type.encoding.get_named("tt.layout")
)
)
result.extend(output_attrs)
return result

def make_graph_node(self):
Expand Down
20 changes: 5 additions & 15 deletions tools/explorer/tt_adapter/src/tt_adapter/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def get_error(self):
def get_progress(self):
return self.progress

def get_artifacts_dir(self):
return self._explorer_artifacts_dir

def is_busy(self):
return self.runner_thread and self.runner_thread.is_alive()

Expand Down Expand Up @@ -232,29 +235,16 @@ def compile_and_run(self, model_path, overrides_string):
self.optimized_model_path = ttnn_ir_file
self.progress = 100

def run(
self, model_path, memory_layout_analysis_enabled, memory_layout_analysis_policy
):
def run(self, model_path, compile_options):
# Check if a run is already in progress
if self.is_busy():
raise RuntimeError(
"A model is already being processed. Please wait for it to finish."
)
self.reset_state()

options = [
f'system-desc-path={f"{self._explorer_artifacts_dir}/system_desc.ttsys"}',
"enable-optimizer=true",
f"memory-layout-analysis-enabled={memory_layout_analysis_enabled}",
]
if memory_layout_analysis_policy:
options.append(
f"memory-layout-analysis-policy={memory_layout_analysis_policy}"
)
options_string = " ".join(options)

# Start compile and run in a new thread
self.runner_thread = threading.Thread(
target=self.compile_and_run_wrapper, args=(model_path, options_string)
target=self.compile_and_run_wrapper, args=(model_path, compile_options)
)
self.runner_thread.start()

0 comments on commit e959b38

Please sign in to comment.