Skip to content

Commit

Permalink
TTNN relu_bw sweep migration (#10729)
Browse files Browse the repository at this point in the history
* #10147: Migrated backward relu

* #10147: Corrected ttnn ops
  • Loading branch information
npetrovic-tenstorrent authored Jul 26, 2024
1 parent eb8e4e7 commit 30065c6
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 124 deletions.
4 changes: 0 additions & 4 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,10 +1186,6 @@
"tt_op": tt_lib_ops.abs_bw,
"pytorch_op": pytorch_ops.abs_bw,
},
"relu-bw": {
"tt_op": tt_lib_ops.relu_bw,
"pytorch_op": pytorch_ops.relu_bw,
},
"gt-bw": {
"tt_op": tt_lib_ops.gt_bw,
"pytorch_op": pytorch_ops.gt_bw,
Expand Down

This file was deleted.

This file was deleted.

20 changes: 0 additions & 20 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3067,26 +3067,6 @@ def abs_bw(
return tt2torch_tensor(t2)


@setup_host_and_device
def relu_bw(
x,
y,
*args,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])

t2 = ttl.tensor.relu_bw(t0, t1, output_mem_config)[0]

return tt2torch_tensor(t2)


@setup_host_and_device
def gt_bw(
x,
Expand Down
4 changes: 4 additions & 0 deletions tests/ttnn/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,4 +773,8 @@
"tt_op": ttnn_ops.addalpha_bw,
"pytorch_op": pytorch_ops.addalpha_bw,
},
"relu-bw": {
"tt_op": ttnn_ops.relu_bw,
"pytorch_op": pytorch_ops.relu_bw,
},
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
test-list:
- relu-bw:
shape:
start-shape: [1, 1, 32, 32]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 32]
num-shapes: 2
num-samples: 128
num-dims: [2, 3, 4]
args-sampling-strategy: "all"
datagen:
function: gen_rand
args:
low: -100
high: 100
comparison:
function: comp_pcc
args-gen: gen_scalar_args
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: backward_relu_sweep.csv
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
test-list:
- relu-bw:
shape:
start-shape: [1, 1, 32, 32]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 32]
num-shapes: 2
num-samples: 128
num-dims: [2, 3, 4]
args-sampling-strategy: "all"
datagen:
function: gen_rand
args:
low: -100
high: 100
comparison:
function: comp_pcc
args-gen: gen_scalar_args
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: backward_relu_sweep.csv
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
19 changes: 19 additions & 0 deletions tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3707,3 +3707,22 @@ def rsqrt_bw(
t2 = ttnn.rsqrt_bw(t0, t1, memory_config=output_mem_config)[0]

return ttnn_tensor_to_torch(t2)


def relu_bw(
x,
y,
*args,
device,
dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_ttnn_tensor(y, device, layout[1], input_mem_config[1], dtype[1])

t2 = ttnn.relu_bw(t0, t1, memory_config=output_mem_config)[0]

return ttnn_tensor_to_torch(t2)

0 comments on commit 30065c6

Please sign in to comment.