Skip to content

Commit

Permalink
Revert "Implement reduction avg op end to end (#271)" (#282)
Browse files Browse the repository at this point in the history
This reverts commit 1ddbd0e due to build break.
  • Loading branch information
dgolubovicTT authored Aug 2, 2024
1 parent 1ddbd0e commit 4c3e670
Show file tree
Hide file tree
Showing 9 changed files with 1 addition and 54 deletions.
7 changes: 0 additions & 7 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,6 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> {
}];
}

def TTIR_AvgOp : TTIR_ReductionOp<"avg"> {
let summary = "Average reduction op.";
let description = [{
Average reduction op.
}];
}

def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> {
let summary = "Softmax operation.";
let description = [{
Expand Down
7 changes: 0 additions & 7 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,6 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
}];
}

def TTNN_AvgOp : TTNN_ReductionOp<"avg"> {
let summary = "Average reduction op.";
let description = [{
Average reduction op.
}];
}

def TTNN_ReluOp : TTNN_ElementwiseOp<"relu"> {
let summary = "Eltwise ReLU.";
let description = [{
Expand Down
1 change: 0 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ table EltwiseOp {

enum ReductionOpType: uint32 {
Sum = 0,
Avg = 1,
}

table ReductionOp {
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseBinaryOpConversionPattern<ttir::GreaterEqualOp, ttnn::GreaterEqualOp>,
ElementwiseBinaryOpConversionPattern<ttir::ReluOp, ttnn::ReluOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::AvgOp, ttnn::AvgOp>,
SoftmaxOpConversionPattern,
MatmulOpConversionPattern
>(typeConverter, ctx);
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Reduction ops
//
patterns.add<DefaultOpConversionPattern<ttnn::SumOp>>(typeConverter, ctx);
patterns.add<DefaultOpConversionPattern<ttnn::AvgOp>>(typeConverter, ctx);
}

} // namespace mlir::tt
4 changes: 1 addition & 3 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,9 +534,7 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
TTIRLayoutOperandsRewriter<MultiplyOp>,
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<GreaterEqualOp>,
TTIRLayoutOperandsRewriter<ReluOp>,
TTIRLayoutOperandsRewriter<SumOp>,
TTIRLayoutOperandsRewriter<AvgOp>,
TTIRLayoutOperandsRewriter<ReluOp>, TTIRLayoutOperandsRewriter<SumOp>,
TTIRLayoutOperandsRewriter<SoftmaxOp>,
TTIRLayoutOperandsRewriter<MatmulOp>, TTIRLayoutFuncReturnRewriter>(
&getContext());
Expand Down
5 changes: 0 additions & 5 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
::tt::target::ttnn::ReductionOpType type;
if constexpr (std::is_same_v<ReductionOp, SumOp>) {
type = ::tt::target::ttnn::ReductionOpType::Sum;
} else if constexpr (std::is_same_v<ReductionOp, AvgOp>) {
type = ::tt::target::ttnn::ReductionOpType::Avg;
} else {
llvm_unreachable("unhandled ReductionOp");
}
Expand Down Expand Up @@ -211,9 +209,6 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto sumOp = dyn_cast<SumOp>(op); sumOp) {
return createOperation(cache, createReductionOp(cache, sumOp), debugString);
}
if (auto avgOp = dyn_cast<AvgOp>(op); avgOp) {
return createOperation(cache, createReductionOp(cache, avgOp), debugString);
}
if (auto softmaxOp = dyn_cast<SoftmaxOp>(op); softmaxOp) {
return createOperation(cache, createSoftmaxOp(cache, softmaxOp),
debugString);
Expand Down
14 changes: 0 additions & 14 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,6 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
break;
}
case ::tt::target::ttnn::ReductionOpType::Avg: {
auto &in = *liveTensors.at(op->in()->global_id());

const auto *dim_arg_fb_ptr = op->dim_arg();
std::optional<vector<int>> dim_arg =
dim_arg_fb_ptr ? std::make_optional(std::vector<int>(
dim_arg_fb_ptr->begin(), dim_arg_fb_ptr->end()))
: std::nullopt;

tensorPool.push_back(::ttnn::avg(in, dim_arg, op->keep_dim()));

liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
break;
}
}
}

Expand Down
15 changes: 0 additions & 15 deletions test/ttmlir/Dialect/TTNN/simple_avg.mlir

This file was deleted.

0 comments on commit 4c3e670

Please sign in to comment.