diff --git a/tests/ttnn/sweep_tests/sweeps/mul.py b/tests/ttnn/sweep_tests/sweeps/mul.py index 7a68c4ead69..184d9ece529 100644 --- a/tests/ttnn/sweep_tests/sweeps/mul.py +++ b/tests/ttnn/sweep_tests/sweeps/mul.py @@ -17,8 +17,8 @@ "height": [384, 1024], "width": [1024, 4096], "broadcast": [None, "h", "w", "hw"], - "input_a_dtype": [ttnn.bfloat16], - "input_b_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], "input_a_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], "input_b_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG], @@ -27,9 +27,13 @@ } -def skip(*, broadcast, input_b_layout, **_) -> Tuple[bool, Optional[str]]: +def skip(*, broadcast, input_a_layout, input_b_layout, input_a_dtype, input_b_dtype, **_) -> Tuple[bool, Optional[str]]: if broadcast in {"w", "hw"} and input_b_layout == ttnn.ROW_MAJOR_LAYOUT: return True, "Broadcasting along width is not supported for row major layout" + if input_a_layout == ttnn.ROW_MAJOR_LAYOUT or input_b_layout == ttnn.ROW_MAJOR_LAYOUT: + return True, "Row major layout not supported" + if input_a_dtype != input_b_dtype: + return True, "Input tensors should be of same dtype" return False, None diff --git a/tests/ttnn/sweep_tests/sweeps/sub.py b/tests/ttnn/sweep_tests/sweeps/sub.py index 7a68c4ead69..184d9ece529 100644 --- a/tests/ttnn/sweep_tests/sweeps/sub.py +++ b/tests/ttnn/sweep_tests/sweeps/sub.py @@ -17,8 +17,8 @@ "height": [384, 1024], "width": [1024, 4096], "broadcast": [None, "h", "w", "hw"], - "input_a_dtype": [ttnn.bfloat16], - "input_b_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], "input_a_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], "input_b_layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG], @@ -27,9 +27,13 @@ } -def skip(*, broadcast, input_b_layout, **_) -> Tuple[bool, Optional[str]]: +def skip(*, broadcast, input_a_layout, input_b_layout, input_a_dtype, input_b_dtype, **_) -> Tuple[bool, Optional[str]]: if broadcast in {"w", "hw"} and input_b_layout == ttnn.ROW_MAJOR_LAYOUT: return True, "Broadcasting along width is not supported for row major layout" + if input_a_layout == ttnn.ROW_MAJOR_LAYOUT or input_b_layout == ttnn.ROW_MAJOR_LAYOUT: + return True, "Row major layout not supported" + if input_a_dtype != input_b_dtype: + return True, "Input tensors should be of same dtype" return False, None