Skip to content

Commit

Permalink
Handle different value types for constant op explicitly for TTNN back…
Browse files Browse the repository at this point in the history
…end (#1164)
  • Loading branch information
mmanzoorTT authored Nov 8, 2024
1 parent ce78615 commit f35a2cb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 7 deletions.
24 changes: 21 additions & 3 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,10 @@ class ConstantOpConversionPattern

if (valueAttr.isSplat()) {
Value device = getOrInsertDevice(rewriter, op);
float fillValue = valueAttr.getElementType().isInteger()
? static_cast<float>(valueAttr.getSplatValue<int>())
: valueAttr.getSplatValue<float>();
float fillValue =
valueAttr.getElementType().isInteger()
? getIntegerValue(valueAttr)
: valueAttr.getSplatValue<mlir::APFloat>().convertToFloat();
if (fillValue == 0) {
rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
op, this->getTypeConverter()->convertType(op.getType()), device);
Expand Down Expand Up @@ -536,6 +537,23 @@ class ConstantOpConversionPattern

return success();
}

float getIntegerValue(mlir::ElementsAttr valueAttr) const {
size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth();
switch (bitWidth) {
case 1:
return static_cast<float>(valueAttr.getSplatValue<bool>());
case 8:
return static_cast<float>(valueAttr.getSplatValue<int8_t>());
case 16:
return static_cast<float>(valueAttr.getSplatValue<int16_t>());
case 32:
return static_cast<float>(valueAttr.getSplatValue<int>());
case 64:
return static_cast<float>(valueAttr.getSplatValue<int64_t>());
}
assert(false && "Unsupported integer type.");
}
};

} // namespace
Expand Down
54 changes: 50 additions & 4 deletions test/ttmlir/Dialect/TTNN/simple_constant.mlir
Original file line number Diff line number Diff line change
@@ -1,27 +1,73 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @test_empty_int8() -> tensor<64x128xi8> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi8>}> : () -> tensor<64x128xi8>
// CHECK: %{{[0-9]+}} = "ttnn.empty"
return %0 : tensor<64x128xi8>
}

func.func @test_empty_int16() -> tensor<64x128xi16> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi16>}> : () -> tensor<64x128xi16>
// CHECK: %{{[0-9]+}} = "ttnn.empty"
return %0 : tensor<64x128xi16>
}

func.func @test_empty_int() -> tensor<64x128xi32> {
%0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi32>}> : () -> tensor<64x128xi32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
// CHECK: %{{[0-9]+}} = "ttnn.empty"
return %0 : tensor<64x128xi32>
}

func.func @test_empty_bfloat16() -> tensor<64x128xbf16> {
%0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16>
// CHECK: %{{[0-9]+}} = "ttnn.empty"
return %0 : tensor<64x128xbf16>
}

func.func @test_empty_float() -> tensor<64x128xf32> {
%0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
// CHECK: %{{[0-9]+}} = "ttnn.empty"
return %0 : tensor<64x128xf32>
}

func.func @test_full_int8() -> tensor<64x128xi8> {
// CHECK: %{{[0-9]+}} = "ttnn.full"
// CHECK-SAME: fillValue = 1.000000e+00 : f32
// CHECK-SAME: tensor<64x128xi8
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi8>}> : () -> tensor<64x128xi8>
return %0 : tensor<64x128xi8>
}

func.func @test_full_int16() -> tensor<64x128xi16> {
// CHECK: %{{[0-9]+}} = "ttnn.full"
// CHECK-SAME: fillValue = 1.000000e+00 : f32
// CHECK-SAME: tensor<64x128xi16
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi16>}> : () -> tensor<64x128xi16>
return %0 : tensor<64x128xi16>
}

func.func @test_full_int() -> tensor<64x128xi32> {
// CHECK: %{{[0-9]+}} = "ttnn.full"
// CHECK-SAME: fillValue = 1.000000e+00 : f32
// CHECK-SAME: tensor<64x128xi32
%0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi32>}> : () -> tensor<64x128xi32>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xi32>
}

func.func @test_full_bfloat16() -> tensor<64x128xbf16> {
// CHECK: %{{[0-9]+}} = "ttnn.full"
// CHECK-SAME: fillValue = 1.000000e+00 : f32
// CHECK-SAME: tensor<64x128xbf16
%0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16>
return %0 : tensor<64x128xbf16>
}

func.func @test_full_float() -> tensor<64x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.full"
// CHECK-SAME: fillValue = 1.000000e+00 : f32
// CHECK-SAME: tensor<64x128xf32
%0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
return %0 : tensor<64x128xf32>
}
}

0 comments on commit f35a2cb

Please sign in to comment.