Skip to content

Commit

Permalink
TTNN Rendering Support in TT-Explorer (#1298)
Browse files Browse the repository at this point in the history
* Added maybe_downcast & hardened TT Attrs and Types to include better support

* Removed manual maybe_downcast, added tt_class

* Removed redundant imports

* Lint Fixes

* new MLIR module for parsing TTNN modules

* Added TTNNLayout Support + Fixes

* editable on Debug, minor fixes

* Requested Changes

* Removed stale import

* Removed stale import
  • Loading branch information
vprajapati-tt authored Nov 25, 2024
1 parent 1609d01 commit 02df31c
Show file tree
Hide file tree
Showing 12 changed files with 643 additions and 158 deletions.
3 changes: 3 additions & 0 deletions include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipPhysicalCoresAttrGet(
MlirAttribute *dram, size_t dramSize, MlirAttribute *eth, size_t ethSize,
MlirAttribute *eth_inactive, size_t eth_inactiveSize);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTCoreCoordAttrGet(MlirContext ctx,
int64_t y, int64_t x);

#ifdef __cplusplus
}
#endif
Expand Down
5 changes: 5 additions & 0 deletions include/ttmlir-c/TTNNAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef TTMLIR_C_TTNNATTRS_H
#define TTMLIR_C_TTNNATTRS_H

#include "mlir-c/AffineMap.h"
#include "ttmlir-c/Dialects.h"

#ifdef __cplusplus
Expand Down Expand Up @@ -44,6 +45,10 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx,
int64_t y,
int64_t x);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTNNTTNNLayoutAttrGet(
MlirContext ctx, MlirAffineMap linear, MlirAttribute grid, MlirType memref,
unsigned memLayout);

#ifdef __cplusplus
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,8 @@ MlirAttribute ttmlirTTChipPhysicalCoresAttrGet(
ethVec, ethInactiveVec));
}

MlirAttribute ttmlirTTCoreCoordAttrGet(MlirContext ctx, int64_t y, int64_t x) {
return wrap(CoreCoordAttr::get(unwrap(ctx), y, x));
}

} // namespace mlir::tt
10 changes: 10 additions & 0 deletions lib/CAPI/TTNNAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,14 @@ MlirAttribute ttmlirTTNNMeshShapeAttrGet(MlirContext ctx, int64_t y,
return wrap(MeshShapeAttr::get(unwrap(ctx), y, x));
}

MlirAttribute ttmlirTTNNTTNNLayoutAttrGet(MlirContext ctx, MlirAffineMap linear,
MlirAttribute grid, MlirType memref,
unsigned memLayout) {
mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr);
return wrap(TTNNLayoutAttr::get(unwrap(ctx), affineMap,
mlir::cast<GridAttr>(unwrap(grid)),
mlir::cast<MemRefType>(unwrap(memref)),
static_cast<TensorMemoryLayout>(memLayout)));
}

} // namespace mlir::tt::ttnn
25 changes: 21 additions & 4 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ void populateTTModule(py::module &m) {
return static_cast<uint32_t>(la.getOobVal());
})
.def_property_readonly("grid_attr", &tt::LayoutAttr::getGrid)
.def_property_readonly("memref", &tt::LayoutAttr::getMemref)
.def_property_readonly(
"memref", [](tt::LayoutAttr self) { return wrap(self.getMemref()); })
.def_property_readonly("memory_space", &tt::LayoutAttr::getMemorySpace)
.def_property_readonly("memory_space_as_int",
[](tt::LayoutAttr la) {
Expand All @@ -99,6 +100,8 @@ void populateTTModule(py::module &m) {
})
.def_property_readonly("shard_shape", &tt::LayoutAttr::getShardShape)
.def_property_readonly("memory_layout", &tt::LayoutAttr::getMemLayout)
.def_property_readonly(
"linear", [](tt::LayoutAttr self) { return wrap(self.getLinear()); })
.def_property_readonly("memory_layout_as_int", [](tt::LayoutAttr la) {
return static_cast<uint32_t>(la.getMemLayout());
});
Expand Down Expand Up @@ -236,6 +239,14 @@ void populateTTModule(py::module &m) {
return self.getEthInactive().vec();
});

tt_attribute_class<tt::CoreCoordAttr>(m, "CoreCoordAttr")
.def_static("get",
[](MlirContext ctx, int64_t y, int64_t x) {
return wrap(tt::CoreCoordAttr::get(unwrap(ctx), y, x));
})
.def_property_readonly("y", &tt::CoreCoordAttr::getY)
.def_property_readonly("x", &tt::CoreCoordAttr::getX);

tt_attribute_class<tt::ChipCoordAttr>(m, "ChipCoordAttr")
.def_static("get",
[](MlirContext ctx, unsigned rack, unsigned shelf, unsigned y,
Expand Down Expand Up @@ -430,8 +441,11 @@ void populateTTModule(py::module &m) {
return mlir::cast<tt::DeviceAttr>(unwrap(self));
})
.def_property_readonly("grid_attr", &tt::DeviceAttr::getWorkerGrid)
.def_property_readonly("l1_map", &tt::DeviceAttr::getL1Map)
.def_property_readonly("dram_map", &tt::DeviceAttr::getDramMap)
.def_property_readonly(
"l1_map", [](tt::DeviceAttr self) { return wrap(self.getL1Map()); })
.def_property_readonly(
"dram_map",
[](tt::DeviceAttr self) { return wrap(self.getDramMap()); })
.def_property_readonly(
"mesh_shape",
[](tt::DeviceAttr const &self) { return self.getMeshShape().vec(); })
Expand All @@ -447,7 +461,10 @@ void populateTTModule(py::module &m) {
unwrap(ctx), SmallVector<std::int64_t>{height, width},
static_cast<tt::DataType>(dataType)));
})
.def_property_readonly("data_type", &tt::TileType::getDataType)
.def_property_readonly("data_type_as_int",
[](tt::TileType self) {
return static_cast<uint32_t>(self.getDataType());
})
.def_property_readonly("shape", [](tt::TileType const &tile) {
return std::vector<int64_t>({tile.getHeight(), tile.getWidth()});
});
Expand Down
23 changes: 23 additions & 0 deletions python/TTNNModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "mlir/CAPI/AffineMap.h"
#include "ttmlir/Bindings/Python/TTMLIRModule.h"

namespace mlir::ttmlir::python {
Expand Down Expand Up @@ -127,5 +128,27 @@ void populateTTNNModule(py::module &m) {
})
.def_property_readonly("y", &tt::ttnn::MeshShapeAttr::getY)
.def_property_readonly("x", &tt::ttnn::MeshShapeAttr::getX);

tt_attribute_class<tt::ttnn::TTNNLayoutAttr>(m, "TTNNLayoutAttr")
.def_static("get",
[](MlirContext ctx, MlirAffineMap linear, MlirAttribute grid,
MlirType memref, unsigned memLayout) {
return wrap(tt::ttnn::TTNNLayoutAttr::get(
unwrap(ctx), mlir::cast<AffineMap>(unwrap(linear)),
mlir::cast<tt::GridAttr>(unwrap(grid)),
mlir::cast<MemRefType>(unwrap(memref)),
static_cast<tt::ttnn::TensorMemoryLayout>(memLayout)));
})
.def_property_readonly(
"linear",
[](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getLinear()); })
.def_property_readonly("grid_attr", &tt::ttnn::TTNNLayoutAttr::getGrid)
.def_property_readonly(
"memref",
[](tt::ttnn::TTNNLayoutAttr self) { return wrap(self.getMemref()); })
.def_property_readonly(
"memory_layout_as_int", [](tt::ttnn::TTNNLayoutAttr self) {
return static_cast<uint32_t>(self.getMemLayout());
});
}
} // namespace mlir::ttmlir::python
1 change: 1 addition & 0 deletions python/ttmlir/dialects/ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
# SPDX-License-Identifier: Apache-2.0

from ._ttnn_ops_gen import *
from ._ttnn_enum_gen import *
from .._mlir_libs._ttmlir import register_dialect, ttnn_ir as ir
2 changes: 1 addition & 1 deletion tools/explorer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ExternalProject_Add(

add_custom_target(explorer
COMMENT "Building tt-explorer... ${TTMLIR_BIN_DIR}"
COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter
COMMAND pip install $<$<CONFIG:Debug>:-e> ${CMAKE_CURRENT_SOURCE_DIR}/tt_adapter
COMMAND pip install ${CMAKE_CURRENT_SOURCE_DIR}/model-explorer/src/model-explorer/src/server/package

DEPENDS TTMLIRPythonModules model-explorer ttrt
Expand Down
4 changes: 2 additions & 2 deletions tools/explorer/tt_adapter/src/tt_adapter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict
import model_explorer
from . import ttir, runner, utils
from . import runner, utils, mlir
import dataclasses
import enum

Expand Down Expand Up @@ -46,7 +46,7 @@ def convert(
module = utils.parse_mlir_file(model_path)

# Convert TTIR to Model Explorer Graphs and Display/Return
graph = ttir.ttir_to_graph(module)
graph = mlir.build_graph(module)
return {"graphs": [graph]}

def execute(
Expand Down
Loading

0 comments on commit 02df31c

Please sign in to comment.