Skip to content

Commit

Permalink
Add subtract op end to end (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT authored Jul 5, 2024
1 parent 3d5b44f commit d978515
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 3 deletions.
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,13 @@ def TTIR_AddOp : TTIR_ElementwiseBinaryOp<"add"> {
}];
}

def TTIR_SubtractOp : TTIR_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

def TTIR_MultiplyOp : TTIR_ElementwiseBinaryOp<"multiply"> {
let summary = "Eltwise multiply.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def TTNN_AddOp : TTNN_ElementwiseBinaryOp<"add"> {
}];
}

def TTNN_SubtractOp : TTNN_ElementwiseBinaryOp<"subtract"> {
let summary = "Eltwise subtract.";
let description = [{
Eltwise subtract operation.
}];
}

def TTNN_MultiplyOp : TTNN_ElementwiseBinaryOp<"multiply"> {
let summary = "Eltwise multiply.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ table FullOp {
enum EltwiseOpType: uint32 {
Add = 0,
Multiply = 1,
Subtract = 2,
}

table EltwiseOp {
Expand Down
9 changes: 8 additions & 1 deletion lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ConvertTosaToTTIR
patterns.add<TosaToTTIREltwiseBinaryRewriter<tosa::AddOp, ttir::AddOp,
OperandConstraint::AnyDevice>,
TosaToTTIREltwiseBinaryRewriter<tosa::MulOp, ttir::MultiplyOp,
OperandConstraint::AnyDevice>,
TosaToTTIREltwiseBinaryRewriter<tosa::SubOp, ttir::SubtractOp,
OperandConstraint::AnyDevice>>(
&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand Down Expand Up @@ -107,6 +109,9 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern<TTIROp> {
} else if constexpr (std::is_same<TTIROp, ttir::AddOp>::value) {
kernelName = "add";
kernelKind = "eltwise";
} else if constexpr (std::is_same<TTIROp, ttir::SubtractOp>::value) {
kernelName = "subtract";
kernelKind = "eltwise";
} else {
return rewriter.notifyMatchFailure(op,
"Unsupported Tosa operation for TTIR");
Expand Down Expand Up @@ -259,7 +264,8 @@ class TTIRGeneric : public impl::TTIRGenericBase<TTIRGeneric> {
RewritePatternSet patterns(&getContext());
patterns.add<TTIRLinalgGenericRewriter, TTIRKernelGenericRewriter,
TTIRNamedToKernelRewriter<AddOp>,
TTIRNamedToKernelRewriter<MultiplyOp>>(&getContext());
TTIRNamedToKernelRewriter<MultiplyOp>,
TTIRNamedToKernelRewriter<SubtractOp>>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
Expand Down Expand Up @@ -583,6 +589,7 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
patterns.add<TTIRLayoutOperandsRewriter<GenericOp>,
TTIRLayoutOperandsRewriter<AddOp>,
TTIRLayoutOperandsRewriter<MultiplyOp>,
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<MatmulOp>,
TTIRLayoutFuncReturnRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class ConvertTTIRToTTNN
patterns
.add<TTIRToTTNNLayoutRewriter, TTIRToTTNNOpRewriter<ttir::AddOp, AddOp>,
TTIRToTTNNOpRewriter<ttir::MultiplyOp, MultiplyOp>,
TTIRToTTNNOpRewriter<ttir::SubtractOp, SubtractOp>,
TTIRToTTNNBinaryOpRewriter<ttir::MatmulOp, MatmulOp>,
TensorEmptyToFullRewriter>(&getContext());
// ANCHOR_END: adding_an_op_matmul_rewrite_pattern_set
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Add;
} else if constexpr (std::is_same_v<EltwiseOp, MultiplyOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Multiply;
} else if constexpr (std::is_same_v<EltwiseOp, SubtractOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Subtract;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -149,6 +151,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createEltwiseOp(cache, multiplyOp),
debugString);
}
if (auto subtractOp = dyn_cast<SubtractOp>(op); subtractOp) {
return createOperation(cache, createEltwiseOp(cache, subtractOp),
debugString);
}
if (auto matmulOp = dyn_cast<MatmulOp>(op); matmulOp) {
return createOperation(cache, createOp(cache, matmulOp), debugString);
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ConvertTTNNToEmitC
TTNNToEmitCOpaqueRewriter<FullOp>,
TTNNToEmitCOpaqueRewriter<ToMemoryConfigOp>,
TTNNToEmitCOpaqueRewriter<MultiplyOp>,
TTNNToEmitCOpaqueRewriter<SubtractOp>,
TTNNToEmitCOpaqueRewriter<MatmulOp>,
TTNNToEmitCOpaqueRewriter<CloseDeviceOp>>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand Down
10 changes: 8 additions & 2 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,15 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device,
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::multiply(lhs, rhs));
// auto [iter, inserted] =
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
// assert(inserted && "Duplicate output tensor");
break;
}
case ::tt::target::ttnn::EltwiseOpType::Subtract: {
assert(op->ins()->size() == 2 && "Unsupported number of inputs");
auto &lhs = *liveTensors.at(op->ins()->Get(0)->global_id());
auto &rhs = *liveTensors.at(op->ins()->Get(1)->global_id());
tensorPool.push_back(::ttnn::subtract(lhs, rhs));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
break;
}
default:
Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_subtract.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {torch.debug_module_name = "_lambda", tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]]
%1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<64x128xf32>
}
}

0 comments on commit d978515

Please sign in to comment.