Skip to content

Commit

Permalink
Minor fixes for stablehlo.gather op (#1337)
Browse files Browse the repository at this point in the history
* Use updated operands during stableHLO to TTIR conversion
* Add check to ensure the input argument data type is bfloat16
* Add stablehlo runtime test for gather op
  • Loading branch information
mmanzoorTT authored Nov 21, 2024
1 parent ff339a1 commit 548b063
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 17 deletions.
10 changes: 5 additions & 5 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern

llvm::SmallVector<int64_t, 4> broadcastedShape;
auto srcType =
getTypeConverter()->convertType(srcOp.getOperand().getType());
getTypeConverter()->convertType(adaptor.getOperand().getType());
auto inputShape = mlir::cast<mlir::RankedTensorType>(srcType).getShape();
auto outputShape = mlir::cast<mlir::RankedTensorType>(srcType).getShape();

Expand Down Expand Up @@ -996,8 +996,8 @@ class StableHLOToTTIRConcatOpConversionPattern
"ConcatOp dimension is too large.");
}

auto rankedTensorType =
mlir::dyn_cast<mlir::RankedTensorType>(srcOp.getOperand(0).getType());
auto rankedTensorType = mlir::dyn_cast<mlir::RankedTensorType>(
adaptor.getOperands()[0].getType());
if (static_cast<int64_t>(adaptor.getDimension()) >=
rankedTensorType.getRank()) {
return rewriter.notifyMatchFailure(srcOp,
Expand Down Expand Up @@ -1185,8 +1185,8 @@ class StableHLOToTTIRGatherOpConversionPattern
auto dimensionNumbers = srcOp.getDimensionNumbers();

rewriter.replaceOpWithNewOp<mlir::tt::ttir::GatherOp>(
srcOp, outputType, srcOp.getOperands()[0],
srcOp.getOperands()[1], // Start indices
srcOp, outputType, adaptor.getOperands()[0],
adaptor.getOperands()[1], // Start indices
Value(outputTensor), dimensionNumbers.getOffsetDims(),
dimensionNumbers.getCollapsedSliceDims(),
dimensionNumbers.getOperandBatchingDims(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Transforms/DialectConversion.h"

#include <algorithm>
#include <mlir/IR/BuiltinAttributes.h>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -407,6 +408,13 @@ struct GatherToEmbeddingConversionPattern
// collapsed slice dims of the gather op
auto collapsedSliceDims = op.getCollapsedSliceDims();

RankedTensorType operandType =
mlir::cast<RankedTensorType>(op->getOperand(0).getType());
if (!operandType.getElementType().isBF16()) {
return rewriter.notifyMatchFailure(
op, "only supports bfloat16 input tensor.");
}

if (shape.size() > 1) {
auto hiddenDim = shape[shape.size() - 1];
// check if sliceSizes has more than one element
Expand Down
17 changes: 17 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module @jit_gather attributes {} {
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]]
return %0 : tensor<1x32x1024xf32>
}

func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<448x384xf32>, tensor<1x2x1xi32>) -> tensor<1x2x384xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
Expand All @@ -22,4 +23,20 @@ module @jit_gather attributes {} {
return %0 : tensor<1x2x384xf32>
}

func.func public @test_gather_3(%arg0: tensor<32128x512xbf16>, %arg1: tensor<1x15xi64>) -> tensor<1x15x512xbf16> {
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : tensor<1x15x512xbf16>
// CHECK: %[[VAL:[0-9]+]] = "ttir.gather"(%arg0, %arg1, %[[EMPTY]])
// CHECK-SAME: collapsed_slice_dims = array<i64: 0>,
// CHECK-SAME: index_vector_dim = 2 : si64,
// CHECK-SAME: indices_are_sorted = false,
// CHECK-SAME: offset_dims = array<i64: 2>,
// CHECK-SAME: operand_batching_dims = array<i64>,
// CHECK-SAME: slice_sizes = array<i64: 1, 512>,
// CHECK-SAME: start_index_map = array<i64: 0>,
// CHECK-SAME: start_indices_batching_dims = array<i64>
// CHECK-SAME: (tensor<32128x512xbf16>, tensor<1x15xi32>, tensor<1x15x512xbf16>) -> tensor<1x15x512xbf16>
%0 = "stablehlo.gather"(%arg0, %arg1) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 512>}> : (tensor<32128x512xbf16>, tensor<1x15xi64>) -> tensor<1x15x512xbf16>
// CEHCK: return %[[VAL]] : tensor<1x15x512xbf16>
return %0 : tensor<1x15x512xbf16>
}
}
24 changes: 12 additions & 12 deletions test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> {
func.func @gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x32x1024xf32>
%0 = tensor.empty() : tensor<1x32x1024xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.gather"(%operand, %start_indices, %0) {
offset_dims = array<i64: 2>,
Expand All @@ -15,13 +15,13 @@ module attributes {} {
slice_sizes = array<i64: 1, 1024>,
indices_are_sorted = false,
operand_constraints = [#any_device, #any_device, #any_device]
} : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32>
return %1 : tensor<1x32x1024xf32>
} : (tensor<32000x1024xbf16>, tensor<1x32xi32>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16>
return %1 : tensor<1x32x1024xbf16>
}

func.func @gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> {
func.func @gather_1(%operand: tensor<448x384xbf16>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x2x384xf32>
%0 = tensor.empty() : tensor<1x2x384xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.gather"(%operand, %start_indices, %0) <{
offset_dims = array<i64: 2>,
Expand All @@ -33,13 +33,13 @@ module attributes {} {
slice_sizes = array<i64: 1, 384>,
indices_are_sorted = false,
operand_constraints = [#any_device, #any_device, #any_device]
}> : (tensor<448x384xf32>, tensor<1x2x1xi32>, tensor<1x2x384xf32>) -> tensor<1x2x384xf32>
return %1 : tensor<1x2x384xf32>
}> : (tensor<448x384xbf16>, tensor<1x2x1xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16>
return %1 : tensor<1x2x384xbf16>
}

func.func @gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xf32> {
func.func @gather_2(%operand: tensor<51864x384xbf16>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x2x384xf32>
%0 = tensor.empty() : tensor<1x2x384xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.gather"(%operand, %start_indices, %0) <{
offset_dims = array<i64: 2>,
Expand All @@ -51,7 +51,7 @@ module attributes {} {
slice_sizes = array<i64: 1, 384>,
indices_are_sorted = false,
operand_constraints = [#any_device, #any_device, #any_device]
}> : (tensor<51864x384xf32>, tensor<1x2xi32>, tensor<1x2x384xf32>) -> tensor<1x2x384xf32>
return %1 : tensor<1x2x384xf32>
}> : (tensor<51864x384xbf16>, tensor<1x2xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16>
return %1 : tensor<1x2x384xbf16>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,25 @@ module attributes {} {
return %1 : tensor<1x2x384xf32>
}
}

// Verify that the parsing fails for data type other than bfloat16.
// -----
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> {
%0 = tensor.empty() : tensor<1x32x1024xf32>
// CHECK: error: failed to legalize operation 'ttir.gather' that was explicitly marked illegal
%1 = "ttir.gather"(%operand, %start_indices, %0) {
offset_dims = array<i64: 2>,
collapsed_slice_dims = array<i64: 0>,
operand_batching_dims = array<i64: 0>,
start_indices_batching_dims = array<i64: 0>,
start_index_map = array<i64: 0>,
index_vector_dim = 1 : si64,
slice_sizes = array<i64: 1, 1024>,
indices_are_sorted = false,
operand_constraints = [#any_device, #any_device, #any_device]
} : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32>
return %1 : tensor<1x32x1024xf32>
}
}
45 changes: 45 additions & 0 deletions test/ttmlir/Silicon/StableHLO/gather_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RU1N: FileCheck --input-file=%t.mlir %s

module @jit_gather attributes {} {
func.func public @test_gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> {
// CHECK-LABEL: func.func public @test_gather_0
// CHECK: ttnn.empty
// CHECK: ttnn.embedding
// CHECK-SAME: tensor<1x32xi32,
// CHECK-SAME: tensor<1x32x1024xbf16
// CHECK-SAME: tensor<32000x1024xbf16,
// CHECK-SAME: -> tensor<1x32x1024xbf16
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1024>}> : (tensor<32000x1024xbf16>, tensor<1x32xi32>) -> tensor<1x32x1024xbf16>
return %0 : tensor<1x32x1024xbf16>
}

func.func public @test_gather_1(%operand: tensor<51864x384xbf16>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xbf16> {
// CHECK-LABEL: func.func public @test_gather_1
// CHECK: ttnn.empty
// CHECK: ttnn.embedding
// CHECK-SAME: tensor<1x2xi32,
// CHECK-SAME: tensor<1x2x384xbf16
// CHECK-SAME: tensor<51864x384xbf16,
// CHECK-SAME: -> tensor<1x2x384xbf16
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<51864x384xbf16>, tensor<1x2xi32>) -> tensor<1x2x384xbf16>
return %0 : tensor<1x2x384xbf16>
}

func.func public @test_gather_2(%operand: tensor<32128x512xbf16>, %start_indices: tensor<1x15xi64>) -> tensor<1x15x512xbf16> {
// CHECK-LABEL: func.func public @test_gather_2
// CHECK: ttnn.empty
// CHECK: ttnn.embedding
// CHECK-SAME: tensor<1x16xi32,
// CHECK-SAME: tensor<1x15x512xbf16
// CHECK-SAME: tensor<32128x512xbf16,
// CHECK-SAME: -> tensor<1x15x512xbf16
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 512>}> : (tensor<32128x512xbf16>, tensor<1x15xi64>) -> tensor<1x15x512xbf16>
return %0 : tensor<1x15x512xbf16>
}
}

0 comments on commit 548b063

Please sign in to comment.