Skip to content

Commit

Permalink
#7712: Update ttnn docs for elu, erf-alike, log-alike ops
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed May 10, 2024
1 parent b3945f9 commit 607f564
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 25 deletions.
17 changes: 14 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,26 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
"alpha": [-0.5, 0, 0.5],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
alpha,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT:
return True, "Not Supported"
return False, None


Expand Down
16 changes: 13 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/erf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT:
return True, "Not Supported"
return False, None


Expand Down
16 changes: 13 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/erfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT:
return True, "Not Supported"
return False, None


Expand Down
16 changes: 13 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/erfinv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT or input_dtype == ttnn.bfloat8_b:
return True, "Not Supported"
return False, None


Expand Down
16 changes: 13 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/log10.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT or input_dtype == ttnn.bfloat8_b:
return True, "Not Supported"
return False, None


Expand Down
16 changes: 13 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/log1p.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT or input_dtype == ttnn.bfloat8_b:
return True, "Not Supported"
return False, None


Expand Down
16 changes: 13 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/log2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT or input_dtype == ttnn.bfloat8_b:
return True, "Not Supported"
return False, None


Expand Down
16 changes: 13 additions & 3 deletions tests/ttnn/sweep_tests/sweeps/neg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
"batch_sizes": [(1,)],
"height": [384, 1024],
"width": [1024, 4096],
"input_dtype": [ttnn.bfloat16],
"input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"output_memory_config": [ttnn.DRAM_MEMORY_CONFIG],
"layout": [ttnn.TILE_LAYOUT],
"layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
}


def skip(**_) -> Tuple[bool, Optional[str]]:
def skip(
batch_sizes,
height,
width,
input_dtype,
input_memory_config,
output_memory_config,
layout,
) -> Tuple[bool, Optional[str]]:
if layout == ttnn.ROW_MAJOR_LAYOUT:
return True, "Not Supported"
return False, None


Expand Down
8 changes: 7 additions & 1 deletion ttnn/ttnn/operations/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,15 @@ def _golden_function(input_tensor: ttnn.Tensor, **_):
return torch_function(input_tensor)

def _math_op_validate_input_tensors(operation_name, input_tensor, *args, **kwargs):
if operation_name in ["ttnn.erfinv", "ttnn.log2", "ttnn.log10", "ttnn.log1p"]:
supported_dtypes = (ttnn.bfloat16,)
else:
supported_dtypes = (ttnn.bfloat16, ttnn.bfloat8_b)
ttnn.validate_input_tensor(
operation_name,
input_tensor,
ranks=(2, 3, 4),
dtypes=(ttnn.bfloat16, ttnn.bfloat8_b),
dtypes=supported_dtypes,
layouts=(ttnn.TILE_LAYOUT,),
can_be_on_device=True,
can_be_on_cpu=False,
Expand Down Expand Up @@ -102,6 +106,8 @@ def math_op_function(
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = ttnn.{(name)}(tensor)
{math_op_function.__doc__}
"""
setattr(THIS_MODULE, name, math_op_function)

Expand Down

0 comments on commit 607f564

Please sign in to comment.