Skip to content

Commit

Permalink
Add support for overrides in explorer
Browse files Browse the repository at this point in the history
  • Loading branch information
odjuricicTT committed Dec 27, 2024
1 parent cf80a1a commit d9a8568
Show file tree
Hide file tree
Showing 13 changed files with 468 additions and 201 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
/test/ttmlir/Silicon/TTNN/optimizer/ @nobradovictt @odjuricicTT
/test/unittests/Optimizer @nobradovictt @odjuricicTT
/tools/ @svuckovicTT @mtopalovicTT
/tools/explorer/ @odjuricicTT @nobradovictt @vprajapati-tt
/tools/explorer/ @odjuricicTT @nobradovictt @vprajapati-tt @vcanicTT
5 changes: 2 additions & 3 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ class OptimizerOverridesHandler {

// Wrapper methods we use to expose the adders to the python bindings
void addInputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &);
void addOutputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &,
BufferType, TensorMemoryLayout,
tt::ttnn::Layout, tt::DataType);
void addOutputLayoutOverridePybindWrapper(std::string,
OutputLayoutOverrideParams);

private:
// Flags for enabling/disabling the optimizer passes
Expand Down
17 changes: 7 additions & 10 deletions include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H

#include <string_view>

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

namespace mlir::tt::ttnn {

Expand All @@ -31,12 +28,12 @@ struct OptionNames {
};

struct OutputLayoutOverrideParams {
std::optional<SmallVector<int64_t, 2>> grid;
std::optional<BufferType> bufferType;
std::optional<TensorMemoryLayout>
tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
std::optional<Layout> memoryLayout; // ROW_MAJOR / TILE
std::optional<tt::DataType> dataType;
std::optional<SmallVector<int64_t, 2>> grid = std::nullopt;
std::optional<BufferType> bufferType = std::nullopt;
std::optional<TensorMemoryLayout> tensorMemoryLayout =
std::nullopt; // INTERLEAVED / SHARDED etc...
std::optional<Layout> memoryLayout = std::nullopt; // ROW_MAJOR / TILE
std::optional<tt::DataType> dataType = std::nullopt;

// Check if all layout parameters that are generated in LegalLayoutAnalysis
// are overridden. DataType is the only that is not.
Expand All @@ -45,7 +42,7 @@ struct OutputLayoutOverrideParams {
tensorMemoryLayout.has_value() && memoryLayout.has_value();
}

bool operator==(const OutputLayoutOverrideParams rhs) const {
bool operator==(const OutputLayoutOverrideParams &rhs) const {
if (grid.has_value() != rhs.grid.has_value()) {
return false;
}
Expand Down
9 changes: 3 additions & 6 deletions lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy(
}
void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy(
MemoryLayoutAnalysisPolicyType value) {
enableMemoryLayoutAnalysisPolicy = true;
memoryLayoutAnalysisPolicy = value;
}

Expand Down Expand Up @@ -198,13 +199,9 @@ void OptimizerOverridesHandler::addInputLayoutOverridePybindWrapper(
}

void OptimizerOverridesHandler::addOutputLayoutOverridePybindWrapper(
std::string opName, std::vector<int64_t> &grid, BufferType bufferType,
TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout,
tt::DataType dataType) {
std::string opName, OutputLayoutOverrideParams overrideParams) {
StringRef opNameStringRef(opName);
SmallVector<int64_t> gridSmallVector(grid.begin(), grid.end());
addOutputLayoutOverride(opNameStringRef, gridSmallVector, bufferType,
tensorMemoryLayout, memoryLayout, dataType);
addOutputLayoutOverride(opNameStringRef, overrideParams);
}

} // namespace mlir::tt::ttnn
36 changes: 24 additions & 12 deletions lib/Dialect/TTNN/Utils/PassOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"
#include <numeric>

namespace mlir::tt::ttnn {

Expand Down Expand Up @@ -102,32 +103,43 @@ std::string OutputLayoutOverrideParser::toString(
res += std::string(entry.getKey()) + "=";
const OutputLayoutOverrideParams &params = entry.getValue();

// Print grid values
std::vector<std::string> parts;

// Collect grid values
if (params.grid.has_value()) {
std::string gridStr;
for (size_t i = 0; i < params.grid.value().size(); ++i) {
res += std::to_string(params.grid.value()[i]);
gridStr += std::to_string(params.grid.value()[i]);
if (i < params.grid.value().size() - 1) {
res += "x";
gridStr += "x";
}
}
parts.push_back(gridStr);
}
// Print memory space and memory layout
// Collect memory space and memory layout
if (params.bufferType.has_value()) {
res += ":" + std::string(mlir::tt::ttnn::stringifyBufferType(
params.bufferType.value()));
parts.push_back(std::string(
mlir::tt::ttnn::stringifyBufferType(params.bufferType.value())));
}
if (params.tensorMemoryLayout.has_value()) {
res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout(
params.tensorMemoryLayout.value()));
parts.push_back(std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout(
params.tensorMemoryLayout.value())));
}
if (params.memoryLayout.has_value()) {
res += ":" + std::string(mlir::tt::ttnn::stringifyLayout(
params.memoryLayout.value()));
parts.push_back(std::string(
mlir::tt::ttnn::stringifyLayout(params.memoryLayout.value())));
}
if (params.dataType.has_value()) {
res += ":" + std::string(
mlir::tt::DataTypeEnumToString(params.dataType.value()));
parts.push_back(
std::string(mlir::tt::DataTypeEnumToString(params.dataType.value())));
}

// Join parts with ":"
res += std::accumulate(parts.begin(), parts.end(), std::string(),
[](const std::string &a, const std::string &b) {
return a.empty() ? b : a + ":" + b;
});

if (++count < value.size()) {
res += ",";
}
Expand Down
53 changes: 50 additions & 3 deletions python/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h"
#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include <optional>

namespace mlir::ttmlir::python {

Expand Down Expand Up @@ -133,12 +134,20 @@ void populateOptimizerOverridesModule(py::module &m) {
"grid",
[](const mlir::tt::ttnn::OutputLayoutOverrideParams &obj) {
// Getter: Convert SmallVector to std::vector
return std::vector<int64_t>(obj.grid->begin(), obj.grid->end());
if (obj.grid.has_value()) {
return std::make_optional<std::vector<int64_t>>(obj.grid->begin(),
obj.grid->end());
}
return std::make_optional<std::vector<int64_t>>();
},
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::vector<int64_t> &input) {
// Setter: Convert std::vector to SmallVector
obj.grid->clear();
if (!obj.grid.has_value()) {
obj.grid = SmallVector<int64_t, 2>();
} else {
obj.grid->clear();
}
obj.grid->append(input.begin(), input.end());
})
.def_readwrite("buffer_type",
Expand All @@ -149,7 +158,45 @@ void populateOptimizerOverridesModule(py::module &m) {
.def_readwrite("memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout)
.def_readwrite("data_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType);
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType)
.def("set_buffer_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto bufferType = mlir::tt::ttnn::symbolizeBufferType(value)) {
obj.bufferType = bufferType;
} else {
throw std::invalid_argument("Invalid buffer type: " + value);
}
})
.def("set_tensor_memory_layout_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto tensorMemoryLayout =
mlir::tt::ttnn::symbolizeTensorMemoryLayout(value)) {
obj.tensorMemoryLayout = tensorMemoryLayout;
} else {
throw std::invalid_argument("Invalid tensor memory layout: " +
value);
}
})
.def("set_memory_layout_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(value)) {
obj.memoryLayout = memoryLayout;
} else {
throw std::invalid_argument("Invalid memory layout: " + value);
}
})
.def("set_data_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto dataType = mlir::tt::DataTypeStringToEnum(value)) {
obj.dataType = dataType;
} else {
throw std::invalid_argument("Invalid data type: " + value);
}
});
}

} // namespace mlir::ttmlir::python
12 changes: 11 additions & 1 deletion python/TTNNModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,24 @@ void populateTTNNModule(py::module &m) {
.def_property_readonly(
"memref",
[](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getMemref()); })
.def_property_readonly("memory_layout_as_int",
.def_property_readonly("tensor_memory_layout_as_int",
[](tt::ttnn::TTNNLayoutAttr self)
-> std::variant<uint32_t, py::object> {
if (!self.getMemLayout()) {
return py::none();
}
return static_cast<uint32_t>(
self.getMemLayout().getValue());
})
.def_property_readonly("memory_layout_as_int",
[](tt::ttnn::TTNNLayoutAttr self) {
return static_cast<uint32_t>(self.getLayout());
})
.def_property_readonly("data_type_as_int",
[](tt::ttnn::TTNNLayoutAttr self) {
return static_cast<uint32_t>(self.getDataType());
});
// .def_property_readonly("data_type",
// &tt::ttnn::TTNNLayoutAttr::getDataType);
}
} // namespace mlir::ttmlir::python
39 changes: 21 additions & 18 deletions python/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include <pybind11/pytypes.h>
#include <variant>

namespace mlir::ttmlir::python {

Expand All @@ -17,25 +19,26 @@ void populateUtilModule(py::module &m) {
return source;
});

m.def("get_loc_name", [](MlirLocation _loc) -> std::string {
mlir::Location loc = unwrap(_loc);
if (mlir::isa<mlir::NameLoc>(loc)) {
mlir::NameLoc nameLoc = mlir::cast<mlir::NameLoc>(loc);
return nameLoc.getName().str();
}
return "-";
});
m.def("get_loc_name",
[](MlirLocation _loc) -> std::variant<std::string, py::object> {
mlir::Location loc = unwrap(_loc);
if (mlir::isa<mlir::NameLoc>(loc)) {
mlir::NameLoc nameLoc = mlir::cast<mlir::NameLoc>(loc);
return nameLoc.getName().str();
}
return py::none();
});

m.def("get_loc_full", [](MlirLocation _loc) -> std::string {
mlir::Location loc = unwrap(_loc);
if (mlir::isa<mlir::FileLineColLoc>(loc)) {
mlir::FileLineColLoc fileLoc = mlir::cast<mlir::FileLineColLoc>(loc);
return fileLoc.getFilename().str() + ":" +
std::to_string(fileLoc.getLine()) + ":" +
std::to_string(fileLoc.getColumn());
}
return "-";
});
m.def("get_loc_full",
[](MlirLocation _loc) -> std::variant<std::string, py::object> {
mlir::Location loc = unwrap(_loc);

std::string locationStr;
llvm::raw_string_ostream output(locationStr);
loc.print(output);

return locationStr;
});
}

} // namespace mlir::ttmlir::python
2 changes: 1 addition & 1 deletion tools/explorer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ include(ExternalProject)
set(TT_EXPLORER_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run.py)
set(TTMLIR_BUILD_BIN_DIR ${TTMLIR_BINARY_DIR}/bin)

set(MODEL_EXPLORER_VERSION "ca884d5eb3291507e7f4e76776957e231b2d9b6d")
set(MODEL_EXPLORER_VERSION "4ffe3f0c11969b8eb628e1309bcde1990bb470b7")
ExternalProject_Add(
model-explorer
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer
Expand Down
Loading

0 comments on commit d9a8568

Please sign in to comment.