Skip to content

Commit

Permalink
#8865: Port ttnn changed ops to dispatch profiling infra (#10772)
Browse files Browse the repository at this point in the history
* #8865: Port ttnn changed ttnn ops to dispatch profiling infra
* #8865: Update reference times for dispatch profiling infra
  • Loading branch information
nemanjagrujic authored Jul 31, 2024
1 parent d8ba4eb commit 0e021e2
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 94 deletions.
89 changes: 54 additions & 35 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def bcast_hw_shape_func_11(input_shape):
return input_shape, input_shape_1


def bcast_h_shape_func_1(input_shape):
input_shape_1 = [input_shape[-4], input_shape[-3], 1, input_shape[-1]]
return input_shape, input_shape_1


def bcast_w_shape_func_1(input_shape):
input_shape_1 = [input_shape[-4], input_shape[-3], input_shape[-2], 1]
return input_shape, input_shape_1


def complex_add(x, y):
tt_lib.tensor.complex_add(
x, y, tt_lib.tensor.MemoryConfig(tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, tt_lib.tensor.BufferType.DRAM)
Expand Down Expand Up @@ -151,15 +161,15 @@ def unary_pow_bw(x, y):


def clamp_bw(x, y):
ttnn.clamp_bw(x, y, 0.1, 0.9)
ttnn.clamp_bw(x, y, min=0.1, max=0.9)


def clamp_min_bw(x, y):
ttnn.clamp_min_bw(x, y, 0.1)
ttnn.clamp_bw(x, y, min=0.1)


def clamp_max_bw(x, y):
ttnn.clamp_max_bw(x, y, 0.9)
ttnn.clamp_bw(x, y, max=0.9)


def gelu_bw_none(x, y):
Expand Down Expand Up @@ -207,7 +217,7 @@ def unary_eq_bw(x, y):


def logiteps_bw(x, y):
ttnn.logiteps_bw(x, y, 0.0001)
ttnn.logiteps_bw(x, y, eps=0.0001)


def fmod_bw(x, y):
Expand Down Expand Up @@ -418,23 +428,23 @@ def angle_bw(x, y):


def celu_bw(x, y):
ttnn.celu_bw(x, y, 1)
ttnn.celu_bw(x, y, alpha=1)


def hardshrink_bw(x, y):
ttnn.hardshrink_bw(x, y, 0.5)
ttnn.hardshrink_bw(x, y, lambd=0.5)


def leaky_relu_bw(x, y):
ttnn.leaky_relu_bw(x, y, 0.3)
ttnn.leaky_relu_bw(x, y, negative_slope=0.3)


def softshrink_bw(x, y):
ttnn.softshrink_bw(x, y, 0.5)
ttnn.softshrink_bw(x, y, lambd=0.5)


def unary_div_bw(x, y):
ttnn.div_bw(x, y, 3, round_mode="None")
ttnn.div_bw(x, y, 3.0, round_mode="None")


all_binary_ops = [
Expand All @@ -450,6 +460,16 @@ def unary_div_bw(x, y):
"op": ttnn.mul,
"name": "ttnn.mul",
},
{
"op": ttnn.mul,
"name": "ttnn.mul_bcast_h",
"shape_func": bcast_h_shape_func_1,
},
{
"op": ttnn.mul,
"name": "ttnn.mul_bcast_w",
"shape_func": bcast_w_shape_func_1,
},
{
"op": ttnn.mul,
"name": "ttnn.mul_bcast_hw",
Expand Down Expand Up @@ -646,7 +666,7 @@ def unary_div_bw(x, y):
},
{
"op": ttnn.embedding,
"name": "tt_lib.tensor.embeddings",
"name": "ttnn.embedding",
"layout": "ROW_MAJOR",
"shape_func": embeddings_shape_func,
},
Expand Down Expand Up @@ -1170,11 +1190,11 @@ def leaky_relu(x):


def softshrink(x):
ttnn.softshrink(x, 70)
ttnn.softshrink(x, lambd=70)


def hardshrink(x):
ttnn.hardshrink(x, 1)
ttnn.hardshrink(x, lambd=1)


def elu(x):
Expand All @@ -1194,7 +1214,7 @@ def bias_gelu_unary(x):


def logit(x):
ttnn.logit(x, 0.0001)
ttnn.logit(x, eps=0.0001)


def logical_andi(x):
Expand Down Expand Up @@ -1309,14 +1329,6 @@ def empty(x):
ttnn.empty(shape=x.get_legacy_shape(), dtype=x.get_dtype(), layout=x.get_layout(), device=x.device())


def tril(x):
ttnn.tril(x, 1)


def triu(x):
ttnn.triu(x, 1)


def sum_dim_2(x):
ttnn.sum(x, dim=2)

Expand Down Expand Up @@ -1951,6 +1963,7 @@ def clone(x):
{
"op": tilize,
"name": "ttnn.tilize",
"layout": "ROW_MAJOR",
},
{
"op": tt_lib.tensor.untilize,
Expand All @@ -1968,6 +1981,7 @@ def clone(x):
{
"op": ttnn.tilize_with_zero_padding,
"name": "ttnn.tilize_with_zero_padding",
"layout": "ROW_MAJOR",
},
{
"op": pad,
Expand Down Expand Up @@ -2022,12 +2036,12 @@ def clone(x):
"name": "ttnn.empty",
},
{
"op": tril,
"op": ttnn.tril,
"name": "ttnn.tril",
"num_repeats": 3,
},
{
"op": triu,
"op": ttnn.triu,
"name": "ttnn.triu",
"num_repeats": 3,
},
Expand Down Expand Up @@ -2181,16 +2195,16 @@ def clone(x):
"name": "tt_lib.tensor.mean_hw",
},
{
"op": tt_lib.tensor.var_hw,
"name": "tt_lib.tensor.var_hw",
"op": ttnn.var_hw,
"name": "ttnn.var_hw",
},
{
"op": logical_noti,
"name": "tt_lib.tensor.logical_noti",
},
{
"op": tt_lib.tensor.std_hw,
"name": "tt_lib.tensor.std_hw",
"op": ttnn.std_hw,
"name": "ttnn.std_hw",
},
{
"op": ttnn.normalize_hw,
Expand Down Expand Up @@ -2534,18 +2548,23 @@ def div_bw(x, y, z):
ttnn.div_bw(x, y, z, round_mode="None")


def add_bw(x, y, z):
ttnn.add_bw(x, y, z)


def primary_moreh_norm_backward(x, y, z):
tt_lib.operations.primary.moreh_norm_backward(x, y, z, p=2.0)


def fused_linear(x, weight, bias):
def linear(x, weight, bias):
ttnn.linear(x, weight, bias=bias)


def fused_linear_shape_func(input_shape):
x_shape = [1, 1, input_shape[-2], input_shape[-1]]
weight_shape = [1, 1, input_shape[-2], input_shape[-1]]
bias_shape = [1, 1, 32, input_shape[-1]]
def linear_shape_func(input_shape):
N = input_shape[-1]
x_shape = [1, input_shape[-2], N]
weight_shape = [N, N]
bias_shape = [1, N]
return x_shape, weight_shape, bias_shape


Expand Down Expand Up @@ -2634,7 +2653,7 @@ def fused_linear_shape_func(input_shape):
"name": "ttnn.min_bw",
},
{
"op": ttnn.add_bw,
"op": add_bw,
"name": "ttnn.add_bw",
},
# {
Expand Down Expand Up @@ -2726,9 +2745,9 @@ def fused_linear_shape_func(input_shape):
"name": "tt_lib.tensor.moreh_norm_backward",
},
{
"op": fused_linear,
"op": linear,
"name": "ttnn.linear",
"shape_func": fused_linear_shape_func,
"shape_func": linear_shape_func,
},
{
"op": ttnn.ge_bw,
Expand Down
Loading

0 comments on commit 0e021e2

Please sign in to comment.