diff --git a/docs/source/ttnn/ttnn/ttnn/logical_not.rst b/docs/source/ttnn/ttnn/ttnn/logical_not.rst index de46c16d9dd..af05192bba8 100644 --- a/docs/source/ttnn/ttnn/ttnn/logical_not.rst +++ b/docs/source/ttnn/ttnn/ttnn/logical_not.rst @@ -1,6 +1,6 @@ .. _ttnn.logical_not: ttnn.logical_not -###################### -logical_not +################# + .. autofunction:: ttnn.logical_not diff --git a/tests/ttnn/sweep_tests/sweeps/logical_not.py b/tests/ttnn/sweep_tests/sweeps/logical_not.py index 6daf6c2f432..f60f6cc2c3e 100644 --- a/tests/ttnn/sweep_tests/sweeps/logical_not.py +++ b/tests/ttnn/sweep_tests/sweeps/logical_not.py @@ -16,14 +16,24 @@ "batch_sizes": [(1,)], "height": [384, 1024], "width": [1024, 4096], - "input_dtype": [ttnn.bfloat16], + "input_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], "input_memory_config": [ttnn.DRAM_MEMORY_CONFIG], "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG], - "layout": [ttnn.TILE_LAYOUT], + "layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], } -def skip(**_) -> Tuple[bool, Optional[str]]: +def skip( + batch_sizes, + height, + width, + input_dtype, + input_memory_config, + output_memory_config, + layout, +) -> Tuple[bool, Optional[str]]: + if input_dtype == ttnn.bfloat8_b: + return True, "Not Supported" return False, None diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 875f1a57585..32ce0a3423a 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -42,11 +42,15 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): return torch_function(input_tensor) def _unary_validate_input_tensors(operation_name, input_tensor, *args, **kwargs): + if operation_name in ["ttnn.logical_not"]: + supported_dtypes = (ttnn.bfloat16,) + else: + supported_dtypes = (ttnn.bfloat16, ttnn.bfloat8_b) ttnn.validate_input_tensor( operation_name, input_tensor, ranks=(2, 3, 4), - dtypes=(ttnn.bfloat16, ttnn.bfloat8_b), + dtypes=supported_dtypes, layouts=(ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT), can_be_on_device=True, can_be_on_cpu=False,