diff --git a/include/ttmlir-c/TTAttrs.h b/include/ttmlir-c/TTAttrs.h index 0948c1fab..850a37887 100644 --- a/include/ttmlir-c/TTAttrs.h +++ b/include/ttmlir-c/TTAttrs.h @@ -22,13 +22,17 @@ ttmlirTTChipCapabilityAttrGet(MlirContext ctx, uint32_t chipCapability); MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTArchAttrGet(MlirContext ctx, uint32_t arch); +MLIR_CAPI_EXPORTED MlirAttribute +ttmlirTTDataTypeAttrGet(MlirContext ctx, uint16_t *supportedDataTypes); + MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipDescAttrGet( MlirContext ctx, MlirAttribute arch, int64_t *grid, size_t gridSize, unsigned l1Size, unsigned numDramChannels, unsigned dramChannelSize, unsigned nocL1AddressAlignBytes, unsigned pcieAddressAlignBytes, unsigned nocDRAMAddressAlignBytes, unsigned l1UnreservedBase, unsigned eriscL1UnreservedBase, unsigned dramUnreservedBase, - MlirAttribute chipPhysicalCores); + MlirAttribute chipPhysicalCores, MlirAttribute *supportedDataTypes, + MlirAttribute *supportedTileSizes); MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTChipCoordAttrGet( MlirContext ctx, unsigned rack, unsigned shelf, unsigned y, unsigned x); diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td index 9d596c038..251078bc0 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td @@ -35,6 +35,7 @@ def TT_DataType : I32EnumAttr<"DataType", "TT DataTypes", TT_UInt16, TT_UInt8 ]> { + let genSpecializedAttr = 0; let cppNamespace = "::mlir::tt"; let stringToSymbolFnName = "DataTypeStringToEnum"; let symbolToStringFnName = "DataTypeEnumToString"; diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index 0403c8fbf..d143f2f8d 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -51,6 +51,10 @@ def TT_ArchAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } +def TT_DataTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def TT_CoreCoordAttr : TT_Attr<"CoreCoord", "core_coord"> { let summary = "TT core_coord attribute"; let description = [{ @@ -61,6 +65,16 @@ def TT_CoreCoordAttr : TT_Attr<"CoreCoord", "core_coord"> { let assemblyFormat = "`(` $y `,` $x `)`"; } +def TT_TileSizeAttr : TT_Attr<"TileSize", "tile_size"> { + let summary = "TT tile_size attribute"; + let description = [{ + TT tile_size attribute containing a supported Tensix tile shape. + }]; + + let parameters = (ins "int64_t":$y, "int64_t":$x); + let assemblyFormat = "`(` $y `x` $x `)`"; +} + def TT_ChipPhysicalCoresAttr : TT_Attr<"ChipPhysicalCores", "chip_physical_cores"> { let summary = "TT chip_physical_cores attribute"; @@ -89,7 +103,9 @@ def TT_ChipDescAttr : TT_Attr<"ChipDesc", "chip_desc"> { "unsigned":$l1UnreservedBase, "unsigned":$eriscL1UnreservedBase, "unsigned":$dramUnreservedBase, - "ChipPhysicalCoresAttr":$chipPhysicalCores); + "ChipPhysicalCoresAttr":$chipPhysicalCores, + ArrayRefParameter<"DataTypeAttr">:$supportedDataTypes, + ArrayRefParameter<"TileSizeAttr">:$supportedTileSizes); let assemblyFormat = [{`{` `arch` `=` $arch `,` `grid` `=` custom($grid) `,` `l1_size` `=` $l1Size `,` @@ -101,7 +117,9 @@ def TT_ChipDescAttr : TT_Attr<"ChipDesc", "chip_desc"> { `l1_unreserved_base` `=` $l1UnreservedBase `,` `erisc_l1_unreserved_base` `=` $eriscL1UnreservedBase `,` `dram_unreserved_base` `=` $dramUnreservedBase `,` - `physical_cores` `=` $chipPhysicalCores `}`}]; + `physical_cores` `=` $chipPhysicalCores `,` + `supported_data_types` `=` `[` $supportedDataTypes `]` `,` + `supported_tile_sizes` `=` `[` $supportedTileSizes `]` `}`}]; let extraClassDeclaration = [{ unsigned getUsableL1Size() const { return getL1Size() - getL1UnreservedBase(); } diff --git a/include/ttmlir/Target/Common/types.fbs b/include/ttmlir/Target/Common/types.fbs index 052c5524d..42a828761 100644 --- a/include/ttmlir/Target/Common/types.fbs +++ b/include/ttmlir/Target/Common/types.fbs @@ -16,7 +16,7 @@ enum Arch: uint { Blackhole = 2, } -enum DataType: ushort { +enum DataType: uint16 { Float32 = 0, Float16 = 1, BFloat16 = 2, @@ -106,6 +106,8 @@ table ChipDesc { erisc_l1_unreserved_base: uint32; dram_unreserved_base: uint32; physical_cores: ChipPhysicalCores; + supported_data_types: [DataType]; + supported_tile_sizes: [Dim2d]; } struct ChipCoord { diff --git a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h index fcf0aa889..fa8e67466 100644 --- a/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h +++ b/include/ttmlir/Target/Utils/MLIRToFlatbuffer.h @@ -108,6 +108,17 @@ inline ::tt::target::Arch toFlatbuffer(FlatbufferObjectCache &, ArchAttr arch) { } } +// Overloaded function for DataTypeAttr +inline ::tt::target::DataType toFlatbuffer(FlatbufferObjectCache &cache, + const DataTypeAttr &dtypeAttr) { + return toFlatbuffer(cache, dtypeAttr.getValue()); +} + +inline ::tt::target::Dim2d toFlatbuffer(FlatbufferObjectCache &cache, + TileSizeAttr tileSize) { + return ::tt::target::Dim2d(tileSize.getY(), tileSize.getX()); +} + inline ::tt::target::ChipCapability toFlatbuffer(FlatbufferObjectCache &, ChipCapabilityAttr capabilityAttr) { auto capabilities = capabilityAttr.getValue(); @@ -233,7 +244,9 @@ toFlatbuffer(FlatbufferObjectCache &cache, ChipDescAttr chipDesc) { chipDesc.getPcieAddressAlignBytes(), chipDesc.getNocDRAMAddressAlignBytes(), chipDesc.getL1UnreservedBase(), chipDesc.getEriscL1UnreservedBase(), chipDesc.getDramUnreservedBase(), - toFlatbuffer(cache, chipDesc.getChipPhysicalCores())); + toFlatbuffer(cache, chipDesc.getChipPhysicalCores()), + toFlatbuffer(cache, chipDesc.getSupportedDataTypes()), + toFlatbuffer(cache, chipDesc.getSupportedTileSizes())); } inline flatbuffers::Offset<::tt::target::SystemDesc> diff --git a/lib/CAPI/TTAttrs.cpp b/lib/CAPI/TTAttrs.cpp index bed6ad92a..07db90b51 100644 --- a/lib/CAPI/TTAttrs.cpp +++ b/lib/CAPI/TTAttrs.cpp @@ -26,20 +26,29 @@ MlirAttribute ttmlirTTArchAttrGet(MlirContext ctx, uint32_t arch) { return wrap(ArchAttr::get(unwrap(ctx), static_cast(arch))); } +MlirAttribute ttmlirTTDataTypeAttrGet(MlirContext ctx, + uint16_t *supportedDataTypes) { + return wrap(DataTypeAttr::get(unwrap(ctx), + static_cast(*supportedDataTypes))); +} + MlirAttribute ttmlirTTChipDescAttrGet( MlirContext ctx, MlirAttribute arch, int64_t *grid, size_t gridSize, unsigned l1Size, unsigned numDramChannels, unsigned dramChannelSize, unsigned nocL1AddressAlignBytes, unsigned pcieAddressAlignBytes, unsigned nocDRAMAddressAlignBytes, unsigned l1UnreservedBase, unsigned eriscL1UnreservedBase, unsigned dramUnreservedBase, - MlirAttribute chipPhysicalCores) { + MlirAttribute chipPhysicalCores, MlirAttribute *supportedDataTypes, + MlirAttribute *supportedTileSizes) { std::vector gridVec(grid, grid + gridSize); return wrap(ChipDescAttr::get( unwrap(ctx), mlir::dyn_cast(unwrap(arch)), gridVec, l1Size, numDramChannels, dramChannelSize, nocL1AddressAlignBytes, pcieAddressAlignBytes, nocDRAMAddressAlignBytes, l1UnreservedBase, eriscL1UnreservedBase, dramUnreservedBase, - mlir::dyn_cast(unwrap(chipPhysicalCores)))); + mlir::dyn_cast(unwrap(chipPhysicalCores)), + mlir::dyn_cast(unwrap(*supportedDataTypes)), + mlir::dyn_cast(unwrap(*supportedTileSizes)))); } MlirAttribute ttmlirTTChipCoordAttrGet(MlirContext ctx, unsigned rack, diff --git a/lib/Dialect/TT/IR/TTOpsTypes.cpp b/lib/Dialect/TT/IR/TTOpsTypes.cpp index b6cafdfb4..cf0e8b0c6 100644 --- a/lib/Dialect/TT/IR/TTOpsTypes.cpp +++ b/lib/Dialect/TT/IR/TTOpsTypes.cpp @@ -29,6 +29,43 @@ mlir::tt::SystemDescAttr mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) { // Populate a dummy n150 SmallVector gridShape = {8, 8}; + + // populate a placeholder for supported tile sizes + SmallVector supported_data_types; + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::Float32)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::Float16)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFloat16)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_Float8)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat8)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_Float4)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat4)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_Float2)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat2)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::UInt32)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::UInt16)); + supported_data_types.push_back( + tt::DataTypeAttr::get(context, tt::DataType::UInt8)); + + // populate a placeholder for supported tile sizes + SmallVector supported_tile_sizes; + supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 4, 16)); + supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 16, 16)); + supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 32, 16)); + supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 4, 32)); + supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 16, 32)); + supported_tile_sizes.push_back(tt::TileSizeAttr::get(context, 32, 32)); + SmallVector workerCores; workerCores.reserve(gridShape[0] * gridShape[1]); for (std::int64_t y = 0; y < gridShape[0]; ++y) { @@ -50,7 +87,8 @@ mlir::tt::SystemDescAttr::getDefault(MLIRContext *context) { context, tt::ArchAttr::get(context, tt::Arch::WormholeB0), gridShape, 1499136, 12, (1 << 30), 16, 32, 32, 0, 0, 0, tt::ChipPhysicalCoresAttr::get(context, workerCores, dramCores, - {}, {})), + {}, {}), + supported_data_types, supported_tile_sizes), }, // Chip Descriptor Indices { @@ -117,6 +155,7 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) { eth_inactive_cores.emplace_back( tt::CoreCoordAttr::get(context, core->y(), core->x())); } + // Create ChipPhysicalCoresAttr from the list of CoreCoordAttr instances auto chip_physical_cores_attr = tt::ChipPhysicalCoresAttr::get( context, worker_cores, dram_cores, eth_cores, eth_inactive_cores); @@ -134,6 +173,68 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) { break; } + std::vector supported_data_types_attr; + + for (auto it : *(element->supported_data_types())) { + switch (it) { + case ::tt::target::DataType::Float32: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::Float32)); + break; + case ::tt::target::DataType::Float16: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::Float16)); + break; + case ::tt::target::DataType::BFloat16: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFloat16)); + break; + case ::tt::target::DataType::BFP_Float8: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_Float8)); + break; + case ::tt::target::DataType::BFP_BFloat8: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat8)); + break; + case ::tt::target::DataType::BFP_Float4: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_Float4)); + break; + case ::tt::target::DataType::BFP_BFloat4: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat4)); + break; + case ::tt::target::DataType::BFP_Float2: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_Float2)); + break; + case ::tt::target::DataType::BFP_BFloat2: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::BFP_BFloat2)); + break; + case ::tt::target::DataType::UInt32: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::UInt32)); + break; + case ::tt::target::DataType::UInt16: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::UInt16)); + break; + case ::tt::target::DataType::UInt8: + supported_data_types_attr.push_back( + tt::DataTypeAttr::get(context, tt::DataType::UInt8)); + break; + } + } + + SmallVector supported_tile_sizes_attr; + + for (auto it : *(element->supported_tile_sizes())) { + supported_tile_sizes_attr.push_back( + tt::TileSizeAttr::get(context, it->y(), it->x())); + } + auto current_chip_desc_attr = tt::ChipDescAttr::get( context, tt::ArchAttr::get(context, arch), {element->grid_size()->y(), element->grid_size()->x()}, @@ -142,7 +243,8 @@ mlir::tt::SystemDescAttr::getFromPath(MLIRContext *context, std::string &path) { element->pcie_address_align_bytes(), element->noc_dram_address_align_bytes(), element->l1_unreserved_base(), element->erisc_l1_unreserved_base(), element->dram_unreserved_base(), - chip_physical_cores_attr); + chip_physical_cores_attr, supported_data_types_attr, + supported_tile_sizes_attr); chip_desc_list.push_back(current_chip_desc_attr); } diff --git a/python/TTModule.cpp b/python/TTModule.cpp index 434a3143f..8ebf4c9a7 100644 --- a/python/TTModule.cpp +++ b/python/TTModule.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include #include #include "ttmlir/Bindings/Python/TTMLIRModule.h" @@ -10,6 +11,7 @@ #include "mlir/CAPI/IR.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Target/Common/types_generated.h" namespace mlir::ttmlir::python { void populateTTModule(py::module &m) { @@ -108,6 +110,12 @@ void populateTTModule(py::module &m) { tt::ArchAttr::get(unwrap(ctx), static_cast(arch))); }); + py::class_(m, "DataTypeAttr") + .def_static("get", [](MlirContext ctx, uint16_t *supportedDataTypes) { + return wrap(tt::DataTypeAttr::get( + unwrap(ctx), static_cast(*supportedDataTypes))); + }); + py::class_(m, "ChipDescAttr") .def_static( "get", @@ -116,7 +124,9 @@ void populateTTModule(py::module &m) { unsigned dramChannelSize, unsigned nocL1AddressAlignBytes, unsigned pcieAddressAlignBytes, unsigned nocDRAMAddressAlignBytes, unsigned l1UnreservedBase, unsigned eriscL1UnreservedBase, - unsigned dramUnreservedBase, MlirAttribute chipPhysicalCores) { + unsigned dramUnreservedBase, MlirAttribute chipPhysicalCores, + MlirAttribute supportedDataTypes, + MlirAttribute supportedTileSizes) { return wrap(tt::ChipDescAttr::get( unwrap(ctx), mlir::cast(unwrap(arch)), grid, l1Size, numDramChannels, dramChannelSize, @@ -124,7 +134,9 @@ void populateTTModule(py::module &m) { nocDRAMAddressAlignBytes, l1UnreservedBase, eriscL1UnreservedBase, dramUnreservedBase, mlir::dyn_cast( - unwrap(chipPhysicalCores)))); + unwrap(chipPhysicalCores)), + mlir::cast(unwrap(supportedDataTypes)), + mlir::cast(unwrap(supportedTileSizes)))); }); py::class_(m, "ChipCoordAttr") diff --git a/runtime/lib/common/system_desc.cpp b/runtime/lib/common/system_desc.cpp index d52e5555e..a6df9d798 100644 --- a/runtime/lib/common/system_desc.cpp +++ b/runtime/lib/common/system_desc.cpp @@ -5,6 +5,7 @@ #include "tt/runtime/utils.h" #include "ttmlir/Target/TTNN/Target.h" #include "ttmlir/Version.h" +#include "types_generated.h" #include #include @@ -157,6 +158,7 @@ getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) { std::vector<::flatbuffers::Offset> chipDescs; std::vector chipDescIndices; std::vector<::tt::target::ChipCapability> chipCapabilities; + // Ignore for now std::vector<::tt::target::ChipCoord> chipCoords = { ::tt::target::ChipCoord(0, 0, 0, 0)}; @@ -170,12 +172,33 @@ getCurrentSystemDescImpl(const ::tt::tt_metal::DeviceMesh &deviceMesh) { // Extract physical core coordinates for worker, dram, eth cores auto chipPhysicalCores = createChipPhysicalCores(device, fbb); + // The following is temporary place-holder value to be replaced by API + // value. + std::vector<::tt::target::DataType> supportedDataTypesVector = { + ::tt::target::DataType::Float32, ::tt::target::DataType::Float16, + ::tt::target::DataType::BFloat16, ::tt::target::DataType::BFP_Float8, + ::tt::target::DataType::BFP_BFloat8, ::tt::target::DataType::BFP_Float4, + ::tt::target::DataType::BFP_BFloat4, ::tt::target::DataType::BFP_Float2, + ::tt::target::DataType::BFP_BFloat2, ::tt::target::DataType::UInt32, + ::tt::target::DataType::UInt16, ::tt::target::DataType::UInt8}; + + auto supportedDataTypes = fbb.CreateVector(supportedDataTypesVector); + + std::vector<::tt::target::Dim2d> supportedTileSizesVector = { + ::tt::target::Dim2d(4, 16), ::tt::target::Dim2d(16, 16), + ::tt::target::Dim2d(32, 16), ::tt::target::Dim2d(4, 32), + ::tt::target::Dim2d(16, 32), ::tt::target::Dim2d(32, 32)}; + + auto supportedTileSizes = + fbb.CreateVectorOfStructs(supportedTileSizesVector); + chipDescs.push_back(::tt::target::CreateChipDesc( fbb, toFlatbuffer(device->arch()), &deviceGrid, device->l1_size_per_core(), device->num_dram_channels(), device->dram_size_per_channel(), L1_ALIGNMENT, PCIE_ALIGNMENT, DRAM_ALIGNMENT, L1_UNRESERVED_BASE, ERISC_L1_UNRESERVED_BASE, - DRAM_UNRESERVED_BASE, chipPhysicalCores)); + DRAM_UNRESERVED_BASE, chipPhysicalCores, supportedDataTypes, + supportedTileSizes)); chipDescIndices.push_back(device->id()); // Derive chip capability ::tt::target::ChipCapability chipCapability =