From 46d74887c5ee283d716c17a1ec0ca713b9a18952 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Mon, 23 Dec 2024 21:58:31 +0000 Subject: [PATCH] Adding more softmax tests --- .../Fusion/softmax/test_fuse_softmax.mlir | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/test/ttmlir/Dialect/TTIR/Fusion/softmax/test_fuse_softmax.mlir b/test/ttmlir/Dialect/TTIR/Fusion/softmax/test_fuse_softmax.mlir index 7584396f4..dc59d262e 100644 --- a/test/ttmlir/Dialect/TTIR/Fusion/softmax/test_fuse_softmax.mlir +++ b/test/ttmlir/Dialect/TTIR/Fusion/softmax/test_fuse_softmax.mlir @@ -1,6 +1,6 @@ // RUN: ttmlir-opt --ttir-fusion %s | FileCheck %s module attributes {} { - func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + func.func @softmax_pattern_with_explicit_broadcast(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { // CHECK: %[[C:.*]] = "ttir.softmax"[[C:.*]] %dps1 = tensor.empty() : tensor<1x32x128x128xf32> %1 = "ttir.exp"(%arg0, %dps1) {operandSegmentSizes = array} : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> @@ -13,3 +13,33 @@ module attributes {} { return %4 : tensor<1x32x128x128xf32> } } + +module attributes {} { + func.func @softmax_pattern_with_implicit_broadcast(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.softmax"[[C:.*]] + %dps1 = tensor.empty() : tensor<1x32x128x128xf32> + %1 = "ttir.exp"(%arg0, %dps1) {operandSegmentSizes = array} : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + %dps2 = tensor.empty() : tensor<1x32x128x1xf32> + %2 = "ttir.sum"(%1, %dps2) {keep_dim = true, dim_arg = [3 : i32]} : (tensor<1x32x128x128xf32>, tensor<1x32x128x1xf32>) -> tensor<1x32x128x1xf32> + %dps3 = tensor.empty() : tensor<1x32x128x128xf32> + %3 = "ttir.div"(%1, %2, %dps3) {operandSegmentSizes = array} : (tensor<1x32x128x128xf32>, tensor<1x32x128x1xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %3 : tensor<1x32x128x128xf32> + } +} + +module attributes {} { + func.func @softmax_pattern_with_fusable_keepdim_reduce_and_broadcast(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { + // CHECK: %[[C:.*]] = "ttir.softmax"[[C:.*]] + %dps1 = tensor.empty() : tensor<1x32x128x128xf32> + %1 = "ttir.exp"(%arg0, %dps1) {operandSegmentSizes = array} : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + %dps2 = tensor.empty() : tensor<1x32x128xf32> + %2 = "ttir.sum"(%1, %dps2) {keep_dim = false, dim_arg = [3 : i32]} : (tensor<1x32x128x128xf32>, tensor<1x32x128xf32>) -> tensor<1x32x128xf32> + %dps3 = tensor.empty() : tensor<1x32x128x1xf32> + %3 = "ttir.reshape"(%2, %dps3) {shape = [1: i32, 32: i32, 128: i32, 1: i32]} : (tensor<1x32x128xf32>, tensor<1x32x128x1xf32>) -> tensor<1x32x128x1xf32> + %dps4 = tensor.empty() : tensor<1x32x128x128xf32> + %4 = "ttir.broadcast"(%3, %dps4) {dimension = [3 : i64]} : (tensor<1x32x128x1xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + %dps5 = tensor.empty() : tensor<1x32x128x128xf32> + %5 = "ttir.div"(%1, %4, %dps4) {operandSegmentSizes = array} : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + return %5 : tensor<1x32x128x128xf32> + } +}