Skip to content

Commit

Permalink
Added supported_data_types and supported_tile_sizes to ChipDesc. (#433)…
Browse files Browse the repository at this point in the history
… (#494)

Currently using placeholder values.
  • Loading branch information
ddilbazTT authored Aug 27, 2024
1 parent 1bcf83d commit 287a2ef
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 12 deletions.
6 changes: 5 additions & 1 deletion include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ ttmlirTTChipCapabilityAttrGet(MlirContext ctx, uint32_t chipCapability);
MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTArchAttrGet(MlirContext ctx,
uint32_t arch);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTDataTypeAttrGet(MlirContext ctx, uint16_t *supportedDataTypes);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipDescAttrGet(
MlirContext ctx, MlirAttribute arch, int64_t *grid, size_t gridSize,
unsigned l1Size, unsigned numDramChannels, unsigned dramChannelSize,
unsigned nocL1AddressAlignBytes, unsigned pcieAddressAlignBytes,
unsigned nocDRAMAddressAlignBytes, unsigned l1UnreservedBase,
unsigned eriscL1UnreservedBase, unsigned dramUnreservedBase,
MlirAttribute chipPhysicalCores);
MlirAttribute chipPhysicalCores, MlirAttribute *supportedDataTypes,
MlirAttribute *supportedTileSizes);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipCoordAttrGet(
MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, unsigned x);
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def TT_DataType : I32EnumAttr<"DataType", "TT DataTypes",
TT_UInt16,
TT_UInt8
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
let stringToSymbolFnName = "DataTypeStringToEnum";
let symbolToStringFnName = "DataTypeEnumToString";
Expand Down
22 changes: 20 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def TT_ArchAttr : EnumAttr<TT_Dialect, TT_Arch, "arch"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_DataTypeAttr : EnumAttr<TT_Dialect, TT_DataType, "supportedDataTypes"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_CoreCoordAttr : TT_Attr<"CoreCoord", "core_coord"> {
let summary = "TT core_coord attribute";
let description = [{
Expand All @@ -61,6 +65,16 @@ def TT_CoreCoordAttr : TT_Attr<"CoreCoord", "core_coord"> {
let assemblyFormat = "`(` $y `,` $x `)`";
}

def TT_TileSizeAttr : TT_Attr<"TileSize", "tile_size"> {
let summary = "TT tile_size attribute";
let description = [{
TT tile_size attribute containing a supported Tensix tile shape.
}];

let parameters = (ins "int64_t":$y, "int64_t":$x);
let assemblyFormat = "`(` $y `x` $x `)`";
}


def TT_ChipPhysicalCoresAttr : TT_Attr<"ChipPhysicalCores", "chip_physical_cores"> {
let summary = "TT chip_physical_cores attribute";
Expand Down Expand Up @@ -89,7 +103,9 @@ def TT_ChipDescAttr : TT_Attr<"ChipDesc", "chip_desc"> {
"unsigned":$l1UnreservedBase,
"unsigned":$eriscL1UnreservedBase,
"unsigned":$dramUnreservedBase,
"ChipPhysicalCoresAttr":$chipPhysicalCores);
"ChipPhysicalCoresAttr":$chipPhysicalCores,
ArrayRefParameter<"DataTypeAttr">:$supportedDataTypes,
ArrayRefParameter<"TileSizeAttr">:$supportedTileSizes);
let assemblyFormat = [{`{` `arch` `=` $arch `,`
`grid` `=` custom<DimensionList>($grid) `,`
`l1_size` `=` $l1Size `,`
Expand All @@ -101,7 +117,9 @@ def TT_ChipDescAttr : TT_Attr<"ChipDesc", "chip_desc"> {
`l1_unreserved_base` `=` $l1UnreservedBase `,`
`erisc_l1_unreserved_base` `=` $eriscL1UnreservedBase `,`
`dram_unreserved_base` `=` $dramUnreservedBase `,`
`physical_cores` `=` $chipPhysicalCores `}`}];
`physical_cores` `=` $chipPhysicalCores `,`
`supported_data_types` `=` `[` $supportedDataTypes `]` `,`
`supported_tile_sizes` `=` `[` $supportedTileSizes `]` `}`}];

let extraClassDeclaration = [{
unsigned getUsableL1Size() const { return getL1Size() - getL1UnreservedBase(); }
Expand Down
4 changes: 3 additions & 1 deletion include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ enum Arch: uint {
Blackhole = 2,
}

enum DataType: ushort {
enum DataType: uint16 {
Float32 = 0,
Float16 = 1,
BFloat16 = 2,
Expand Down Expand Up @@ -106,6 +106,8 @@ table ChipDesc {
erisc_l1_unreserved_base: uint32;
dram_unreserved_base: uint32;
physical_cores: ChipPhysicalCores;
supported_data_types: [DataType];
supported_tile_sizes: [Dim2d];
}

struct ChipCoord {
Expand Down
15 changes: 14 additions & 1 deletion include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ inline ::tt::target::Arch toFlatbuffer(FlatbufferObjectCache &, ArchAttr arch) {
}
}

// Overloaded function for DataTypeAttr
inline ::tt::target::DataType toFlatbuffer(FlatbufferObjectCache &cache,
const DataTypeAttr &dtypeAttr) {
return toFlatbuffer(cache, dtypeAttr.getValue());
}

inline ::tt::target::Dim2d toFlatbuffer(FlatbufferObjectCache &cache,
TileSizeAttr tileSize) {
return ::tt::target::Dim2d(tileSize.getY(), tileSize.getX());
}

inline ::tt::target::ChipCapability
toFlatbuffer(FlatbufferObjectCache &, ChipCapabilityAttr capabilityAttr) {
auto capabilities = capabilityAttr.getValue();
Expand Down Expand Up @@ -233,7 +244,9 @@ toFlatbuffer(FlatbufferObjectCache &cache, ChipDescAttr chipDesc) {
chipDesc.getPcieAddressAlignBytes(),
chipDesc.getNocDRAMAddressAlignBytes(), chipDesc.getL1UnreservedBase(),
chipDesc.getEriscL1UnreservedBase(), chipDesc.getDramUnreservedBase(),
toFlatbuffer(cache, chipDesc.getChipPhysicalCores()));
toFlatbuffer(cache, chipDesc.getChipPhysicalCores()),
toFlatbuffer(cache, chipDesc.getSupportedDataTypes()),
toFlatbuffer(cache, chipDesc.getSupportedTileSizes()));
}

inline flatbuffers::Offset<::tt::target::SystemDesc>
Expand Down
13 changes: 11 additions & 2 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,29 @@ MlirAttribute ttmlirTTArchAttrGet(MlirContext ctx, uint32_t arch) {
return wrap(ArchAttr::get(unwrap(ctx), static_cast<Arch>(arch)));
}

MlirAttribute ttmlirTTDataTypeAttrGet(MlirContext ctx,
uint16_t *supportedDataTypes) {
return wrap(DataTypeAttr::get(unwrap(ctx),
static_cast<DataType>(*supportedDataTypes)));
}

MlirAttribute ttmlirTTChipDescAttrGet(
MlirContext ctx, MlirAttribute arch, int64_t *grid, size_t gridSize,
unsigned l1Size, unsigned numDramChannels, unsigned dramChannelSize,
unsigned nocL1AddressAlignBytes, unsigned pcieAddressAlignBytes,
unsigned nocDRAMAddressAlignBytes, unsigned l1UnreservedBase,
unsigned eriscL1UnreservedBase, unsigned dramUnreservedBase,
MlirAttribute chipPhysicalCores) {
MlirAttribute chipPhysicalCores, MlirAttribute *supportedDataTypes,
MlirAttribute *supportedTileSizes) {
std::vector<int64_t> gridVec(grid, grid + gridSize);
return wrap(ChipDescAttr::get(
unwrap(ctx), mlir::dyn_cast<ArchAttr>(unwrap(arch)), gridVec, l1Size,
numDramChannels, dramChannelSize, nocL1AddressAlignBytes,
pcieAddressAlignBytes, nocDRAMAddressAlignBytes, l1UnreservedBase,
eriscL1UnreservedBase, dramUnreservedBase,
mlir::dyn_cast<ChipPhysicalCoresAttr>(unwrap(chipPhysicalCores))));
mlir::dyn_cast<ChipPhysicalCoresAttr>(unwrap(chipPhysicalCores)),
mlir::dyn_cast<DataTypeAttr>(unwrap(*supportedDataTypes)),
mlir::dyn_cast<TileSizeAttr>(unwrap(*supportedTileSizes))));
}

MlirAttribute ttmlirTTChipCoordAttrGet(MlirContext ctx, unsigned rack,
Expand Down
106 changes: 104 additions & 2 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,43 @@ mlir::tt::SystemDescAttr
mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
// Populate a dummy n150
SmallVector<std::int64_t> gridShape = {8, 8};

// populate a placeholder for supported tile sizes
SmallVector<tt::DataTypeAttr> supported_data_types;
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::Float32));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::Float16));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFloat16));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_Float8));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat8));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_Float4));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat4));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_Float2));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat2));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::UInt32));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::UInt16));
supported_data_types.push_back(
tt::DataTypeAttr::get(context, tt::DataType::UInt8));

// populate a placeholder for supported tile sizes
SmallVector<tt::TileSizeAttr> supported_tile_sizes;
supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 4, 16));
supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 16, 16));
supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 32, 16));
supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 4, 32));
supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 16, 32));
supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 32, 32));

SmallVector<CoreCoordAttr> workerCores;
workerCores.reserve(gridShape[0] * gridShape[1]);
for (std::int64_t y = 0; y < gridShape[0]; ++y) {
Expand All @@ -50,7 +87,8 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) {
context, tt::ArchAttr::get(context, tt::Arch::WormholeB0),
gridShape, 1499136, 12, (1 << 30), 16, 32, 32, 0, 0, 0,
tt::ChipPhysicalCoresAttr::get(context, workerCores, dramCores,
{}, {})),
{}, {}),
supported_data_types, supported_tile_sizes),
},
// Chip Descriptor Indices
{
Expand Down Expand Up @@ -117,6 +155,7 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {
eth_inactive_cores.emplace_back(
tt::CoreCoordAttr::get(context, core->y(), core->x()));
}

// Create ChipPhysicalCoresAttr from the list of CoreCoordAttr instances
auto chip_physical_cores_attr = tt::ChipPhysicalCoresAttr::get(
context, worker_cores, dram_cores, eth_cores, eth_inactive_cores);
Expand All @@ -134,6 +173,68 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {
break;
}

std::vector<tt::DataTypeAttr> supported_data_types_attr;

for (auto it : *(element->supported_data_types())) {
switch (it) {
case ::tt::target::DataType::Float32:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::Float32));
break;
case ::tt::target::DataType::Float16:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::Float16));
break;
case ::tt::target::DataType::BFloat16:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFloat16));
break;
case ::tt::target::DataType::BFP_Float8:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_Float8));
break;
case ::tt::target::DataType::BFP_BFloat8:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat8));
break;
case ::tt::target::DataType::BFP_Float4:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_Float4));
break;
case ::tt::target::DataType::BFP_BFloat4:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat4));
break;
case ::tt::target::DataType::BFP_Float2:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_Float2));
break;
case ::tt::target::DataType::BFP_BFloat2:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat2));
break;
case ::tt::target::DataType::UInt32:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::UInt32));
break;
case ::tt::target::DataType::UInt16:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::UInt16));
break;
case ::tt::target::DataType::UInt8:
supported_data_types_attr.push_back(
tt::DataTypeAttr::get(context, tt::DataType::UInt8));
break;
}
}

SmallVector<tt::TileSizeAttr> supported_tile_sizes_attr;

for (auto it : *(element->supported_tile_sizes())) {
supported_tile_sizes_attr.push_back(
tt::TileSizeAttr::get(context, it->y(), it->x()));
}

auto current_chip_desc_attr = tt::ChipDescAttr::get(
context, tt::ArchAttr::get(context, arch),
{element->grid_size()->y(), element->grid_size()->x()},
Expand All @@ -142,7 +243,8 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) {
element->pcie_address_align_bytes(),
element->noc_dram_address_align_bytes(), element->l1_unreserved_base(),
element->erisc_l1_unreserved_base(), element->dram_unreserved_base(),
chip_physical_cores_attr);
chip_physical_cores_attr, supported_data_types_attr,
supported_tile_sizes_attr);
chip_desc_list.push_back(current_chip_desc_attr);
}

Expand Down
16 changes: 14 additions & 2 deletions python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#include <vector>

#include "ttmlir/Bindings/Python/TTMLIRModule.h"
Expand All @@ -10,6 +11,7 @@
#include "mlir/CAPI/IR.h"

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Target/Common/types_generated.h"

namespace mlir::ttmlir::python {
void populateTTModule(py::module &m) {
Expand Down Expand Up @@ -108,6 +110,12 @@ void populateTTModule(py::module &m) {
tt::ArchAttr::get(unwrap(ctx), static_cast<tt::Arch>(arch)));
});

py::class_<tt::DataTypeAttr>(m, "DataTypeAttr")
.def_static("get", [](MlirContext ctx, uint16_t *supportedDataTypes) {
return wrap(tt::DataTypeAttr::get(
unwrap(ctx), static_cast<tt::DataType>(*supportedDataTypes)));
});

py::class_<tt::ChipDescAttr>(m, "ChipDescAttr")
.def_static(
"get",
Expand All @@ -116,15 +124,19 @@ void populateTTModule(py::module &m) {
unsigned dramChannelSize, unsigned nocL1AddressAlignBytes,
unsigned pcieAddressAlignBytes, unsigned nocDRAMAddressAlignBytes,
unsigned l1UnreservedBase, unsigned eriscL1UnreservedBase,
unsigned dramUnreservedBase, MlirAttribute chipPhysicalCores) {
unsigned dramUnreservedBase, MlirAttribute chipPhysicalCores,
MlirAttribute supportedDataTypes,
MlirAttribute supportedTileSizes) {
return wrap(tt::ChipDescAttr::get(
unwrap(ctx), mlir::cast<tt::ArchAttr>(unwrap(arch)), grid,
l1Size, numDramChannels, dramChannelSize,
nocL1AddressAlignBytes, pcieAddressAlignBytes,
nocDRAMAddressAlignBytes, l1UnreservedBase,
eriscL1UnreservedBase, dramUnreservedBase,
mlir::dyn_cast<tt::ChipPhysicalCoresAttr>(
unwrap(chipPhysicalCores))));
unwrap(chipPhysicalCores)),
mlir::cast<tt::DataTypeAttr>(unwrap(supportedDataTypes)),
mlir::cast<tt::TileSizeAttr>(unwrap(supportedTileSizes))));
});

py::class_<tt::ChipCoordAttr>(m, "ChipCoordAttr")
Expand Down
25 changes: 24 additions & 1 deletion runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"
#include "types_generated.h"
#include <cstdint>
#include <vector>

Expand Down Expand Up @@ -157,6 +158,7 @@ getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) {
std::vector<::flatbuffers::Offset<tt::target::ChipDesc>> chipDescs;
std::vector<uint32_t> chipDescIndices;
std::vector<::tt::target::ChipCapability> chipCapabilities;

// Ignore for now
std::vector<::tt::target::ChipCoord> chipCoords = {
::tt::target::ChipCoord(0, 0, 0, 0)};
Expand All @@ -170,12 +172,33 @@ getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) {
// Extract physical core coordinates for worker, dram, eth cores
auto chipPhysicalCores = createChipPhysicalCores(device, fbb);

// The following is temporary place-holder value to be replaced by API
// value.
std::vector<::tt::target::DataType> supportedDataTypesVector = {
::tt::target::DataType::Float32, ::tt::target::DataType::Float16,
::tt::target::DataType::BFloat16, ::tt::target::DataType::BFP_Float8,
::tt::target::DataType::BFP_BFloat8, ::tt::target::DataType::BFP_Float4,
::tt::target::DataType::BFP_BFloat4, ::tt::target::DataType::BFP_Float2,
::tt::target::DataType::BFP_BFloat2, ::tt::target::DataType::UInt32,
::tt::target::DataType::UInt16, ::tt::target::DataType::UInt8};

auto supportedDataTypes = fbb.CreateVector(supportedDataTypesVector);

std::vector<::tt::target::Dim2d> supportedTileSizesVector = {
::tt::target::Dim2d(4, 16), ::tt::target::Dim2d(16, 16),
::tt::target::Dim2d(32, 16), ::tt::target::Dim2d(4, 32),
::tt::target::Dim2d(16, 32), ::tt::target::Dim2d(32, 32)};

auto supportedTileSizes =
fbb.CreateVectorOfStructs(supportedTileSizesVector);

chipDescs.push_back(::tt::target::CreateChipDesc(
fbb, toFlatbuffer(device->arch()), &deviceGrid,
device->l1_size_per_core(), device->num_dram_channels(),
device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT,
DRAM_ALIGNMENT, L1_UNRESERVED_BASE, ERISC_L1_UNRESERVED_BASE,
DRAM_UNRESERVED_BASE, chipPhysicalCores));
DRAM_UNRESERVED_BASE, chipPhysicalCores, supportedDataTypes,
supportedTileSizes));
chipDescIndices.push_back(device->id());
// Derive chip capability
::tt::target::ChipCapability chipCapability =
Expand Down

0 comments on commit 287a2ef

Please sign in to comment.