Skip to content

Commit

Permalink
#0: testing
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Dec 15, 2024
1 parent 46a3d00 commit f49299c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion models/demos/convnet_mnist/tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_perf_device_bare_metal_convnet_mnist(batch_size, expected_perf):
subdir = "ttnn_convnet_mnist"
num_iterations = 1
margin = 0.03
expected_perf = 1753.5 if is_grayskull() else 2705.5
expected_perf = 1800 if is_grayskull() else 2800.5

command = f"pytest tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand Down
14 changes: 6 additions & 8 deletions ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ void validate_bias_tensor(const ttnn::Tensor& bias_tensor) {

void validate_weights_format(const std::string& weights_format) {
TT_FATAL(weights_format.size() == 4, "weights_format must have exactly 4 characters");
TT_ASSERT(weights_format.find("O") != string::npos, "weights_format must contain \"O\"");
TT_ASSERT(weights_format.find("I") != string::npos, "weights_format must contain \"I\"");
TT_ASSERT(weights_format.find("H") != string::npos, "weights_format must contain \"H\"");
TT_ASSERT(weights_format.find("W") != string::npos, "weights_format must contain \"W\"");
TT_ASSERT(weights_format == "OIHW", "Conv2d weights format must be \"OIHW\"");
TT_FATAL(weights_format.find("O") != string::npos, "weights_format must contain \"O\"");
TT_FATAL(weights_format.find("I") != string::npos, "weights_format must contain \"I\"");
TT_FATAL(weights_format.find("H") != string::npos, "weights_format must contain \"H\"");
TT_FATAL(weights_format.find("W") != string::npos, "weights_format must contain \"W\"");
TT_FATAL(weights_format == "OIHW", "Conv2d weights format must be \"OIHW\"");
}

template <typename T>
Expand All @@ -43,9 +43,7 @@ bool check_non_tile_mul_width(
const Conv2dConfig& conv_config,
const uint32_t in_channels
){
ShardOrientation shard_orientation =
conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR;
auto num_cores_c = shard_orientation == ShardOrientation::COL_MAJOR ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x;
auto num_cores_c = conv_config.transpose_shards ? device->compute_with_storage_grid_size().y : device->compute_with_storage_grid_size().x;
auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2;
bool is_non_tile_mul_width =
(conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && conv_config.act_block_h_override == 0 &&
Expand Down

0 comments on commit f49299c

Please sign in to comment.