Skip to content

Commit

Permalink
Handle stableHLO 64 bit floating point type (#1284)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT authored Nov 18, 2024
1 parent 41a087a commit 924b22c
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 9 deletions.
24 changes: 18 additions & 6 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Pass/Pass.h>

Expand Down Expand Up @@ -43,18 +44,29 @@ class StablehloTypeConverter : public TypeConverter {

// TTNN doesn't support either scalars or boolean data. This transformation
// converts boolean to bfloat16 and scalars to 1-D tensors.
// This transformation also convert 64 bit float/integer types to 32 bit
// types.
addConversion([&](RankedTensorType type) -> RankedTensorType {
bool changed = false;
Type elementType = type.getElementType();
llvm::ArrayRef<int64_t> shape = type.getShape();
size_t bitWidth = type.getElementTypeBitWidth();
MLIRContext *context = elementType.getContext();
// Convert the element type to bfloat16 if the input is boolean.
if (type.getElementTypeBitWidth() == 1) {
elementType = BFloat16Type::get(elementType.getContext());
changed = true;
} else if (type.getElementTypeBitWidth() == 64 &&
isa<IntegerType>(type.getElementType())) {
elementType = IntegerType::get(elementType.getContext(), 32);
if (bitWidth == 1) {
elementType = BFloat16Type::get(context);
changed = true;
} else if (bitWidth == 64) {
// Convert 64 bit integer element type to 32 bit integer.
if (isa<IntegerType>(type.getElementType())) {
elementType = IntegerType::get(context, 32);
changed = true;
}
// Convert 64 bit float element type to 32 bit float.
else if (isa<FloatType>(type.getElementType())) {
elementType = FloatType::getF32(context);
changed = true;
}
}
// Create shape of 1-D tensor in case of scalar input.
if (shape.size() == 0) {
Expand Down
15 changes: 14 additions & 1 deletion lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,16 @@ class StableHLOToTTIRConstantOpConversionPattern
// converted to bfloat16 tensors.
// 3. Integer tensor: TTNN does not support 64 bit integer. So they are
// converted to 32 bit tensor.
// 4. Float tensor: TTNN does not support 64 bit float. So they are converted
// to 32 bit tensor.
mlir::ElementsAttr getValueAttr(mlir::ElementsAttr valueAttr) const {
Type elementType = valueAttr.getElementType();
size_t bitWidth = elementType.getIntOrFloatBitWidth();
bool isTensor = !valueAttr.getShapedType().getShape().empty();
bool isIntTensor = isTensor && isa<IntegerType>(elementType) &&
bitWidth != 1 && bitWidth != 64;
bool isFloatTensor = isTensor && isa<FloatType>(elementType);
bool isFloatTensor = isTensor && isa<FloatType>(elementType) &&
bitWidth != 1 && bitWidth != 64;

if (isTensor && (isIntTensor || isFloatTensor)) {
return valueAttr;
Expand Down Expand Up @@ -413,6 +416,16 @@ class StableHLOToTTIRConstantOpConversionPattern
}
}
if (isa<FloatType>(elementType)) {
// Convert 64 bit floating point numbers to 32 bit floating point numbers.
if (bitWidth == 64) {
std::vector<mlir::APFloat> floatValues;
for (mlir::APFloat value : valueAttr.getValues<mlir::APFloat>()) {
float fl = static_cast<float>(value.convertToDouble());
mlir::APFloat input = mlir::APFloat(fl);
floatValues.emplace_back(input);
}
return mlir::DenseElementsAttr::get(valueType, floatValues);
}
// In case of float values llvm has a bug where not all float types are
// supported for iterating in DenseElementsAttr, so we have to use a
// different constructor.
Expand Down
4 changes: 2 additions & 2 deletions test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ module @jit_concat attributes {} {
dimension = 1 : i64
} : (tensor<64x32xf64>, tensor<64x64xf64>) -> tensor<64x96xf64>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x32xf64>, tensor<64x64xf64>, tensor<64x96xf64>) -> tensor<64x96xf64>
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<64x96xf32>) -> tensor<64x96xf32>
return %0 : tensor<64x96xf64>
}

Expand All @@ -69,7 +69,7 @@ module @jit_concat attributes {} {
dimension = 3 : i64
} : (tensor<3x2x4x5xf64>, tensor<3x2x4x3xf64>) -> tensor<3x2x4x8xf64>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 3 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2x4x5xf64>, tensor<3x2x4x3xf64>, tensor<3x2x4x8xf64>) -> tensor<3x2x4x8xf64>
// CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 3 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2x4x5xf32>, tensor<3x2x4x3xf32>, tensor<3x2x4x8xf32>) -> tensor<3x2x4x8xf32>
return %0 : tensor<3x2x4x8xf64>
}

Expand Down
30 changes: 30 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,36 @@ module @jit_constant attributes {} {
return %0 : tensor<2x2xf32>
}

func.func public @test_f64_scalar() -> tensor<f64> {
// CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = stablehlo.constant dense<0.3> : tensor<f64>
// CHECK: return %[[VAL]] : tensor<1xf32>
return %0 : tensor<f64>
}

func.func public @test_f64_splat() -> tensor<64xf64> {
// CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<64xf32>}> : () -> tensor<64xf32>
%0 = stablehlo.constant dense<0.3> : tensor<64xf64>
// CHECK: return %[[VAL]] : tensor<64xf32>
return %0 : tensor<64xf64>
}

func.func public @test_f64_multiple() -> tensor<2x2xf64> {
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
%0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf64>
// CHECK: return %[[VAL]] : tensor<2x2xf32>
return %0 : tensor<2x2xf64>
}

func.func public @test_f64_inf() -> tensor<f64> {
// CHECK: %[[VAL:[0-9]+]] = "ttir.constant"() <{value = dense<0xFF800000> : tensor<1xf32>}> : () -> tensor<1xf32>
%0 = stablehlo.constant dense<0xFFF0000000000000> : tensor<f64>
// CHECK: return %[[VAL]] : tensor<1xf32>
return %0 : tensor<f64>
}

func.func public @test_int8_scalar() -> tensor<i8> {
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = stablehlo.constant dense<3> : tensor<i8>
Expand Down
43 changes: 43 additions & 0 deletions test/ttmlir/Silicon/StableHLO/Constant/constant_f64.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_constant attributes {} {
func.func public @test_f64_scalar() -> tensor<f64> {
// CHECK-LABEL: func.func public @test_f64_scalar
// CHECK: ttnn.full
// CHECK-SAME: fillValue = 3.000000e+00 : f32
// CHECK-SAME: -> tensor<1xf32
%0 = stablehlo.constant dense<3.0> : tensor<f64>
return %0 : tensor<f64>
}

func.func public @test_f64_scalar_empty() -> tensor<f64> {
// CHECK-LABEL: func.func public @test_f64_scalar_empty
// CHECK: ttnn.empty
// CHECK-SAME: -> tensor<1xf32
%0 = stablehlo.constant dense<0.0> : tensor<f64>
return %0 : tensor<f64>
}

func.func public @test_f64_empty() -> tensor<64x128xf64> {
// CHECK-LABEL: func.func public @test_f64_empty
// CHECK: ttnn.empty
// CHECK-SAME: -> tensor<64x128xf32
%0 = stablehlo.constant dense<0.0> : tensor<64x128xf64>
return %0 : tensor<64x128xf64>
}

func.func public @test_f64_splat() -> tensor<64x128xf64> {
// CHECK-LABEL: func.func public @test_f64_splat
// CHECK: ttnn.full
// CHECK-SAME: fillValue = 3.000000e+00 : f32
// CHECK-SAME: -> tensor<64x128xf32
%0 = stablehlo.constant dense<3.0> : tensor<64x128xf64>
return %0 : tensor<64x128xf64>
}
}

0 comments on commit 924b22c

Please sign in to comment.