Skip to content

Commit

Permalink
Create tensor grids with same rank as device grid bug fix #500 (#515)
Browse files Browse the repository at this point in the history
Previously we were just creating tensor layout grids with rank 1 or 2
depending on the tensor's rank.  This however is incorrect, the tensor
grid must be of the same rank as the device grid.

The fix is to use the device grid's rank in the layout type converter to
ensure that by default a tensor layout gets a grid of equivalent rank.
  • Loading branch information
nsmithtt authored Aug 28, 2024
1 parent 58c515a commit 8d96ba8
Show file tree
Hide file tree
Showing 22 changed files with 51 additions and 30 deletions.
5 changes: 5 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def TT_GridAttr : TT_Attr<"Grid", "grid"> {
static GridAttr get(::mlir::MLIRContext *context, ArrayRef<int64_t> shape) {
return GridAttr::get(context, shape, AffineMap::get(context));
}

static GridAttr get(::mlir::MLIRContext *context, std::int64_t rank) {
return GridAttr::get(context, SmallVector<std::int64_t>(rank, 1));
}
}];
}

Expand Down Expand Up @@ -259,6 +263,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
static LayoutAttr get(::mlir::MLIRContext *context,
RankedTensorType ty,
MemorySpace memorySpace,
GridAttr grid,
Type elementType);
LayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape, GridAttr grid, ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}});
LayoutAttr withGrid(::mlir::MLIRContext *context,
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,7 @@ LayoutAttr LayoutAttr::get(
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals,
OOBVal oobVal) {
if (not grid) {
grid = tensorShape.size() == 1 ? GridAttr::get(context, {1})
: GridAttr::get(context, {1, 1});
grid = GridAttr::get(context, tensorShape.size());
}

auto linear = collapsedLinearAffineMap(context, tensorShape, grid.getShape(),
Expand All @@ -474,9 +473,11 @@ LayoutAttr LayoutAttr::get(
}

LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty,
MemorySpace memorySpace, Type elementType) {
MemorySpace memorySpace, GridAttr grid,
Type elementType) {
assert(ty);
return get(context, ty.getShape(), elementType, memorySpace, {}, {{0, -1}},
assert(grid);
return get(context, ty.getShape(), elementType, memorySpace, grid, {{0, -1}},
OOBVal::Undef);
}

Expand Down
21 changes: 14 additions & 7 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,20 @@ inline MemorySpace uppermostMemorySpace(OperandConstraint operandConstraint) {

class TTIRLayoutTensorTypeConverter : public TypeConverter {
public:
TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace) {
TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace,
GridAttr deviceGrid) {
addConversion([](Type type) { return type; });
addConversion([ctx, initMemorySpace](RankedTensorType type) -> Type {
addConversion([ctx, initMemorySpace,
deviceGrid](RankedTensorType type) -> Type {
auto layout = type.getEncoding();
if (layout) {
return type;
}
std::int64_t deviceGridRank = deviceGrid.getShape().size();
// Default to single core grid
auto tensorGrid = GridAttr::get(ctx, deviceGridRank);
// Default to initMemorySpace, the optimizer might decide otherwise
auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace);
auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace, tensorGrid);
return RankedTensorType::get(type.getShape(), type.getElementType(),
newLayout);
});
Expand Down Expand Up @@ -526,8 +531,8 @@ static std::optional<Value> createToLayoutOp(PatternRewriter &rewriter,
return std::nullopt;
}

auto desiredLayout =
rewriter.getAttr<LayoutAttr>(ty, desiredMemorySpace, desiredElementType);
auto desiredLayout = rewriter.getAttr<LayoutAttr>(
ty, desiredMemorySpace, currLayout.getGrid(), desiredElementType);
auto output = rewriter.create<tensor::EmptyOp>(
loc, ty.getShape(), ty.getElementType(), desiredLayout);

Expand Down Expand Up @@ -627,8 +632,10 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {

void runOnOperation() final {
{
TTIRLayoutTensorTypeConverter typeConverter(&getContext(),
initMemorySpace);
auto device = getCurrentScopeDevice(getOperation());
assert(device && "Device not found");
TTIRLayoutTensorTypeConverter typeConverter(
&getContext(), initMemorySpace, device.getGrid());
RewritePatternSet patterns(&getContext());
patterns.add<TTIRLayoutTensorTypeRewriter>(typeConverter, &getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTIR/test_layout.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-layout %s | FileCheck %s
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<8x64x128xf32>, %arg1: tensor<8x64x128xf32>) -> tensor<8x64x128xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @bcast_one_dim(%arg0: tensor<2x64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<2x64x128xf32> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_concat.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_div.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_ge.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
// CHECK: #[[TILED_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>>
module attributes {} {
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Dialect/TTNN/simple_mean.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|tile|any_device|any_device_tile>
module attributes {} {
module {
func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_multiply.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_subtract.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_sum.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>) -> tensor<128x64xbf16> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x16xbf16>) -> tensor<16x64xbf16> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
Expand Down
8 changes: 8 additions & 0 deletions test/ttmlir/Translate/TTNN/1d_tensor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | ttmlir-translate --ttnn-to-flatbuffer
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>

func.func @embedding_1d_tensor(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> {
%0 = tensor.empty() : tensor<32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
return %1 : tensor<32x128xf32>
}

0 comments on commit 8d96ba8

Please sign in to comment.