Skip to content

Commit

Permalink
Fixing all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtopalovicTT committed Dec 27, 2024
1 parent c8f0156 commit b664a07
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 57 deletions.
20 changes: 10 additions & 10 deletions test/ttmlir/Conversion/StableHLOToTTIR/reduce_add_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_4to3dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32x4xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128x32x4xf32>
Expand All @@ -15,7 +15,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_4to2dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32, 3 : i32]
// CHECK-SAME: dim = [1 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128x32xf32>
Expand All @@ -26,7 +26,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_4to1dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: dim = [1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128xf32>
Expand All @@ -37,7 +37,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_4to0dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: dim = [0 : i32, 1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand All @@ -48,7 +48,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128x4xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<128x4xf32>
Expand All @@ -59,7 +59,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_3to1dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32]
// CHECK-SAME: dim = [1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<128xf32>
Expand All @@ -70,7 +70,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_3to0dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32]
// CHECK-SAME: dim = [0 : i32, 1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand All @@ -81,7 +81,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xf32>
// CHECK-SAME: -> tensor<128xf32>
Expand All @@ -92,7 +92,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_2to0dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32]
// CHECK-SAME: dim = [0 : i32, 1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand All @@ -103,7 +103,7 @@ module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_1to0dim(%arg0: tensor<128xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.sum"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: dim = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand Down
20 changes: 10 additions & 10 deletions test/ttmlir/Conversion/StableHLOToTTIR/reduce_maximum_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_4to3dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32x4xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128x32x4xf32>
Expand All @@ -15,7 +15,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_4to2dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128x32xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [1 : i32, 3 : i32]
// CHECK-SAME: dim = [1 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128x32xf32>
Expand All @@ -26,7 +26,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_4to1dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: dim = [1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<128xf32>
Expand All @@ -37,7 +37,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_4to0dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: dim = [0 : i32, 1 : i32, 2 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand All @@ -48,7 +48,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128x4xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<128x4xf32>
Expand All @@ -59,7 +59,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_3to1dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32]
// CHECK-SAME: dim = [1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<128xf32>
Expand All @@ -70,7 +70,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_3to0dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32, 2 : i32]
// CHECK-SAME: dim = [0 : i32, 1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand All @@ -81,7 +81,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xf32>
// CHECK-SAME: -> tensor<128xf32>
Expand All @@ -92,7 +92,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_2to0dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [0 : i32, 1 : i32]
// CHECK-SAME: dim = [0 : i32, 1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand All @@ -103,7 +103,7 @@ module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_1to0dim(%arg0: tensor<128xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: tensor.empty
// CHECK: "ttir.max"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: dim = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128xf32>
// CHECK-SAME: -> tensor<1xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_mean.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module {
func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> {
%0 = tensor.empty() : tensor<512x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]]
%1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16>
%1 = "ttir.mean"(%arg0, %0) <{dim = [-1: i32], keep_dim = true}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16>
return %1 : tensor<512x32xbf16>
}
}
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_sum.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module attributes {} {
func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> {
%0 = tensor.empty() : tensor<512x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]]
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16>
%1 = "ttir.sum"(%arg0, %0) <{dim = [-1: i32], keep_dim = true}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16>
return %1 : tensor<512x32xbf16>
}
}
14 changes: 7 additions & 7 deletions test/ttmlir/Silicon/StableHLO/reduce_add_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
module @jit_reduce_add attributes {} {
func.func public @test_reduce_add_4to0dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.sum"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x32x4xf32,
// CHECK-SAME: -> tensor<1x1x1x1xf32,
Expand All @@ -27,7 +27,7 @@ module @jit_reduce_add attributes {} {

func.func public @test_reduce_add_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128x4xf32> {
// CHECK: "ttnn.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x4xf32,
// CHECK-SAME: -> tensor<128x1x4xf32,
Expand All @@ -41,7 +41,7 @@ module @jit_reduce_add attributes {} {

func.func public @test_reduce_add_3to1dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: "ttnn.sum"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32]
// CHECK-SAME: dim = [1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x4xf32,
// CHECK-SAME: -> tensor<128x1x1xf32,
Expand All @@ -55,7 +55,7 @@ module @jit_reduce_add attributes {} {

func.func public @test_reduce_add_3to0dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.sum"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x4xf32,
// CHECK-SAME: -> tensor<1x1x1xf32,
Expand All @@ -69,7 +69,7 @@ module @jit_reduce_add attributes {} {

func.func public @test_reduce_add_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: "ttnn.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10xf32,
// CHECK-SAME: -> tensor<128x1xf32,
Expand All @@ -83,7 +83,7 @@ module @jit_reduce_add attributes {} {

func.func public @test_reduce_add_2to0dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.sum"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10xf32,
// CHECK-SAME: -> tensor<1x1xf32,
Expand All @@ -97,7 +97,7 @@ module @jit_reduce_add attributes {} {

func.func public @test_reduce_add_1to0dim(%arg0: tensor<128xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.sum"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128xf32,
// CHECK-SAME: -> tensor<1xf32,
Expand Down
14 changes: 7 additions & 7 deletions test/ttmlir/Silicon/StableHLO/reduce_maximum_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
module @jit_reduce_maximum attributes {} {
func.func public @test_reduce_maximum_4to0dim(%arg0: tensor<128x10x32x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.max"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x32x4xf32,
// CHECK-SAME: -> tensor<1x1x1x1xf32,
Expand All @@ -27,7 +27,7 @@ module @jit_reduce_maximum attributes {} {

func.func public @test_reduce_maximum_3to2dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128x4xf32> {
// CHECK: "ttnn.max"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x4xf32,
// CHECK-SAME: -> tensor<128x1x4xf32,
Expand All @@ -41,7 +41,7 @@ module @jit_reduce_maximum attributes {} {

func.func public @test_reduce_maximum_3to1dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: "ttnn.max"
// CHECK-SAME: dim_arg = [1 : i32, 2 : i32]
// CHECK-SAME: dim = [1 : i32, 2 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x4xf32,
// CHECK-SAME: -> tensor<128x1x1xf32,
Expand All @@ -55,7 +55,7 @@ module @jit_reduce_maximum attributes {} {

func.func public @test_reduce_maximum_3to0dim(%arg0: tensor<128x10x4xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.max"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10x4xf32,
// CHECK-SAME: -> tensor<1x1x1xf32,
Expand All @@ -69,7 +69,7 @@ module @jit_reduce_maximum attributes {} {

func.func public @test_reduce_maximum_2to1dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<128xf32> {
// CHECK: "ttnn.max"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10xf32,
// CHECK-SAME: -> tensor<128x1xf32,
Expand All @@ -83,7 +83,7 @@ module @jit_reduce_maximum attributes {} {

func.func public @test_reduce_maximum_2to0dim(%arg0: tensor<128x10xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.max"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10xf32,
// CHECK-SAME: -> tensor<1x1xf32,
Expand All @@ -97,7 +97,7 @@ module @jit_reduce_maximum attributes {} {

func.func public @test_reduce_maximum_1to0dim(%arg0: tensor<128xf32>, %cst_0: tensor<f32>) -> tensor<f32> {
// CHECK: "ttnn.max"
// CHECK-NOT: dim_arg
// CHECK-NOT: dim
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128xf32,
// CHECK-SAME: -> tensor<1xf32,
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/perf_unit/test_perf_max.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
func.func @max(%arg0: tensor<1x1x512x64xbf16>) -> tensor<1x1x512xbf16> {
%0 = tensor.empty() : tensor<1x1x512xbf16>
// CHECK: %[[C:.*]] = "ttnn.max"[[C:.*]]
%1 = "ttir.max"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16>
%1 = "ttir.max"(%arg0, %0) <{dim = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16>
return %1 : tensor<1x1x512xbf16>
}
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sum.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
func.func @sum(%arg0: tensor<1x1x512x64xbf16>) -> tensor<1x1x512xbf16> {
%0 = tensor.empty() : tensor<1x1x512xbf16>
// CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]]
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16>
%1 = "ttir.sum"(%arg0, %0) <{dim = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16>
return %1 : tensor<1x1x512xbf16>
}
8 changes: 4 additions & 4 deletions test/ttmlir/Silicon/TTNN/simple_max.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ module {
func.func public @reduce_not_keep_dim(%arg0: tensor<128x10xf32>) -> tensor<128xf32> {
%0 = tensor.empty() : tensor<128xf32>
// CHECK: "ttnn.max"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10xf32,
// CHECK-SAME: -> tensor<128x1xf32,
// CHECK: "ttnn.reshape"
// CHECK-SAME: shape = [128 : i32]
// CHECK-SAME: tensor<128x1xf32,
// CHECK-SAME: -> tensor<128xf32,
%1 = "ttir.max"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
%1 = "ttir.max"(%arg0, %0) <{dim = [1 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32>
return %1 : tensor<128xf32>
}

func.func public @reduce_keep_dim(%arg0: tensor<128x10xf32>) -> tensor<128x1xf32> {
%0 = tensor.empty() : tensor<128x1xf32>
// CHECK: "ttnn.max"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: dim = [1 : i32]
// CHECK-SAME: keep_dim = true
// CHECK-SAME: tensor<128x10xf32,
// CHECK-SAME: -> tensor<128x1xf32,
// CHECK-NOT: "ttnn.reshape"
%1 = "ttir.max"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = true}> : (tensor<128x10xf32>, tensor<128x1xf32>) -> tensor<128x1xf32>
%1 = "ttir.max"(%arg0, %0) <{dim = [1 : i32], keep_dim = true}> : (tensor<128x10xf32>, tensor<128x1xf32>) -> tensor<128x1xf32>
return %1 : tensor<128x1xf32>
}
}
Loading

0 comments on commit b664a07

Please sign in to comment.