From f35a2cbca5b7eed6d318729fddf2fa948bfba6ff Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Fri, 8 Nov 2024 11:21:51 -0500 Subject: [PATCH] Handle different value types for constant op explicitly for TTNN backend (#1164) --- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 24 +++++++-- test/ttmlir/Dialect/TTNN/simple_constant.mlir | 54 +++++++++++++++++-- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index f18d4034b..e7cf8916c 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -503,9 +503,10 @@ class ConstantOpConversionPattern if (valueAttr.isSplat()) { Value device = getOrInsertDevice(rewriter, op); - float fillValue = valueAttr.getElementType().isInteger() - ? static_cast(valueAttr.getSplatValue()) - : valueAttr.getSplatValue(); + float fillValue = + valueAttr.getElementType().isInteger() + ? getIntegerValue(valueAttr) + : valueAttr.getSplatValue().convertToFloat(); if (fillValue == 0) { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), device); @@ -536,6 +537,23 @@ class ConstantOpConversionPattern return success(); } + + float getIntegerValue(mlir::ElementsAttr valueAttr) const { + size_t bitWidth = valueAttr.getElementType().getIntOrFloatBitWidth(); + switch (bitWidth) { + case 1: + return static_cast(valueAttr.getSplatValue()); + case 8: + return static_cast(valueAttr.getSplatValue()); + case 16: + return static_cast(valueAttr.getSplatValue()); + case 32: + return static_cast(valueAttr.getSplatValue()); + case 64: + return static_cast(valueAttr.getSplatValue()); + } + assert(false && "Unsupported integer type."); + } }; } // namespace diff --git a/test/ttmlir/Dialect/TTNN/simple_constant.mlir b/test/ttmlir/Dialect/TTNN/simple_constant.mlir index fc85c13ef..88df7aad2 100644 --- a/test/ttmlir/Dialect/TTNN/simple_constant.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_constant.mlir @@ -1,27 +1,73 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { + func.func @test_empty_int8() -> tensor<64x128xi8> { + %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi8>}> : () -> tensor<64x128xi8> + // CHECK: %{{[0-9]+}} = "ttnn.empty" + return %0 : tensor<64x128xi8> + } + + func.func @test_empty_int16() -> tensor<64x128xi16> { + %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi16>}> : () -> tensor<64x128xi16> + // CHECK: %{{[0-9]+}} = "ttnn.empty" + return %0 : tensor<64x128xi16> + } + func.func @test_empty_int() -> tensor<64x128xi32> { %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi32>}> : () -> tensor<64x128xi32> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + // CHECK: %{{[0-9]+}} = "ttnn.empty" return %0 : tensor<64x128xi32> } + func.func @test_empty_bfloat16() -> tensor<64x128xbf16> { + %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16> + // CHECK: %{{[0-9]+}} = "ttnn.empty" + return %0 : tensor<64x128xbf16> + } + func.func @test_empty_float() -> tensor<64x128xf32> { %0 = "ttir.constant"() <{value = dense<0.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + // CHECK: %{{[0-9]+}} = "ttnn.empty" return %0 : tensor<64x128xf32> } + func.func @test_full_int8() -> tensor<64x128xi8> { + // CHECK: %{{[0-9]+}} = "ttnn.full" + // CHECK-SAME: fillValue = 1.000000e+00 : f32 + // CHECK-SAME: tensor<64x128xi8 + %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi8>}> : () -> tensor<64x128xi8> + return %0 : tensor<64x128xi8> + } + + func.func @test_full_int16() -> tensor<64x128xi16> { + // CHECK: %{{[0-9]+}} = "ttnn.full" + // CHECK-SAME: fillValue = 1.000000e+00 : f32 + // CHECK-SAME: tensor<64x128xi16 + %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi16>}> : () -> tensor<64x128xi16> + return %0 : tensor<64x128xi16> + } + func.func @test_full_int() -> tensor<64x128xi32> { + // CHECK: %{{[0-9]+}} = "ttnn.full" + // CHECK-SAME: fillValue = 1.000000e+00 : f32 + // CHECK-SAME: tensor<64x128xi32 %0 = "ttir.constant"() <{value = dense<1> : tensor<64x128xi32>}> : () -> tensor<64x128xi32> - // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] return %0 : tensor<64x128xi32> } + func.func @test_full_bfloat16() -> tensor<64x128xbf16> { + // CHECK: %{{[0-9]+}} = "ttnn.full" + // CHECK-SAME: fillValue = 1.000000e+00 : f32 + // CHECK-SAME: tensor<64x128xbf16 + %0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xbf16>}> : () -> tensor<64x128xbf16> + return %0 : tensor<64x128xbf16> + } + func.func @test_full_float() -> tensor<64x128xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.full" + // CHECK-SAME: fillValue = 1.000000e+00 : f32 + // CHECK-SAME: tensor<64x128xf32 %0 = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<64x128xf32>}> : () -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] return %0 : tensor<64x128xf32> } }