Skip to content

Commit

Permalink
Address code review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Dec 24, 2024
1 parent d1a36ee commit 3b43c5a
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 66 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,8 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,8 @@ def TTNN_RepeatOp : TTNN_Op<"repeat"> {
I64ArrayAttr:$shape);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_SliceOp: TTNN_NamedDPSOp<"slice"> {
Expand Down
39 changes: 16 additions & 23 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,6 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern

if (inputType.getRank() == outputType.getRank()) {
// unsqueeze is not needed, proceed to converting to broadcast

tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

Expand All @@ -766,41 +765,35 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern

rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor), dimArg);
adaptor.getOperand(), outputTensor, dimArg);
} else {

std::vector<int64_t> UnsqueezeShape;
for (unsigned int i = 0; i < inputType.getRank(); i++) {
UnsqueezeShape.push_back(inputType.getDimSize(i));
}

for (unsigned int i = inputType.getRank(); i < outputType.getRank();
i++) {
UnsqueezeShape.insert(UnsqueezeShape.begin(), 1);
if ((unsigned)UnsqueezeShape.size() == (unsigned)outputType.getRank()) {
break;
}
// This stablehlo operation cannot be represented by a single TTIR
// operation. It has to be split into ttir.reshape followed by a
// ttir.broadcast op.
SmallVector<int64_t> unsqueezeShape(outputType.getRank(), 1);
::llvm::ArrayRef<int64_t> broadcast_in_dim =
adaptor.getBroadcastDimensions();

for (int64_t i = 0; i < inputType.getRank(); i++) {
unsqueezeShape[broadcast_in_dim[i]] = inputType.getDimSize(i);
}

RankedTensorType unsqueezeOutputType =
RankedTensorType::get(UnsqueezeShape, outputType.getElementType());
RankedTensorType::get(unsqueezeShape, outputType.getElementType());

tensor::EmptyOp reshapeOutputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), unsqueezeOutputType.getShape(),
unsqueezeOutputType.getElementType());

std::vector<int32_t> new_shape_i32;
for (int64_t dim : UnsqueezeShape) {
new_shape_i32.push_back(static_cast<int32_t>(dim));
}

auto reshapeDim = rewriter.getI32ArrayAttr(new_shape_i32);
SmallVector<int32_t> reshapeDim(unsqueezeShape.begin(),
unsqueezeShape.end());
auto reshapeDimAttr = rewriter.getI32ArrayAttr(reshapeDim);

mlir::tt::ttir::ReshapeOp reshape =
rewriter.create<mlir::tt::ttir::ReshapeOp>(
srcOp.getLoc(),
getTypeConverter()->convertType(unsqueezeOutputType),
adaptor.getOperand(), reshapeOutputTensor, reshapeDim);
adaptor.getOperand(), reshapeOutputTensor, reshapeDimAttr);

tensor::EmptyOp broadcastOutputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
Expand All @@ -811,7 +804,7 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern
rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
srcOp,
getTypeConverter()->convertType(broadcastOutputTensor.getType()),
Value(reshape.getResult()), Value(broadcastOutputTensor), dimArg);
reshape.getResult(), broadcastOutputTensor, dimArg);
}

return success();
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,23 @@ ::mlir::OpFoldResult mlir::tt::ttir::ReshapeOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

// BroadcastOp verification
::mlir::LogicalResult mlir::tt::ttir::BroadcastOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();

// Check that the input rank matches the rank of the output tensor
if (inputType.getRank() != outputType.getRank()) {
return emitOpError("Input tensor rank should match output tensor rank.");
}

return success();
}

//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,41 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// RepeatOp
//===----------------------------------------------------------------------===//

// RepeatOp verification
::mlir::LogicalResult mlir::tt::ttnn::RepeatOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getResult().getType();

// Check that the input rank matches the rank of the output tensor
if (inputType.getRank() != outputType.getRank()) {
return emitOpError("Input tensor rank should match output tensor rank.");
}

auto shape = getShape();

// Check that the shape size matches the rank of the output tensor
if (static_cast<int64_t>(shape.size()) != outputType.getRank()) {
return emitOpError("Input tensor rank should match output tensor rank.");
}

auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();

size_t shape_size = shape.size();
for (size_t i = 0; i < shape_size; i++) {
int64_t dim_value = mlir::cast<IntegerAttr>(shape[i]).getInt();
if (inputShape[i] * dim_value != outputShape[i]) {
return emitOpError("Input shape does not repeat to output shape.");
}
}

return success();
}

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
Expand All @@ -293,6 +328,7 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() {
::mlir::LogicalResult mlir::tt::ttnn::ReshapeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getResult().getType();

auto shape = getShape();
int64_t shape_size = static_cast<int64_t>(shape.size());

Expand Down
39 changes: 17 additions & 22 deletions test/ttmlir/Dialect/TTNN/simple_broadcast.mlir
Original file line number Diff line number Diff line change
@@ -1,30 +1,25 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
module {
func.func @main(%arg0: tensor<1x16x1xf32>, %arg1: tensor<1x1x32xi32>) -> tensor<1x16x32xf32> {
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.to_device"(%{{[0-9]+}}, %{{[0-9]+}})
// CHECK: %{{[0-9]+}} = "ttnn.repeat"([[VAL0]])
%0 = tensor.empty() : tensor<1x1x32xf32>
%1 = "ttir.typecast"(%arg1, %0) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<1x1x32xi32>, tensor<1x1x32xf32>) -> tensor<1x1x32xf32>
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"(%{{[0-9]+}})
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{dimension = [0, 1, 2]}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
%3 = "ttir.broadcast"(%arg0, %2) <{dimension = [1, 2]}> : (tensor<1x16x1xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%4 = tensor.empty() : tensor<1x16x32xf32>
%5 = "ttir.broadcast"(%1, %4) <{dimension = [0, 1, 2]}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%6 = tensor.empty() : tensor<1x16x32xf32>
%7 = "ttir.multiply"(%3, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %7 : tensor<1x16x32xf32>
%3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %3 : tensor<1x16x32xf32>
}
}

module {
func.func @main(%arg0: tensor<1x10xf32>, %arg1: tensor<10x1xf32>) -> tensor<10x10xf32> {
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.to_device"(%{{[0-9]+}}, %{{[0-9]+}})
module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.reshape"(%{{[0-9]+}})
// CHECK: %{{[0-9]+}} = "ttnn.repeat"([[VAL0]])
%0 = tensor.empty() : tensor<10x10xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{dimension = [0, 1]}> : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = tensor.empty() : tensor<10x10xf32>
%3 = "ttir.broadcast"(%arg1, %2) <{dimension = [0, 1]}> : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%4 = tensor.empty() : tensor<10x10xf32>
%5 = "ttir.subtract"(%1, %3, %4) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
return %5 : tensor<10x10xf32>
%0 = tensor.empty() : tensor<1x1xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 1 : i32]}> : (tensor<1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>
%2 = tensor.empty() : tensor<512x512xf32>
%3 = "ttir.broadcast"(%1, %2) <{dimension = [1]}> : (tensor<1x1xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
%4 = tensor.empty() : tensor<512x512xf32>
%5 = "ttir.maximum"(%3, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %5 : tensor<512x512xf32>
}
}
37 changes: 16 additions & 21 deletions test/ttmlir/Silicon/TTNN/simple_broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,26 @@
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module {
func.func @main(%arg0: tensor<1x16x1xf32>, %arg1: tensor<1x1x32xi32>) -> tensor<1x16x32xf32> {
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.to_device"(%{{[0-9]+}}, %{{[0-9]+}})
// CHECK: %{{[0-9]+}} = "ttnn.repeat"([[VAL0]])
%0 = tensor.empty() : tensor<1x1x32xf32>
%1 = "ttir.typecast"(%arg1, %0) <{operandSegmentSizes = array<i32: 1, 1>}> : (tensor<1x1x32xi32>, tensor<1x1x32xf32>) -> tensor<1x1x32xf32>
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"(%{{[0-9]+}})
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{dimension = [0, 1, 2]}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
%3 = "ttir.broadcast"(%arg0, %2) <{dimension = [1, 2]}> : (tensor<1x16x1xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%4 = tensor.empty() : tensor<1x16x32xf32>
%5 = "ttir.broadcast"(%1, %4) <{dimension = [0, 1, 2]}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%6 = tensor.empty() : tensor<1x16x32xf32>
%7 = "ttir.multiply"(%3, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %7 : tensor<1x16x32xf32>
%3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %3 : tensor<1x16x32xf32>
}
}

module {
func.func @main(%arg0: tensor<1x10xf32>, %arg1: tensor<10x1xf32>) -> tensor<10x10xf32> {
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.to_device"(%{{[0-9]+}}, %{{[0-9]+}})
module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.reshape"(%{{[0-9]+}})
// CHECK: %{{[0-9]+}} = "ttnn.repeat"([[VAL0]])
%0 = tensor.empty() : tensor<10x10xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{dimension = [0, 1]}> : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%2 = tensor.empty() : tensor<10x10xf32>
%3 = "ttir.broadcast"(%arg1, %2) <{dimension = [0, 1]}> : (tensor<10x1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%4 = tensor.empty() : tensor<10x10xf32>
%5 = "ttir.subtract"(%1, %3, %4) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
return %5 : tensor<10x10xf32>
%0 = tensor.empty() : tensor<1x1xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 1 : i32]}> : (tensor<1xf32>, tensor<1x1xf32>) -> tensor<1x1xf32>
%2 = tensor.empty() : tensor<512x512xf32>
%3 = "ttir.broadcast"(%1, %2) <{dimension = [1]}> : (tensor<1x1xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
%4 = tensor.empty() : tensor<512x512xf32>
%5 = "ttir.maximum"(%3, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32>
return %5 : tensor<512x512xf32>
}
}

0 comments on commit 3b43c5a

Please sign in to comment.