Skip to content

Commit

Permalink
Add sweeps for complex bw_ops: polar, recip, add, mul (#11364)
Browse files Browse the repository at this point in the history
* #10147: Added sweeps for complex_polar_bw to ttnn

* #10147: Added complex_recip_bw sweeps to ttnn

* #10147: Added complex_mul_bw sweeps to ttnn

* #10147: Add complexadd_bw sweeps to ttnn

* #10147: Added support for dims 2, 3, 4 in complex_backward_mul and complex_backward_add YAML files

* #10147: Reformating

* #10147: Add num-dims: [2, 3, 4] to complex ops: abs, recip, polar

* #10147: Reformating
  • Loading branch information
amalbasaTT authored Aug 14, 2024
1 parent 99428fa commit 18dd2d1
Show file tree
Hide file tree
Showing 19 changed files with 435 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,6 +2067,67 @@ def cos_bw(x, y, *args, **kwargs):
return in_data.grad


def complex_polar_bw(x, y, *args, **kwargs):
grad_data = x
in_data = y
in_data.requires_grad = True

in_data.retain_grad()
pyt_y = torch.polar(in_data.real, in_data.imag)
pyt_y.backward(gradient=grad_data)

grad_real = torch.real(in_data.grad)
grad_imag = torch.imag(in_data.grad)

return torch.complex(grad_real, grad_imag)


def complex_recip_bw(x, y, *args, **kwargs):
grad_data = x
in_data = y
in_data.requires_grad = True

in_data.retain_grad()
pyt_y = torch.reciprocal(in_data)
pyt_y.backward(gradient=grad_data)

return in_data.grad


def complex_mul_bw(x, y, z, *args, **kwargs):
grad_data = x
in_data = y
other_data = z

in_data.requires_grad = True
other_data.requires_grad = True

in_data.retain_grad()
other_data.retain_grad()

pyt_y = in_data * other_data
pyt_y.backward(gradient=grad_data)

return [in_data.grad, other_data.grad]


def complex_add_bw(x, y, z, *args, **kwargs):
grad_data = x
in_data = y
other_data = z

in_data.requires_grad = True
other_data.requires_grad = True

in_data.retain_grad()
other_data.retain_grad()

pyt_y = in_data + other_data
pyt_y.backward(gradient=grad_data)

return [in_data.grad, other_data.grad]


def global_avg_pool2d(x, *args, **kwargs):
output_size = (1, 1)
x = x.to(torch.float32)
Expand Down
20 changes: 20 additions & 0 deletions tests/ttnn/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,10 @@
"tt_op": ttnn_ops.eltwise_unary_div_no_nan,
"pytorch_op": pytorch_ops.unary_div_no_nan,
},
"complex-polar-bw": {
"tt_op": ttnn_ops.complex_polar_bw,
"pytorch_op": pytorch_ops.complex_polar_bw,
},
"complex-conj": {
"tt_op": ttnn_ops.complex_conj,
"pytorch_op": pytorch_ops.complex_conj,
Expand All @@ -941,4 +945,20 @@
"tt_op": ttnn_ops.complex_recip,
"pytorch_op": pytorch_ops.complex_recip,
},
"complex-polar-bw": {
"tt_op": ttnn_ops.complex_polar_bw,
"pytorch_op": pytorch_ops.complex_polar_bw,
},
"complex-recip-bw": {
"tt_op": ttnn_ops.complex_recip_bw,
"pytorch_op": pytorch_ops.complex_recip_bw,
},
"complex-mul-bw": {
"tt_op": ttnn_ops.complex_mul_bw,
"pytorch_op": pytorch_ops.complex_mul_bw,
},
"complex-add-bw": {
"tt_op": ttnn_ops.complex_add_bw,
"pytorch_op": pytorch_ops.complex_add_bw,
},
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
test-list:
- complex-polar-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-shapes: 2
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_polar_bw_sweep.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
test-list:
- complex-add-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 3
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc_list
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16", "BFLOAT8_B"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_bw_add_sweep.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
test-list:
- complex-mul-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 3
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc_list
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16", "BFLOAT8_B"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_bw_mul_sweep.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
test-list:
- complex-recip-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-shapes: 2
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_bw_recip_sweep.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
test-list:
- complex-add-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 3
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc_list
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16", "BFLOAT8_B"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_bw_add_sweep.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
---
test-list:
- complex-mul-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 3
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc_list
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16", "BFLOAT8_B"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_bw_mul_sweep.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
test-list:
- complex-polar-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-shapes: 2
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_polar_bw_sweep.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
test-list:
- complex-recip-bw:
shape:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-shapes: 2
num-samples: 128
args-sampling-strategy: "all"
env:
# TT_PCI_DMA_BUF_SIZE: "1048576"
datagen:
function: gen_rand_complex
args:
low: -100
high: 100
comparison:
function: comp_pcc
args-gen: gen_dtype_layout_device
args:
data-layout: ["TILE"]
data-type: ["BFLOAT16"]
buffer-type: ["DRAM", "L1"]
out-buffer-type: ["DRAM", "L1"]
output-file: complex_bw_recip_sweep.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test-list:
start-shape: [1, 1, 32, 64]
end-shape: [6, 12, 256, 256]
interval: [1, 1, 32, 64]
num-dims: [2, 3, 4]
num-shapes: 1
num-samples: 128
args-sampling-strategy: "all"
Expand Down
Loading

0 comments on commit 18dd2d1

Please sign in to comment.