Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimizer overrides to explorer #1604

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@
/test/unittests/Optimizer @odjuricicTT @tt-mpantic @sdjordjevicTT @nobradovictt
/test/unittests/OpModel @odjuricicTT @tt-mpantic @sdjordjevicTT @nobradovictt
/tools/ @svuckovicTT @mtopalovicTT
/tools/explorer/ @odjuricicTT @tt-mpantic @sdjordjevicTT @nobradovictt @vprajapati-tt
/tools/explorer/ @odjuricicTT @tt-mpantic @sdjordjevicTT @nobradovictt @vprajapati-tt @vcanicTT
5 changes: 2 additions & 3 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ class OptimizerOverridesHandler {

// Wrapper methods we use to expose the adders to the python bindings
void addInputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &);
void addOutputLayoutOverridePybindWrapper(std::string, std::vector<int64_t> &,
BufferType, TensorMemoryLayout,
tt::ttnn::Layout, tt::DataType);
void addOutputLayoutOverridePybindWrapper(std::string,
OutputLayoutOverrideParams);

private:
// Flags for enabling/disabling the optimizer passes
Expand Down
17 changes: 7 additions & 10 deletions include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H

#include <string_view>

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

namespace mlir::tt::ttnn {

Expand All @@ -31,12 +28,12 @@ struct OptionNames {
};

struct OutputLayoutOverrideParams {
std::optional<SmallVector<int64_t, 2>> grid;
std::optional<BufferType> bufferType;
std::optional<TensorMemoryLayout>
tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
std::optional<Layout> memoryLayout; // ROW_MAJOR / TILE
std::optional<tt::DataType> dataType;
std::optional<SmallVector<int64_t, 2>> grid = std::nullopt;
std::optional<BufferType> bufferType = std::nullopt;
std::optional<TensorMemoryLayout> tensorMemoryLayout =
std::nullopt; // INTERLEAVED / SHARDED etc...
std::optional<Layout> memoryLayout = std::nullopt; // ROW_MAJOR / TILE
std::optional<tt::DataType> dataType = std::nullopt;

// Check if all layout parameters that are generated in LegalLayoutAnalysis
// are overridden. DataType is the only that is not.
Expand All @@ -45,7 +42,7 @@ struct OutputLayoutOverrideParams {
tensorMemoryLayout.has_value() && memoryLayout.has_value();
}

bool operator==(const OutputLayoutOverrideParams rhs) const {
bool operator==(const OutputLayoutOverrideParams &rhs) const {
if (grid.has_value() != rhs.grid.has_value()) {
return false;
}
Expand Down
9 changes: 3 additions & 6 deletions lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy(
}
void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy(
MemoryLayoutAnalysisPolicyType value) {
enableMemoryLayoutAnalysisPolicy = true;
memoryLayoutAnalysisPolicy = value;
}

Expand Down Expand Up @@ -198,13 +199,9 @@ void OptimizerOverridesHandler::addInputLayoutOverridePybindWrapper(
}

void OptimizerOverridesHandler::addOutputLayoutOverridePybindWrapper(
std::string opName, std::vector<int64_t> &grid, BufferType bufferType,
TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout,
tt::DataType dataType) {
std::string opName, OutputLayoutOverrideParams overrideParams) {
StringRef opNameStringRef(opName);
SmallVector<int64_t> gridSmallVector(grid.begin(), grid.end());
addOutputLayoutOverride(opNameStringRef, gridSmallVector, bufferType,
tensorMemoryLayout, memoryLayout, dataType);
addOutputLayoutOverride(opNameStringRef, overrideParams);
}

} // namespace mlir::tt::ttnn
36 changes: 24 additions & 12 deletions lib/Dialect/TTNN/Utils/PassOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"
#include <numeric>

namespace mlir::tt::ttnn {

Expand Down Expand Up @@ -102,32 +103,43 @@ std::string OutputLayoutOverrideParser::toString(
res += std::string(entry.getKey()) + "=";
const OutputLayoutOverrideParams &params = entry.getValue();

// Print grid values
std::vector<std::string> parts;

// Collect grid values
if (params.grid.has_value()) {
std::string gridStr;
for (size_t i = 0; i < params.grid.value().size(); ++i) {
res += std::to_string(params.grid.value()[i]);
gridStr += std::to_string(params.grid.value()[i]);
if (i < params.grid.value().size() - 1) {
res += "x";
gridStr += "x";
}
}
parts.push_back(gridStr);
}
// Print memory space and memory layout
// Collect memory space and memory layout
if (params.bufferType.has_value()) {
res += ":" + std::string(mlir::tt::ttnn::stringifyBufferType(
params.bufferType.value()));
parts.push_back(std::string(
mlir::tt::ttnn::stringifyBufferType(params.bufferType.value())));
}
if (params.tensorMemoryLayout.has_value()) {
res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout(
params.tensorMemoryLayout.value()));
parts.push_back(std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout(
params.tensorMemoryLayout.value())));
}
if (params.memoryLayout.has_value()) {
res += ":" + std::string(mlir::tt::ttnn::stringifyLayout(
params.memoryLayout.value()));
parts.push_back(std::string(
mlir::tt::ttnn::stringifyLayout(params.memoryLayout.value())));
}
if (params.dataType.has_value()) {
res += ":" + std::string(
mlir::tt::DataTypeEnumToString(params.dataType.value()));
parts.push_back(
std::string(mlir::tt::DataTypeEnumToString(params.dataType.value())));
}

// Join parts with ":"
res += std::accumulate(parts.begin(), parts.end(), std::string(),
[](const std::string &a, const std::string &b) {
return a.empty() ? b : a + ":" + b;
});

if (++count < value.size()) {
res += ",";
}
Expand Down
53 changes: 50 additions & 3 deletions python/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h"
#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include <optional>

namespace mlir::ttmlir::python {

Expand Down Expand Up @@ -133,12 +134,20 @@ void populateOptimizerOverridesModule(py::module &m) {
"grid",
[](const mlir::tt::ttnn::OutputLayoutOverrideParams &obj) {
// Getter: Convert SmallVector to std::vector
return std::vector<int64_t>(obj.grid->begin(), obj.grid->end());
if (obj.grid.has_value()) {
return std::make_optional<std::vector<int64_t>>(obj.grid->begin(),
obj.grid->end());
}
return std::make_optional<std::vector<int64_t>>();
},
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::vector<int64_t> &input) {
// Setter: Convert std::vector to SmallVector
obj.grid->clear();
if (!obj.grid.has_value()) {
obj.grid = SmallVector<int64_t, 2>();
} else {
obj.grid->clear();
}
obj.grid->append(input.begin(), input.end());
})
.def_readwrite("buffer_type",
Expand All @@ -149,7 +158,45 @@ void populateOptimizerOverridesModule(py::module &m) {
.def_readwrite("memory_layout",
&mlir::tt::ttnn::OutputLayoutOverrideParams::memoryLayout)
.def_readwrite("data_type",
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType);
&mlir::tt::ttnn::OutputLayoutOverrideParams::dataType)
.def("set_buffer_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto bufferType = mlir::tt::ttnn::symbolizeBufferType(value)) {
obj.bufferType = bufferType;
} else {
throw std::invalid_argument("Invalid buffer type: " + value);
}
})
.def("set_tensor_memory_layout_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto tensorMemoryLayout =
mlir::tt::ttnn::symbolizeTensorMemoryLayout(value)) {
obj.tensorMemoryLayout = tensorMemoryLayout;
} else {
throw std::invalid_argument("Invalid tensor memory layout: " +
value);
}
})
.def("set_memory_layout_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto memoryLayout = mlir::tt::ttnn::symbolizeLayout(value)) {
obj.memoryLayout = memoryLayout;
} else {
throw std::invalid_argument("Invalid memory layout: " + value);
}
})
.def("set_data_type_from_str",
[](mlir::tt::ttnn::OutputLayoutOverrideParams &obj,
const std::string &value) {
if (auto dataType = mlir::tt::DataTypeStringToEnum(value)) {
obj.dataType = dataType;
} else {
throw std::invalid_argument("Invalid data type: " + value);
}
});
}

} // namespace mlir::ttmlir::python
10 changes: 9 additions & 1 deletion python/TTNNModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,22 @@ void populateTTNNModule(py::module &m) {
.def_property_readonly(
"memref",
[](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getMemref()); })
.def_property_readonly("memory_layout_as_int",
.def_property_readonly("tensor_memory_layout_as_int",
[](tt::ttnn::TTNNLayoutAttr self)
-> std::variant<uint32_t, py::object> {
if (!self.getMemLayout()) {
return py::none();
}
return static_cast<uint32_t>(
self.getMemLayout().getValue());
})
.def_property_readonly("memory_layout_as_int",
[](tt::ttnn::TTNNLayoutAttr self) {
return static_cast<uint32_t>(self.getLayout());
})
.def_property_readonly("data_type_as_int",
[](tt::ttnn::TTNNLayoutAttr self) {
return static_cast<uint32_t>(self.getDataType());
});
}
} // namespace mlir::ttmlir::python
39 changes: 21 additions & 18 deletions python/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Bindings/Python/TTMLIRModule.h"
#include <pybind11/pytypes.h>
#include <variant>

namespace mlir::ttmlir::python {

Expand All @@ -17,25 +19,26 @@ void populateUtilModule(py::module &m) {
return source;
});

m.def("get_loc_name", [](MlirLocation _loc) -> std::string {
mlir::Location loc = unwrap(_loc);
if (mlir::isa<mlir::NameLoc>(loc)) {
mlir::NameLoc nameLoc = mlir::cast<mlir::NameLoc>(loc);
return nameLoc.getName().str();
}
return "-";
});
m.def("get_loc_name",
[](MlirLocation _loc) -> std::variant<std::string, py::object> {
mlir::Location loc = unwrap(_loc);
if (mlir::isa<mlir::NameLoc>(loc)) {
mlir::NameLoc nameLoc = mlir::cast<mlir::NameLoc>(loc);
return nameLoc.getName().str();
}
return py::none();
});

m.def("get_loc_full", [](MlirLocation _loc) -> std::string {
mlir::Location loc = unwrap(_loc);
if (mlir::isa<mlir::FileLineColLoc>(loc)) {
mlir::FileLineColLoc fileLoc = mlir::cast<mlir::FileLineColLoc>(loc);
return fileLoc.getFilename().str() + ":" +
std::to_string(fileLoc.getLine()) + ":" +
std::to_string(fileLoc.getColumn());
}
return "-";
});
m.def("get_loc_full",
[](MlirLocation _loc) -> std::variant<std::string, py::object> {
mlir::Location loc = unwrap(_loc);

std::string locationStr;
llvm::raw_string_ostream output(locationStr);
loc.print(output);

return locationStr;
});
}

} // namespace mlir::ttmlir::python
2 changes: 1 addition & 1 deletion tools/explorer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ include(ExternalProject)
set(TT_EXPLORER_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/run.py)
set(TTMLIR_BUILD_BIN_DIR ${TTMLIR_BINARY_DIR}/bin)

set(MODEL_EXPLORER_VERSION "ca884d5eb3291507e7f4e76776957e231b2d9b6d")
set(MODEL_EXPLORER_VERSION "4ffe3f0c11969b8eb628e1309bcde1990bb470b7")
ExternalProject_Add(
model-explorer
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer
Expand Down
Loading
Loading