Skip to content

Commit

Permalink
Gather op implementation [#1015]
Browse files Browse the repository at this point in the history
Gather op is lowered into embedding. Used TTIR pass from 38a4a46. Used
embedding fixes from e798a17. Blocked by tt-metal issue 14584.
  • Loading branch information
ddilbazTT committed Nov 4, 2024
1 parent e6c60fd commit 0bc82d9
Show file tree
Hide file tree
Showing 20 changed files with 333 additions and 45 deletions.
24 changes: 23 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -674,11 +674,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
let results = (outs AnyRankedTensor);
let hasVerifier = 1;

let extraClassDeclaration = [{
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}

def TTIR_GatherOp: TTIR_DPSOp<"gather"> {
let summary = "Gather operation.";
let description = [{
Gather operation.
}];
let arguments = (ins AnyRankedTensor:$input, // operand
AnyRankedTensor:$start_indices, // start_indices
AnyRankedTensor:$output, // result
DenseI64ArrayAttr:$offset_dims, // offset_dims
DenseI64ArrayAttr:$collapsed_slice_dims, // collapsed_slice_dims
DenseI64ArrayAttr:$operand_batching_dims, // operand_batching_dims
DenseI64ArrayAttr:$start_indices_batching_dims, // start_indices_batching_dims
DenseI64ArrayAttr:$start_index_map, // start_index_map
SI64Attr:$index_vector_dim, // index_vector_dim
DenseI64ArrayAttr:$slice_sizes, // slice_sizes
BoolAttr:$indices_are_sorted, // indices_are_sorted (bool)
TT_OperandConstraintArrayAttr:$operand_constraints);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}

def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
Expand Down
7 changes: 6 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -424,17 +424,22 @@ def TTNN_MaxOp : TTNN_ReductionOp<"max"> {
}];
}

def TTNN_EmbeddingOp : TTNN_Op<"embedding"> {
def TTNN_EmbeddingOp : TTNN_NamedDPSOp<"embedding"> {
let summary = "Embedding op.";
let description = [{
Embedding operation.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
AnyRankedTensor:$weight);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ table ReductionOp {
table EmbeddingOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
output: tt.target.TensorRef;
out: tt.target.TensorRef;
}

table SoftmaxOp {
Expand Down
48 changes: 48 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,48 @@ class StableHLOToTTIRSliceOpConversionPattern
}
};

class StableHLOToTTIRGatherOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::GatherOp> {
using OpConversionPattern<mlir::stablehlo::GatherOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::GatherOp srcOp,
mlir::stablehlo::GatherOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Create the output tensor type based on inputs
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
// Create an empty output tensor with the computed shape
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
auto dimensionNumbers = srcOp.getDimensionNumbers();
rewriter.replaceOpWithNewOp<mlir::tt::ttir::GatherOp>(
srcOp, // The original operation to replace
outputType, // Result type
srcOp.getOperands()[0], // Input tensor
srcOp.getOperands()[1], // Start indices
Value(outputTensor), // Output tensor
dimensionNumbers.getOffsetDims(), // offset_dims attribute
dimensionNumbers
.getCollapsedSliceDims(), // collapsed_slice_dims attribute
dimensionNumbers
.getOperandBatchingDims(), // operand_batching_dims attribute
dimensionNumbers
.getStartIndicesBatchingDims(), // start_indices_batching_dims
// attribute
dimensionNumbers.getStartIndexMap(), // start_index_map attribute
dimensionNumbers.getIndexVectorDim(), // index_vector_dim attribute
srcOp.getSliceSizesAttr(), // slice_sizes attribute
false, // indices_are_sorted attribute
rewriter.getArrayAttr( // operand constraints
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1029,6 +1071,11 @@ void addSliceOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
patterns.add<StableHLOToTTIRSliceOpConversionPattern>(typeConverter, ctx);
}

void addGatherOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRGatherOpConversionPattern>(typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand All @@ -1050,6 +1097,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addLogicalOpConversionPattern(ctx, patterns, typeConverter);
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addGatherOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
101 changes: 101 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,107 @@ struct ConvolutionToConv2dPattern
adaptor.getOperandConstraints());

rewriter.replaceOp(op, output);
return success();
}
};
//===----------------------------------------------------------------------===//
// Gather passes
//===----------------------------------------------------------------------===//


struct GatherToEmbeddingConversionPattern
: public OpConversionPattern<ttir::GatherOp> {
using OpConversionPattern<ttir::GatherOp>::OpConversionPattern;

LogicalResult checkBasicLegality(ttir::GatherOp op,
PatternRewriter &rewriter) const {
auto outputType = mlir::cast<RankedTensorType>(op.getResult().getType());
auto shape = outputType.getShape();
auto startIndices = op.getStartIndices(); // start indices of the gather op
auto startIndicesType =
mlir::cast<RankedTensorType>(startIndices.getType());
auto sliceSizes = op.getSliceSizes(); // slice sizes of the gather op
auto offsetDims = op.getOffsetDims();
auto collapsedSliceDims =
op.getCollapsedSliceDims(); // collapsed slice dims of the gather op
if (shape.size() > 1) {
auto hiddenDim = shape[shape.size() - 1];
assert(sliceSizes.size() > 1 &&
"sliceSizes should have at least 2 elements");
if (sliceSizes[0] != 1 || sliceSizes[1] != hiddenDim) {
return rewriter.notifyMatchFailure(op, "Did not satisfy sliceSizes");
}
}
if (offsetDims.size() != 1 &&
std::vector<int64_t>(offsetDims.begin(), offsetDims.end()) !=
std::vector<int64_t>{2}) {
return rewriter.notifyMatchFailure(op, "Did not satisfy offsetDims");
}
if (collapsedSliceDims.size() != 1 ||
std::vector<int64_t>(collapsedSliceDims.begin(),
collapsedSliceDims.end()) !=
std::vector<int64_t>{0}) {
return rewriter.notifyMatchFailure(op,
"Did not satisfy collapsedSliceDims");
}
if (shape.size() == startIndicesType.getShape().size() &&
startIndicesType.getShape()[shape.size() - 1] != 1) {
return rewriter.notifyMatchFailure(op,
"Did not satisfy startIndicesType");
}
return success();
}
ttir::ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc,
Value input,
::llvm::ArrayRef<int64_t> shapei64,
::mlir::ArrayAttr operandConstraints) const {
auto ty = mlir::cast<RankedTensorType>(input.getType());
auto output = rewriter.create<tensor::EmptyOp>(
loc, llvm::ArrayRef<int64_t>(shapei64), ty.getElementType());
std::vector<int32_t> shapei32(shapei64.begin(), shapei64.end());
auto shape_attr = rewriter.getI32ArrayAttr(shapei32);
return rewriter.create<ttir::ReshapeOp>(
loc, mlir::RankedTensorType::get(shapei64, ty.getElementType()), input,
output, shape_attr, operandConstraints);
}
LogicalResult
matchAndRewrite(ttir::GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult err = checkBasicLegality(op, rewriter);
if (not err.succeeded()) {
return err;
}
auto outputType = mlir::cast<RankedTensorType>(op.getResult().getType());
auto shape = outputType.getShape();
auto startIndices = op.getStartIndices(); // start indices of the gather op
auto startIndicesType =
mlir::cast<RankedTensorType>(startIndices.getType());
::mlir::Value input = op.getStartIndices();
if (shape.size() == startIndicesType.getShape().size() &&
startIndicesType.getShape()[shape.size() - 1] == 1) {
// reduce weight tensor dimension
// insert reshape op to remove the last dimension of start indices
// before gather/ embedding op
std::vector<int64_t> newShapeI64(startIndicesType.getShape().begin(),
startIndicesType.getShape().end() - 1);
ttir::ReshapeOp reshapeOp =
createReshapeOp(rewriter, op.getLoc(), startIndices, newShapeI64,
op.getOperandConstraints());
assert(reshapeOp && "Failed to create reshape op");
reshapeOp->moveBefore(op);
input = reshapeOp.getResult();
}
ttir::EmbeddingOp embeddingOp = rewriter.create<ttir::EmbeddingOp>(
op.getLoc(), op.getResult().getType(),
input, // input - start indices
op.getOperands()[0], // weight - input tensor
op.getOutput(),
rewriter.getArrayAttr( // operand constraints
SmallVector<Attribute>(op.getNumOperands() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
assert(embeddingOp != nullptr && "Failed to create embedding op");
rewriter.replaceOp(op, embeddingOp);
return success();
}
};
Expand All @@ -407,6 +507,7 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
TypeConverter &typeConverter) {
patterns.add<IndexToSliceConversionPattern>(typeConverter, ctx);
patterns.add<ConvolutionToConv2dPattern>(typeConverter, ctx);
patterns.add<GatherToEmbeddingConversionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct TTIRToTTIRDecompositionPass
// These are the ops we intend to remove entirely with this pass
target.addIllegalOp<ttir::IndexOp>();
target.addIllegalOp<ttir::ConvolutionOp>();
target.addIllegalOp<ttir::GatherOp>();

TypeConverter typeConverter;
// All types map 1:1.
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class ToLayoutOpConversionPattern
bool shouldForceRowMajor(ttir::ToLayoutOp op) const {
for (mlir::Operation *user : op.getResult().getUsers()) {
if (isa<ttir::Conv2dOp>(user) || isa<ttir::MaxPool2dOp>(user) ||
isa<ttir::SliceOp>(user)) {
isa<ttir::SliceOp>(user) || isa<ttir::EmbeddingOp>(user)) {
return true;
}
}
Expand Down Expand Up @@ -317,7 +317,7 @@ class EmbeddingOpConversionPattern
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::EmbeddingOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getWeight());
adaptor.getInput(), adaptor.getOutput(), adaptor.getWeight());

return success();
}
Expand Down
10 changes: 6 additions & 4 deletions runtime/lib/ttnn/operations/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ void run(const ::tt::target::ttnn::EmbeddingOp *op, ProgramContext &context) {

// default params for embedding op
std::optional<int> padToken = std::nullopt;
::tt::tt_metal::Layout layout = ::ttnn::ROW_MAJOR_LAYOUT;
::tt::tt_metal::Layout layout = utils::isTilized(op->out())
? ::ttnn::TILE_LAYOUT
: ::ttnn::ROW_MAJOR_LAYOUT;
auto embeddingsType = ::ttnn::operations::embedding::EmbeddingsType::GENERIC;
::ttnn::DataType outputDataType = utils::getDataType(op->output());
::ttnn::DataType outputDataType = utils::getDataType(op->out());
::ttnn::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->output());
utils::createMemoryConfig(op->out());
::ttnn::Tensor out =
::ttnn::embedding(input, weight, padToken, layout, embeddingsType,
outputDataType, outputMemoryConfig);
tensorPool.insert_or_assign(op->output()->global_id(), out);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::embedding
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ bool isOnDevice(const ::ttnn::Tensor &tensor) {
return tensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE;
}

bool isTilized(const ::tt::target::TensorRef *tensorRef) {
const ::tt::target::Dim2d *tileShape =
tensorRef->desc()->layout()->memory_desc()->tile_shape();
return tileShape->x() == 32 and tileShape->y() == 32;
}

::tt::target::MemorySpace
getMemorySpace(const ::tt::target::TensorRef *tensorRef) {
return tensorRef->desc()->layout()->memory_desc()->memory_space();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ bool isOnHost(const ::ttnn::Tensor &tensor);

bool isOnDevice(const ::ttnn::Tensor &tensor);

bool isTilized(const ::tt::target::TensorRef *tensorRef);

bool inSystemMemory(const ::tt::target::TensorRef *tensorRef);

::tt::target::MemorySpace
Expand Down
25 changes: 25 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/gather_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_gather attributes {} {
func.func public @test_gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1024>}> : (tensor<32000x1024xf32>, tensor<1x32xi32>) -> tensor<1x32x1024xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]]
return %0 : tensor<1x32x1024xf32>
}
func.func public @test_gather_1(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x1xi32>) -> tensor<1x2x384xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<448x384xf32>, tensor<1x2x1xi32>) -> tensor<1x2x384xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]]
return %0 : tensor<1x2x384xf32>
}

func.func public @test_gather_2(%operand: tensor<51864x384xf32>, %start_indices: tensor<1x2xi32>) -> tensor<1x2x384xf32> {
%0 = "stablehlo.gather"(%operand, %start_indices) <{dimension_numbers = #stablehlo.gather<offset_dims = [2], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = array<i64: 1, 384>}> : (tensor<51864x384xf32>, tensor<1x2xi32>) -> tensor<1x2x384xf32>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.gather"[[C:.*]]
return %0 : tensor<1x2x384xf32>
}

}
8 changes: 4 additions & 4 deletions test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> {
%0 = tensor.empty() : tensor<32x128xf32>
func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> {
%0 = tensor.empty() : tensor<32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32>
return %1 : tensor<32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16>
return %1 : tensor<32x128xbf16>
}
}
11 changes: 5 additions & 6 deletions test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn %s | FileCheck %s
// UNSUPPORTED: true
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<1x32x128xf32> {
func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x32x128xf32>
%0 = tensor.empty() : tensor<1x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xf32>, tensor<512x128xf32>, tensor<1x32x128xf32>) -> tensor<1x32x128xf32>
return %1 : tensor<1x32x128xf32>
%1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16>
return %1 : tensor<1x32x128xbf16>
}
}
Loading

0 comments on commit 0bc82d9

Please sign in to comment.