From f49299c58c1357f12caeb285b5a040789b398e6c Mon Sep 17 00:00:00 2001 From: Shwetank Singh Date: Sun, 15 Dec 2024 07:59:40 +0000 Subject: [PATCH] #0: testing --- .../demos/convnet_mnist/tests/test_performance.py | 2 +- .../conv/conv2d/prepare_conv2d_weights.cpp | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/models/demos/convnet_mnist/tests/test_performance.py b/models/demos/convnet_mnist/tests/test_performance.py index e0aa4b052ab2..d6ccf0717d72 100644 --- a/models/demos/convnet_mnist/tests/test_performance.py +++ b/models/demos/convnet_mnist/tests/test_performance.py @@ -119,7 +119,7 @@ def test_perf_device_bare_metal_convnet_mnist(batch_size, expected_perf): subdir = "ttnn_convnet_mnist" num_iterations = 1 margin = 0.03 - expected_perf = 1753.5 if is_grayskull() else 2705.5 + expected_perf = 1800 if is_grayskull() else 2800.5 command = f"pytest tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 7641ad3b4a06..d22ff4a6a513 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -30,11 +30,11 @@ void validate_bias_tensor(const ttnn::Tensor& bias_tensor) { void validate_weights_format(const std::string& weights_format) { TT_FATAL(weights_format.size() == 4, "weights_format must have exactly 4 characters"); - TT_ASSERT(weights_format.find("O") != string::npos, "weights_format must contain \"O\""); - TT_ASSERT(weights_format.find("I") != string::npos, "weights_format must contain \"I\""); - TT_ASSERT(weights_format.find("H") != string::npos, "weights_format must contain \"H\""); - TT_ASSERT(weights_format.find("W") != string::npos, "weights_format must contain \"W\""); - TT_ASSERT(weights_format == "OIHW", "Conv2d weights format must be \"OIHW\""); + TT_FATAL(weights_format.find("O") != string::npos, "weights_format must contain \"O\""); + TT_FATAL(weights_format.find("I") != string::npos, "weights_format must contain \"I\""); + TT_FATAL(weights_format.find("H") != string::npos, "weights_format must contain \"H\""); + TT_FATAL(weights_format.find("W") != string::npos, "weights_format must contain \"W\""); + TT_FATAL(weights_format == "OIHW", "Conv2d weights format must be \"OIHW\""); } template @@ -43,9 +43,7 @@ bool check_non_tile_mul_width( const Conv2dConfig& conv_config, const uint32_t in_channels ){ - ShardOrientation shard_orientation = - conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - auto num_cores_c = shard_orientation == ShardOrientation::COL_MAJOR ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x; + auto num_cores_c = conv_config.transpose_shards ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x; auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2; bool is_non_tile_mul_width = (conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && conv_config.act_block_h_override == 0 &&