From 924b22c18664b1cbdc787adc2c163105216fc475 Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Mon, 18 Nov 2024 10:01:35 -0500 Subject: [PATCH] Handle stableHLO 64 bit floating point type (#1284) --- .../StableHLOToTTIR/StableHLOToTTIRPass.cpp | 24 ++++++++--- .../StableHLOToTTIRPatterns.cpp | 15 ++++++- .../StableHLOToTTIR/binary/concat_op.mlir | 4 +- .../StableHLOToTTIR/constant_op.mlir | 30 +++++++++++++ .../StableHLO/Constant/constant_f64.mlir | 43 +++++++++++++++++++ 5 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp index 587d63bc4..8fc353842 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -43,18 +44,29 @@ class StablehloTypeConverter : public TypeConverter { // TTNN doesn't support either scalars or boolean data. This transformation // converts boolean to bfloat16 and scalars to 1-D tensors. + // This transformation also convert 64 bit float/integer types to 32 bit + // types. addConversion([&](RankedTensorType type) -> RankedTensorType { bool changed = false; Type elementType = type.getElementType(); llvm::ArrayRef shape = type.getShape(); + size_t bitWidth = type.getElementTypeBitWidth(); + MLIRContext *context = elementType.getContext(); // Convert the element type to bfloat16 if the input is boolean. - if (type.getElementTypeBitWidth() == 1) { - elementType = BFloat16Type::get(elementType.getContext()); - changed = true; - } else if (type.getElementTypeBitWidth() == 64 && - isa(type.getElementType())) { - elementType = IntegerType::get(elementType.getContext(), 32); + if (bitWidth == 1) { + elementType = BFloat16Type::get(context); changed = true; + } else if (bitWidth == 64) { + // Convert 64 bit integer element type to 32 bit integer. + if (isa(type.getElementType())) { + elementType = IntegerType::get(context, 32); + changed = true; + } + // Convert 64 bit float element type to 32 bit float. + else if (isa(type.getElementType())) { + elementType = FloatType::getF32(context); + changed = true; + } } // Create shape of 1-D tensor in case of scalar input. if (shape.size() == 0) { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 083e370a0..ceb353964 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -376,13 +376,16 @@ class StableHLOToTTIRConstantOpConversionPattern // converted to bfloat16 tensors. // 3. Integer tensor: TTNN does not support 64 bit integer. So they are // converted to 32 bit tensor. + // 4. Float tensor: TTNN does not support 64 bit float. So they are converted + // to 32 bit tensor. mlir::ElementsAttr getValueAttr(mlir::ElementsAttr valueAttr) const { Type elementType = valueAttr.getElementType(); size_t bitWidth = elementType.getIntOrFloatBitWidth(); bool isTensor = !valueAttr.getShapedType().getShape().empty(); bool isIntTensor = isTensor && isa(elementType) && bitWidth != 1 && bitWidth != 64; - bool isFloatTensor = isTensor && isa(elementType); + bool isFloatTensor = isTensor && isa(elementType) && + bitWidth != 1 && bitWidth != 64; if (isTensor && (isIntTensor || isFloatTensor)) { return valueAttr; @@ -413,6 +416,16 @@ class StableHLOToTTIRConstantOpConversionPattern } } if (isa(elementType)) { + // Convert 64 bit floating point numbers to 32 bit floating point numbers. + if (bitWidth == 64) { + std::vector floatValues; + for (mlir::APFloat value : valueAttr.getValues()) { + float fl = static_cast(value.convertToDouble()); + mlir::APFloat input = mlir::APFloat(fl); + floatValues.emplace_back(input); + } + return mlir::DenseElementsAttr::get(valueType, floatValues); + } // In case of float values llvm has a bug where not all float types are // supported for iterating in DenseElementsAttr, so we have to use a // different constructor. diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir index 6286f47e7..51cfd214b 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir @@ -51,7 +51,7 @@ module @jit_concat attributes {} { dimension = 1 : i64 } : (tensor<64x32xf64>, tensor<64x64xf64>) -> tensor<64x96xf64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x32xf64>, tensor<64x64xf64>, tensor<64x96xf64>) -> tensor<64x96xf64> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<64x96xf32>) -> tensor<64x96xf32> return %0 : tensor<64x96xf64> } @@ -69,7 +69,7 @@ module @jit_concat attributes {} { dimension = 3 : i64 } : (tensor<3x2x4x5xf64>, tensor<3x2x4x3xf64>) -> tensor<3x2x4x8xf64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 3 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2x4x5xf64>, tensor<3x2x4x3xf64>, tensor<3x2x4x8xf64>) -> tensor<3x2x4x8xf64> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 3 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2x4x5xf32>, tensor<3x2x4x3xf32>, tensor<3x2x4x8xf32>) -> tensor<3x2x4x8xf32> return %0 : tensor<3x2x4x8xf64> } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir index b00a1bec6..eb0aa5b95 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir @@ -93,6 +93,36 @@ module @jit_constant attributes {} { return %0 : tensor<2x2xf32> } + func.func public @test_f64_scalar() -> tensor { + // CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = stablehlo.constant dense<0.3> : tensor + // CHECK: return %[[VAL]] : tensor<1xf32> + return %0 : tensor + } + + func.func public @test_f64_splat() -> tensor<64xf64> { + // CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<64xf32>}> : () -> tensor<64xf32> + %0 = stablehlo.constant dense<0.3> : tensor<64xf64> + // CHECK: return %[[VAL]] : tensor<64xf32> + return %0 : tensor<64xf64> + } + + func.func public @test_f64_multiple() -> tensor<2x2xf64> { + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf64> + // CHECK: return %[[VAL]] : tensor<2x2xf32> + return %0 : tensor<2x2xf64> + } + + func.func public @test_f64_inf() -> tensor { + // CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<0xFF800000> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = stablehlo.constant dense<0xFFF0000000000000> : tensor + // CHECK: return %[[VAL]] : tensor<1xf32> + return %0 : tensor + } + func.func public @test_int8_scalar() -> tensor { // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8> %0 = stablehlo.constant dense<3> : tensor diff --git a/test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir b/test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir new file mode 100644 index 000000000..cc3917816 --- /dev/null +++ b/test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir @@ -0,0 +1,43 @@ +// REQUIRES: stablehlo +// RUN: rm -rf %t.ttnn +// RUN: rm -rf %t.mlir +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// RUN: FileCheck --input-file=%t.mlir %s + +module @jit_constant attributes {} { + func.func public @test_f64_scalar() -> tensor { + // CHECK-LABEL: func.func public @test_f64_scalar + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<1xf32 + %0 = stablehlo.constant dense<3.0> : tensor + return %0 : tensor + } + + func.func public @test_f64_scalar_empty() -> tensor { + // CHECK-LABEL: func.func public @test_f64_scalar_empty + // CHECK: ttnn.empty + // CHECK-SAME: -> tensor<1xf32 + %0 = stablehlo.constant dense<0.0> : tensor + return %0 : tensor + } + + func.func public @test_f64_empty() -> tensor<64x128xf64> { + // CHECK-LABEL: func.func public @test_f64_empty + // CHECK: ttnn.empty + // CHECK-SAME: -> tensor<64x128xf32 + %0 = stablehlo.constant dense<0.0> : tensor<64x128xf64> + return %0 : tensor<64x128xf64> + } + + func.func public @test_f64_splat() -> tensor<64x128xf64> { + // CHECK-LABEL: func.func public @test_f64_splat + // CHECK: ttnn.full + // CHECK-SAME: fillValue = 3.000000e+00 : f32 + // CHECK-SAME: -> tensor<64x128xf32 + %0 = stablehlo.constant dense<3.0> : tensor<64x128xf64> + return %0 : tensor<64x128xf64> + } +}