Skip to content

Commit

Permalink
Fixing reduction ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mtopalovicTT committed Dec 27, 2024
1 parent 6d04d25 commit 435c787
Show file tree
Hide file tree
Showing 20 changed files with 119 additions and 77 deletions.
27 changes: 23 additions & 4 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
BoolAttr:$keep_dim,
OptionalAttr<I32ArrayAttr>:$dim_arg);
OptionalAttr<AnyAttrOf<[SI32Attr, I32ArrayAttr]>>:$dim);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -624,6 +624,26 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :

void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block);

SmallVector<int64_t> getReduceDims() {
mlir::Attribute reduceDimsAttr = getDim().value_or(mlir::Attribute{});
SmallVector<int64_t> reduceDimsVec;
if (!reduceDimsAttr) {
return reduceDimsVec;
}

if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(reduceDimsAttr)) {
reduceDimsVec.push_back(intAttr.getSInt());
} else {
auto arrayAttr = mlir::cast<mlir::ArrayAttr>(reduceDimsAttr);
for (auto dimAttr : arrayAttr) {
int64_t dim = mlir::cast<mlir::IntegerAttr>(dimAttr).getInt();
reduceDimsVec.push_back(dim);
}
}

return reduceDimsVec;
}

// Returns the indexing maps and iterator types for the reduction op.
// Indexing maps are identity maps with dropped dimensions corresponding to the
// reduction dimensions. Iterator types are parallel for non-reduction dimensions
Expand All @@ -635,10 +655,9 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
SmallVector<Attribute> iteratorTypes(
rank, builder.getAttr<IteratorTypeAttr>(IteratorType::Parallel));

auto reduceDims = getDimArgAttr();
SmallVector<int64_t> reduceDims = getReduceDims();
auto resultIndexingMap = indexingMaps.back();
for (auto reduceDim : reduceDims) {
int64_t reduceDimInt = mlir::cast<IntegerAttr>(reduceDim).getInt();
for (auto reduceDimInt : reduceDims) {
if (reduceDimInt < 0) {
reduceDimInt += rank;
}
Expand Down
24 changes: 23 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,32 @@ class TTNN_ReductionOp<string mnemonic, list<Trait> traits = []> : TTNN_Op<mnemo

let arguments = (ins AnyRankedTensor:$input,
BoolAttr:$keep_dim,
OptionalAttr<I32ArrayAttr>:$dim_arg);
OptionalAttr<AnyAttrOf<[SI32Attr, I32ArrayAttr]>>:$dim);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
SmallVector<int64_t> getReduceDims() {
mlir::Attribute reduceDimsAttr = getDim().value_or(mlir::Attribute{});
SmallVector<int64_t> reduceDimsVec;
if (!reduceDimsAttr) {
return reduceDimsVec;
}

if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(reduceDimsAttr)) {
reduceDimsVec.push_back(intAttr.getSInt());
} else {
auto arrayAttr = mlir::cast<mlir::ArrayAttr>(reduceDimsAttr);
for (auto dimAttr : arrayAttr) {
int64_t dim = mlir::cast<mlir::IntegerAttr>(dimAttr).getInt();
reduceDimsVec.push_back(dim);
}
}

return reduceDimsVec;
}
}];

let hasVerifier = 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@

namespace mlir::tt::ttnn::workarounds::decomposition {

// Extracts reduce dimensions' values from the dimArg attribute. In case when
// dimArg is not specified, returns empty vector.
llvm::SmallVector<int64_t>
getReduceDims(const std::optional<mlir::ArrayAttr> &dimArg);

// Calculates the shape of the new Reduce op created in the workaround, based
// on the input shape and reducing dimensions.
llvm::SmallVector<int64_t>
calculateNewReduceShape(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg);
const llvm::SmallVector<int64_t> &reduceDims);

// This workaround addresses the next Metal issue:
// https://github.com/tenstorrent/tt-metal/issues/13361
Expand Down Expand Up @@ -70,7 +65,7 @@ class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> {
RankedTensorType inputType,
RankedTensorType outputType) const {
llvm::SmallVector<int64_t> outputShapeVec =
calculateNewReduceShape(inputType, srcOp.getDimArg());
calculateNewReduceShape(inputType, srcOp.getReduceDims());

TTNNLayoutAttr newOutputLayoutAttr =
mlir::cast<TTNNLayoutAttr>(outputType.getEncoding())
Expand All @@ -81,7 +76,7 @@ class ReduceOpsKeepDimRewritePattern : public OpRewritePattern<ReduceOp> {

return rewriter.create<ReduceOp>(srcOp.getLoc(), newOutputType,
srcOp.getInput(), true /*keep_dim*/,
srcOp.getDimArg().value_or(nullptr));
srcOp.getDim().value_or(nullptr));
}

void replaceOpWithReshapeOp(ReduceOp srcOp, ReduceOp newReduceOp,
Expand All @@ -108,11 +103,11 @@ class ReduceOpsAllDimsRewritePattern : public OpRewritePattern<ReduceOp> {

LogicalResult matchAndRewrite(ReduceOp srcOp,
PatternRewriter &rewriter) const override {
if (!srcOp.getDimArg() || srcOp.getDimArg()->empty()) {
llvm::SmallVector<int64_t> reduceDims = srcOp.getReduceDims();
if (reduceDims.empty()) {
return failure();
}

llvm::SmallVector<int64_t> reduceDims = getReduceDims(srcOp.getDimArg());
llvm::SmallSet<int64_t, 4> uniqueReduceDims(reduceDims.begin(),
reduceDims.end());

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ table ReductionOp {
type: ReductionOpType;
in: tt.target.TensorRef;
out: tt.target.TensorRef;
dim_arg: [int32];
dim: [int32];
keep_dim: bool;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ class ReductionOpConversionPattern : public OpConversionPattern<TTIROpTy> {
rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getKeepDim(),
adaptor.getDimArg().value_or(nullptr));
adaptor.getDim().value_or(nullptr));
return success();
}
};
Expand Down
17 changes: 7 additions & 10 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1729,23 +1729,20 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block,
// Common verifier for all Reduce ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
if (!reduceDims) {
const llvm::SmallVector<int64_t> &reduceDims) {
if (reduceDims.empty()) {
return mlir::success();
}

int64_t inputTensorRank = inputType.getRank();

llvm::SmallSet<int64_t, 4> uniqueReduceDims;
for (mlir::Attribute reduceDim : *reduceDims) {
int64_t reduceDimInt = mlir::cast<mlir::IntegerAttr>(reduceDim).getInt();
for (int64_t reduceDimInt : reduceDims) {
if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) {
return reduceOp->emitOpError("Reduce dimensions are out of range");
}
uniqueReduceDims.insert(reduceDimInt);
}

if (uniqueReduceDims.size() != reduceDims->size()) {
if (uniqueReduceDims.size() != reduceDims.size()) {
return reduceOp->emitOpError("Reduce dimensions are not unique");
}

Expand All @@ -1770,7 +1767,7 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// MaxOp verification.
::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1786,7 +1783,7 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// MeanOp verification.
::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1802,5 +1799,5 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder,

// SumOp verification.
::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims());
}
18 changes: 9 additions & 9 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ ::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() {
<< getStart() << ", end=" << getEnd() << ", step=" << getStep();
}

std::vector<int64_t> expectedShape = {1, 1, 1, numValues};
if (getType().getShape().vec() != expectedShape) {
llvm::SmallVector<int64_t> expectedShape = {1, 1, 1, numValues};
if (getType().getShape() != ArrayRef<int64_t>(expectedShape)) {
return emitOpError() << "Output tensor shape must be " << expectedShape
<< ", but got " << getType().getShape();
}
Expand Down Expand Up @@ -1274,13 +1274,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::PermuteOp::verify() {
// Common verifier for all Reduction ops.
static mlir::LogicalResult
verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &reduceDims) {
int64_t inputTensorRank = inputType.getRank();
const llvm::SmallVector<int64_t> &reduceDims) {
size_t inputTensorRank = inputType.getRank();

// TODO(mrakita): Only last two dimensions can be reduced, check for that
// too.
if (reduceDims && reduceDims->size() > 2 &&
static_cast<int64_t>(reduceDims->size()) != inputTensorRank) {
if (!reduceDims.empty() && reduceDims.size() > 2 &&
reduceDims.size() != inputTensorRank) {
return reduceOp->emitOpError("Reduce on more than two dimensions is not "
"currently supported by TTNN");
}
Expand All @@ -1294,7 +1294,7 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType,

// MaxOp verification.
::mlir::LogicalResult MaxOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1303,7 +1303,7 @@ ::mlir::LogicalResult MaxOp::verify() {

// MeanOp verification.
::mlir::LogicalResult MeanOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims());
}

//===----------------------------------------------------------------------===//
Expand All @@ -1312,7 +1312,7 @@ ::mlir::LogicalResult MeanOp::verify() {

// SumOp verification.
::mlir::LogicalResult SumOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims());
}

} // namespace mlir::tt::ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,10 @@

namespace mlir::tt::ttnn::workarounds::decomposition {

llvm::SmallVector<int64_t>
getReduceDims(const std::optional<mlir::ArrayAttr> &dimArg) {
llvm::SmallVector<int64_t, 4> reduceDims;
if (!dimArg) {
return reduceDims;
}

for (const mlir::Attribute &reduceDim : *dimArg) {
reduceDims.push_back(mlir::cast<mlir::IntegerAttr>(reduceDim).getInt());
}

return reduceDims;
}

llvm::SmallVector<int64_t>
calculateNewReduceShape(RankedTensorType inputType,
const std::optional<mlir::ArrayAttr> &dimArg) {
const llvm::SmallVector<int64_t> &reduceDims) {
llvm::SmallVector<int64_t> outputShapeVec(inputType.getShape());
llvm::SmallVector<int64_t> reduceDims = getReduceDims(dimArg);

if (reduceDims.empty()) {
// When reduce dimensions are not specified that means we are reducing over
Expand Down
16 changes: 9 additions & 7 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ createDistributionStrategy(FlatbufferObjectCache &cache,
// tensor is sliced at the fastest dimension.
if (meshShape[0] == 1 || meshShape[1] == 1) {
assert(type.getShape().size() > 0 && "expected non-zero tensor shape");
uint32_t target_dim = type.getShape().size() - 1;
auto strategy = ::tt::target::CreateShardTensor(*cache.fbb, target_dim);
uint32_t targetDim = type.getShape().size() - 1;
auto strategy = ::tt::target::CreateShardTensor(*cache.fbb, targetDim);
return ::tt::target::CreateDistributionStrategy(
*cache.fbb, ::tt::target::DistributedTensorConfig::ShardTensor,
strategy.Union());
Expand Down Expand Up @@ -730,11 +730,13 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
auto dim_arg =
arrayAttrToFlatbuffer<mlir::IntegerAttr, int>(cache, op.getDimArg());
SmallVector<int64_t> dims = op.getReduceDims();
SmallVector<int32_t> dims32(dims.begin(), dims.end());
auto dimArg =
op.getReduceDims().empty() ? 0 : toFlatbuffer<int32_t>(cache, dims32);

return ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output,
dim_arg, op.getKeepDim());
dimArg, op.getKeepDim());
}

::flatbuffers::Offset<::tt::target::ttnn::TransposeOp>
Expand Down Expand Up @@ -1134,8 +1136,8 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createSliceOp(cache, sliceOp), debugString,
locInfo);
}
if (auto max_pool2dOp = dyn_cast<MaxPool2dOp>(op); max_pool2dOp) {
return createOperation(cache, createMaxPool2dOp(cache, max_pool2dOp),
if (auto maxPool2dOp = dyn_cast<MaxPool2dOp>(op); maxPool2dOp) {
return createOperation(cache, createMaxPool2dOp(cache, maxPool2dOp),
debugString, locInfo);
}
if (auto deallocateOp = dyn_cast<DeallocateOp>(op); deallocateOp) {
Expand Down
8 changes: 4 additions & 4 deletions runtime/lib/ttnn/operations/reduction/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ static void runReductionOp(
const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id());
DEBUG_ASSERT(in.is_allocated());

const auto *fbDimArg = op->dim_arg();
const auto *fbDim = op->dim();
std::optional<::ttnn::SmallVector<int>> dimArg =
fbDimArg ? std::make_optional(::ttnn::SmallVector<int>(fbDimArg->begin(),
fbDimArg->end()))
: std::nullopt;
fbDim ? std::make_optional(
::ttnn::SmallVector<int>(fbDim->begin(), fbDim->end()))
: std::nullopt;

::ttnn::Tensor out = ttnnOp(
in, dimArg, op->keep_dim(), outputMemoryConfig /* memory_config_arg */,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// CHECK: error: 'ttir.sum' op Reduce dimensions are out of range
func.func public @test_reduce_add_invalid_dim_high(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> {
%0 = tensor.empty() : tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim = [2 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// CHECK: error: 'ttir.sum' op Reduce dimensions are out of range
func.func public @test_reduce_add_invalid_dim_low(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> {
%0 = tensor.empty() : tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-3 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim = [-3 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
// CHECK: error: 'ttir.sum' op Reduce dimensions are not unique
func.func public @test_reduce_add_repeating_dims(%arg0: tensor<128x10x32x4xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> {
%0 = tensor.empty() : tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [1 : i32, 2 : i32, 3 : i32, 2 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128xf32>) -> tensor<128xf32>
%1 = "ttir.sum"(%arg0, %0) <{dim = [1 : i32, 2 : i32, 3 : i32, 2 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
}
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/reduction/max_op_negative.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module {
func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> {
%0 = tensor.empty() : tensor<128x1x1x1xbf16>
// CHECK: error: 'ttnn.max' op Reduce on more than two dimensions is not currently supported by TTNN
%1 = "ttir.max"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
%1 = "ttir.max"(%arg0, %0) <{dim = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
return %1 : tensor<128x1x1x1xbf16>
}
}
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/reduction/mean_op_negative.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module {
func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> {
%0 = tensor.empty() : tensor<128x1x1x1xbf16>
// CHECK: error: 'ttnn.mean' op Reduce on more than two dimensions is not currently supported by TTNN
%1 = "ttir.mean"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
%1 = "ttir.mean"(%arg0, %0) <{dim = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
return %1 : tensor<128x1x1x1xbf16>
}
}
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/reduction/sum_op_negative.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module {
func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> {
%0 = tensor.empty() : tensor<128x1x1x1xbf16>
// CHECK: error: 'ttnn.sum' op Reduce on more than two dimensions is not currently supported by TTNN
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
%1 = "ttir.sum"(%arg0, %0) <{dim = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16>
return %1 : tensor<128x1x1x1xbf16>
}
}
Loading

0 comments on commit 435c787

Please sign in to comment.