Skip to content

Commit

Permalink
Add sharding support for ttnn backend (#541)
Browse files Browse the repository at this point in the history
* Adds tensor memory layout attribute to LayoutAttr to be consumed by the TTNN backend. Will need refactor in the future as metal backend does not need this field #596 
* Adds runtime support for generating sharded memory configs accordingly. Currently only BlockSharded is supported
* Adds eltwise sharding tests under Silicon/TTNN/sharded
  • Loading branch information
jnie-TT authored Sep 3, 2024
1 parent 75bd688 commit c75811b
Show file tree
Hide file tree
Showing 35 changed files with 738 additions and 206 deletions.
11 changes: 6 additions & 5 deletions include/ttmlir-c/TTAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,19 @@ MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTSystemDescAttrGet(
MlirAttribute *chipCoords, size_t chipCoordsSize,
MlirAttribute *chipChannels, size_t chipChannelsSize);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTLayoutAttrGet(MlirContext ctx,
MlirAffineMap linear,
unsigned oobVal,
MlirAttribute grid,
MlirType memref);
MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear, unsigned oobVal,
MlirAttribute grid, MlirType memref, unsigned memLayout);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTMemorySpaceAttrGet(MlirContext ctx, uint32_t memorySpace);

MLIR_CAPI_EXPORTED MlirAttribute ttmlirTTOOBValAttrGet(MlirContext ctx,
uint32_t oobVal);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTTensorMemoryLayoutAttrGet(MlirContext ctx, uint32_t memLayout);

MLIR_CAPI_EXPORTED MlirAttribute
ttmlirTTIteratorTypeAttrGet(MlirContext ctx, uint32_t iteratorType);

Expand Down
43 changes: 39 additions & 4 deletions include/ttmlir/Dialect/TT/IR/TTOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ def TT_MemorySpace : I32EnumAttr<"MemorySpace", "TT MemorySpace",
let cppNamespace = "::mlir::tt";
}

def TT_NoneLayout : I32EnumAttrCase<"NoneLayout", 0, "none_layout">;
def TT_Interleaved : I32EnumAttrCase<"Interleaved", 1, "interleaved">;
def TT_SingleBank : I32EnumAttrCase<"SingleBank", 2, "single_bank">;
def TT_HeightSharded : I32EnumAttrCase<"HeightSharded", 3, "height_sharded">;
def TT_WidthSharded : I32EnumAttrCase<"WidthSharded", 4, "width_sharded">;
def TT_BlockSharded : I32EnumAttrCase<"BlockSharded", 5, "block_sharded">;

def TT_TensorMemoryLayout : I32EnumAttr<"TensorMemoryLayout", "TT TensorMemoryLayout",
[
TT_NoneLayout,
TT_Interleaved,
TT_SingleBank,
TT_HeightSharded,
TT_WidthSharded,
TT_BlockSharded,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
}

def TT_Parallel : I32EnumAttrCase<"Parallel", 0, "parallel">;
def TT_Systolic : I32EnumAttrCase<"Systolic", 1, "systolic">;
def TT_Broadcast : I32EnumAttrCase<"Broadcast", 2, "broadcast">;
Expand Down Expand Up @@ -109,20 +129,35 @@ def TT_OperandConstraintDRAM : I32BitEnumAttrCaseBit<"DRAM", 1, "dram">;
def TT_OperandConstraintL1 : I32BitEnumAttrCaseBit<"L1", 2, "l1">;
def TT_OperandConstraintScalar : I32BitEnumAttrCaseBit<"Scalar", 3, "scalar">;
def TT_OperandConstraintTile : I32BitEnumAttrCaseBit<"Tile", 4, "tile">;
def TT_OperandConstraintAny : I32BitEnumAttrCaseGroup<"Any", [TT_OperandConstraintSystem, TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile], "any">;
def TT_OperandConstraintAnyDevice : I32BitEnumAttrCaseGroup<"AnyDevice", [TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile], "any_device">;
def TT_OperandConstraintAnyDeviceTile : I32BitEnumAttrCaseGroup<"AnyDeviceTile", [TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintTile], "any_device_tile">;

def TT_OperandConstraintNoneLayout : I32BitEnumAttrCaseBit<"NoneLayout", 5, "none_layout">;
def TT_OperandConstraintInterleaved : I32BitEnumAttrCaseBit<"Interleaved", 6, "interleaved">;
def TT_OperandConstraintSingleBank : I32BitEnumAttrCaseBit<"SingleBank", 7, "single_bank">;
def TT_OperandConstraintHeightSharded : I32BitEnumAttrCaseBit<"HeightSharded", 8, "height_sharded">;
def TT_OperandConstraintWidthSharded : I32BitEnumAttrCaseBit<"WidthSharded", 9, "width_sharded">;
def TT_OperandConstraintBlockSharded : I32BitEnumAttrCaseBit<"BlockSharded", 10, "block_sharded">;
def TT_OperandConstraintAnyLayout : I32BitEnumAttrCaseGroup<"AnyLayout", [TT_OperandConstraintNoneLayout, TT_OperandConstraintInterleaved, TT_OperandConstraintSingleBank, TT_OperandConstraintHeightSharded, TT_OperandConstraintWidthSharded, TT_OperandConstraintBlockSharded], "any_layout">;
def TT_OperandConstraintAny : I32BitEnumAttrCaseGroup<"Any", [TT_OperandConstraintSystem, TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any">;
def TT_OperandConstraintAnyDevice : I32BitEnumAttrCaseGroup<"AnyDevice", [TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any_device">;
def TT_OperandConstraintAnyDeviceTile : I32BitEnumAttrCaseGroup<"AnyDeviceTile", [TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any_device_tile">;
def TT_OperandConstraintL1BlockSharded : I32BitEnumAttrCaseGroup<"L1BlockSharded", [TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintBlockSharded], "l1_block_sharded">;
def TT_OperandConstraint : I32BitEnumAttr<"OperandConstraint", "TT Operand Constraints",
[
TT_OperandConstraintSystem,
TT_OperandConstraintDRAM,
TT_OperandConstraintL1,
TT_OperandConstraintScalar,
TT_OperandConstraintTile,
TT_OperandConstraintNoneLayout,
TT_OperandConstraintInterleaved,
TT_OperandConstraintSingleBank,
TT_OperandConstraintHeightSharded,
TT_OperandConstraintWidthSharded,
TT_OperandConstraintBlockSharded,
TT_OperandConstraintAnyLayout,
TT_OperandConstraintAny,
TT_OperandConstraintAnyDevice,
TT_OperandConstraintAnyDeviceTile,
TT_OperandConstraintL1BlockSharded,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
Expand Down
10 changes: 10 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ inline bool isDeviceMemorySpace(MemorySpace memorySpace) {
memorySpace == MemorySpace::DeviceL1;
}

inline bool isL1MemorySpace(MemorySpace memorySpace) {
return memorySpace == MemorySpace::DeviceL1;
}

inline bool isShardedMemoryLayout(TensorMemoryLayout layout) {
return layout == TensorMemoryLayout::HeightSharded ||
layout == TensorMemoryLayout::WidthSharded ||
layout == TensorMemoryLayout::BlockSharded;
}

inline void printDimensionList(::mlir::AsmPrinter &printer,
::llvm::ArrayRef<int64_t> shape) {
printer.printDimensionList(shape);
Expand Down
21 changes: 15 additions & 6 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,9 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
let parameters = (ins AttrParameter<"AffineMap", "An affine map that defines how the logical tensor dimensions map to a grid shape.">:$linear,
AttrParameter<"OOBVal", "A tracked out of bounds value that fills padding space.">:$oob_val,
AttrParameter<"GridAttr", "The grid shape that this tensor is divided onto.">:$grid,
AttrParameter<"MemRefType", "A memref that describes the physical footprint allocation of the shard. It must also have a shape with rank equal to grid.">:$memref);
let assemblyFormat = "`<` $linear`,` $oob_val`,` $grid`,` $memref `>`";
AttrParameter<"MemRefType", "A memref that describes the physical footprint allocation of the shard. It must also have a shape with rank equal to grid.">:$memref,
AttrParameter<"TensorMemoryLayout", "The layout of the tensor in memory.">:$mem_layout);
let assemblyFormat = "`<` $linear`,` $oob_val`,` $grid`,` $memref`,` $mem_layout `>`";

let extraClassDeclaration = [{
static LayoutAttr get(::mlir::MLIRContext *context,
Expand All @@ -255,29 +256,33 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
MemorySpace memorySpace = MemorySpace::System,
GridAttr grid = {},
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}},
OOBVal oobVal = OOBVal::Undef);
OOBVal oobVal = OOBVal::Undef,
TensorMemoryLayout memLayout = TensorMemoryLayout::NoneLayout);
static LayoutAttr get(::mlir::MLIRContext *context,
RankedTensorType ty,
MemorySpace memorySpace = MemorySpace::System,
GridAttr grid = {},
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}},
OOBVal oobVal = OOBVal::Undef);
OOBVal oobVal = OOBVal::Undef,
TensorMemoryLayout memLayout = TensorMemoryLayout::NoneLayout);
static LayoutAttr get(::mlir::MLIRContext *context,
RankedTensorType ty,
MemorySpace memorySpace,
GridAttr grid,
Type elementType);
Type elementType,
TensorMemoryLayout memLayout);
LayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape, GridAttr grid, ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}});
LayoutAttr withGrid(::mlir::MLIRContext *context,
RankedTensorType ty,
GridAttr grid,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}});
LayoutAttr withElementType(::mlir::MLIRContext *context, Type elementType);
LayoutAttr withMemorySpace(::mlir::MLIRContext *context, MemorySpace memorySpace);

LayoutAttr withMemoryLayout(::mlir::MLIRContext *context, TensorMemoryLayout memLayout);
MemorySpace getMemorySpace() const;
bool isSystemMemorySpace() const { return ::mlir::tt::isSystemMemorySpace(getMemorySpace()); }
bool isDeviceMemorySpace() const { return ::mlir::tt::isDeviceMemorySpace(getMemorySpace()); }
bool hasShardedTensorMemoryLayout() const;
bool isTiled() const;
Type getElementType() const;
Type getScalarElementType() const;
Expand Down Expand Up @@ -337,6 +342,10 @@ def TT_MemorySpaceAttr : EnumAttr<TT_Dialect, TT_MemorySpace, "memory_space"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_TensorMemoryLayoutAttr : EnumAttr<TT_Dialect, TT_TensorMemoryLayout, "tensor_memory_layout"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_OOBValAttr : EnumAttr<TT_Dialect, TT_OOBVal, "oob_val"> {
let assemblyFormat = "`<` $value `>`";
}
Expand Down
17 changes: 13 additions & 4 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpI
- Some combination of the above

```llvm
#layout = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>>
#layout1 = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>>
#layout = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #system>, none_layout>
#layout1 = #tt.layout<8192x128x1, undef, <1x1>, memref<64x128xf32, #l1_>, none_layout>
%1 = "ttir.to_layout"(%arg0, %0) : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout1>) -> tensor<64x128xf32, #layout1>
```
}];
Expand All @@ -89,8 +89,17 @@ def TTIR_ToLayoutOp : TTIR_Op<"to_layout", [DestinationStyleOpInterface, TTIROpI
// TODO return below, but we need a way to properly create an ArrayAttr:
// return {OperandConstraint::Any, OperandConstraint::Any};
}
// Returns a tuple of booleans indicating if the op changes layout, grid, format, or memory space.
std::tuple<bool, bool, bool, bool> compoundComponents();

struct CompoundComponents {
bool isLayoutChange;
bool isGridChange;
bool isFormatChange;
bool isMemorySpaceChange;
bool isMemoryLayoutChange;
};

// Returns booleans indicating if the op changes layout, grid, format, memory space or memory layout.
CompoundComponents compoundComponents();
}];

let hasVerifier = 1;
Expand Down
6 changes: 5 additions & 1 deletion include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
"::mlir::tt::MemorySpace",
/*default=*/"::mlir::tt::MemorySpace::DeviceDRAM",
"Set the default memory space for layout pass to prefer for operation operands, if not constrained">,
Option<"defaultDeviceMemoryLayout", "default-device-memory-layout",
"::mlir::tt::TensorMemoryLayout",
/*default=*/"::mlir::tt::TensorMemoryLayout::Interleaved",
"Set the default memory layout for layout pass to prefer for operation operands that are on device, if not constrained">
];
}

def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> {
let summary = "Split compound layouts.";
let description = [{
A single to_layout op in ttir can simultaneously perform multiple layout transformations
at once, including changing layout, format, or memory space. This pass splits each of
at once, including changing layout, format, memory space or memory layout. This pass splits each of
these transformation categories into separate to_layout ops.
}];
}
Expand Down
10 changes: 10 additions & 0 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,21 @@ enum ChipCapability: uint32 (bit_flags) {
HostMMIO = 1,
}

enum TensorMemoryLayout: ushort {
NoneLayout = 0,
Interleaved = 1,
SingleBank = 2,
HeightSharded = 3,
WidthSharded = 4,
BlockSharded = 5,
}

table MemoryDesc {
shape: [int];
tile_shape: Dim2d;
data_type: DataType;
memory_space: MemorySpace;
memory_layout: TensorMemoryLayout;
size: uint64;
}

Expand Down
27 changes: 24 additions & 3 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ inline ::tt::target::OOBVal toFlatbuffer(FlatbufferObjectCache &,
}
}

inline ::tt::target::TensorMemoryLayout
toFlatbuffer(FlatbufferObjectCache &, TensorMemoryLayout memLayout) {
switch (memLayout) {
case TensorMemoryLayout::NoneLayout:
return ::tt::target::TensorMemoryLayout::NoneLayout;
case TensorMemoryLayout::Interleaved:
return ::tt::target::TensorMemoryLayout::Interleaved;
case TensorMemoryLayout::SingleBank:
return ::tt::target::TensorMemoryLayout::SingleBank;
case TensorMemoryLayout::HeightSharded:
return ::tt::target::TensorMemoryLayout::HeightSharded;
case TensorMemoryLayout::WidthSharded:
return ::tt::target::TensorMemoryLayout::WidthSharded;
case TensorMemoryLayout::BlockSharded:
return ::tt::target::TensorMemoryLayout::BlockSharded;
}
}

inline std::uint64_t getElementSizeBytes(DataType dtype) {
switch (dtype) {
case DataType::Float32:
Expand Down Expand Up @@ -344,7 +362,8 @@ arrayAttrToFlatbuffer(FlatbufferObjectCache &cache,
}

inline flatbuffers::Offset<::tt::target::MemoryDesc>
memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref) {
memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref,
::mlir::tt::TensorMemoryLayout memLayout) {
auto shapeInt64 = memref.getShape();
std::vector<int32_t> shape(shapeInt64.begin(), shapeInt64.end());
DataType dtype = DataType::Float32;
Expand All @@ -360,6 +379,7 @@ memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref) {
dtype = elementTypeToDataType(elementType);
elementSize = getElementSizeBytes(dtype);
}

std::uint64_t size = elementSize;
for (auto dim : shapeInt64) {
size *= dim;
Expand All @@ -370,7 +390,7 @@ memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref) {
toFlatbuffer(
cache,
mlir::cast<MemorySpaceAttr>(memref.getMemorySpace()).getValue()),
size);
toFlatbuffer(cache, memLayout), size);
}

inline flatbuffers::Offset<::tt::target::LayoutDesc>
Expand All @@ -385,7 +405,8 @@ layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, Attribute attr,
return ::tt::target::CreateLayoutDescDirect(
*cache.fbb, &stride, toFlatbuffer(cache, layoutAttr.getOobVal()),
&coreRangeSet,
cache.getOrCreate(layoutAttr.getMemref(), memrefAttrToFlatbuffer));
cache.getOrCreate(layoutAttr.getMemref(), memrefAttrToFlatbuffer,
layoutAttr.getMemLayout()));
}

inline flatbuffers::Offset<::tt::target::TensorDesc>
Expand Down
11 changes: 9 additions & 2 deletions lib/CAPI/TTAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ MlirAttribute ttmlirTTSystemDescAttrGet(

MlirAttribute ttmlirTTLayoutAttrGet(MlirContext ctx, MlirAffineMap linear,
unsigned oobVal, MlirAttribute grid,
MlirType memref) {
MlirType memref, unsigned memLayout) {
mlir::AffineMap affineMap = mlir::AffineMap::getFromOpaquePointer(linear.ptr);
return wrap(LayoutAttr::get(unwrap(ctx), affineMap,
static_cast<OOBVal>(oobVal),
mlir::cast<GridAttr>(unwrap(grid)),
mlir::cast<MemRefType>(unwrap(memref))));
mlir::cast<MemRefType>(unwrap(memref)),
static_cast<TensorMemoryLayout>(memLayout)));
}

MlirAttribute ttmlirTTMemorySpaceAttrGet(MlirContext ctx,
Expand All @@ -131,6 +132,12 @@ MlirAttribute ttmlirTTOOBValAttrGet(MlirContext ctx, uint32_t oobVal) {
return wrap(OOBValAttr::get(unwrap(ctx), static_cast<tt::OOBVal>(oobVal)));
}

MlirAttribute ttmlirTTTensorMemoryLayoutAttrGet(MlirContext ctx,
uint32_t memLayout) {
return wrap(TensorMemoryLayoutAttr::get(
unwrap(ctx), static_cast<tt::TensorMemoryLayout>(memLayout)));
}

MlirAttribute ttmlirTTIteratorTypeAttrGet(MlirContext ctx,
uint32_t iteratorType) {
return wrap(IteratorTypeAttr::get(
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TT/IR/TTDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ struct TTOpAsmDialectInterface : public OpAsmDialectInterface {
os << "any_device";
} else if (value == OperandConstraint::AnyDeviceTile) {
os << "any_device_tile";
} else if (value == OperandConstraint::L1BlockSharded) {
os << "l1_block_sharded";
} else {
os << "operand_constraint";
}
Expand Down
Loading

0 comments on commit c75811b

Please sign in to comment.