diff --git a/tt-train/tests/model/gpt2s_test.cpp b/tt-train/tests/model/gpt2s_test.cpp index 5791f4628c31..3d1d91101ede 100644 --- a/tt-train/tests/model/gpt2s_test.cpp +++ b/tt-train/tests/model/gpt2s_test.cpp @@ -34,27 +34,27 @@ TEST(GPT2SBatch64Test, Matmul) { {{{64, 12, 1024, 1024}, {64, 12, 1024, 64}, false, false}, ExpectedResult::OK}, {{{768, 65536}, {65536, 96}, false, false}, ExpectedResult::OK}, {{{65536, 768}, {65536, 96}, true, false}, ExpectedResult::OK}, - {{{65536, 96}, {1, 1, 96, 768}, false, false}, ExpectedResult::ERROR}, - {{{65536, 96}, {1, 1, 768, 96}, false, true}, ExpectedResult::ERROR}, + {{{65536, 96}, {1, 1, 96, 768}, false, false}, ExpectedResult::OK}, + {{{65536, 96}, {1, 1, 768, 96}, false, true}, ExpectedResult::OK}, {{{3072, 65536}, {65536, 768}, false, false}, ExpectedResult::OK}, {{{65536, 3072}, {65536, 768}, true, false}, ExpectedResult::OK}, - {{{65536, 768}, {1, 1, 768, 3072}, false, false}, ExpectedResult::ERROR}, - {{{65536, 768}, {1, 1, 3072, 768}, false, true}, ExpectedResult::ERROR}, + {{{65536, 768}, {1, 1, 768, 3072}, false, false}, ExpectedResult::OK}, + {{{65536, 768}, {1, 1, 3072, 768}, false, true}, ExpectedResult::OK}, {{{768, 65536}, {65536, 3072}, false, false}, ExpectedResult::OK}, {{{65536, 768}, {65536, 3072}, true, false}, ExpectedResult::OK}, - {{{65536, 3072}, {1, 1, 3072, 768}, false, false}, ExpectedResult::ERROR}, - {{{65536, 3072}, {1, 1, 768, 3072}, false, true}, ExpectedResult::ERROR}, - {{{65536, 3072}, {3072, 768}, false, false}, ExpectedResult::ERROR}, - {{{65536, 3072}, {768, 3072}, false, true}, ExpectedResult::ERROR}, + {{{65536, 3072}, {1, 1, 3072, 768}, false, false}, ExpectedResult::OK}, + {{{65536, 3072}, {1, 1, 768, 3072}, false, true}, ExpectedResult::OK}, + {{{65536, 3072}, {3072, 768}, false, false}, ExpectedResult::OK}, + {{{65536, 3072}, {768, 3072}, false, true}, ExpectedResult::OK}, {{{768, 65536}, {65536, 768}, false, false}, ExpectedResult::OK}, {{{65536, 768}, {65536, 768}, true, false}, ExpectedResult::OK}, - {{{65536, 768}, {1, 1, 768, 768}, false, false}, ExpectedResult::ERROR}, - {{{768, 65536}, {1, 1, 768, 768}, true, false}, ExpectedResult::ERROR}, + {{{65536, 768}, {1, 1, 768, 768}, false, false}, ExpectedResult::OK}, + {{{768, 65536}, {1, 1, 768, 768}, true, false}, ExpectedResult::OK}, {{{768, 65536}, {65536, 2304}, false, false}, ExpectedResult::OK}, {{{65536, 768}, {65536, 2304}, true, false}, ExpectedResult::OK}, - {{{65536, 768}, {768, 50257}, false, false}, ExpectedResult::ERROR}, - {{{65536, 768}, {50304, 768}, false, true}, ExpectedResult::ERROR}, - {{{65536, 50304}, {50304, 768}, false, false}, ExpectedResult::ERROR}, + {{{65536, 768}, {768, 50257}, false, false}, ExpectedResult::OK}, + {{{65536, 768}, {50304, 768}, false, true}, ExpectedResult::OK}, + {{{65536, 50304}, {50304, 768}, false, false}, ExpectedResult::OK}, }; auto run_matmul = [](auto& a, auto& b, bool transpose_a, bool transpose_b) { diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index e232f73b4458..fa2df7b9454b 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -186,7 +186,7 @@ inline uint32_t get_estimated_size_of_cbs( // out CB: per_core_M * per_core_N // Ignore optional intermediate CB because not needed when need to create a program config. uint32_t in0_size = per_core_M * in0_block_w * 2 * in0_single_tile_size; - uint32_t in1_size = per_core_M * in0_block_w * 2 * in1_single_tile_size; + uint32_t in1_size = per_core_N * in0_block_w * 2 * in1_single_tile_size; uint32_t out_size = per_core_M * per_core_N * output_single_tile_size; uint32_t interm_size = per_core_M * per_core_N * interm_single_tile_size; return in0_size + in1_size + out_size + interm_size; @@ -279,8 +279,8 @@ inline std::vector get_multi_dim_per_core_factor( return {per_core_M, per_core_N, in0_block_w}; } - std::vector m_factors; - std::vector n_factors; + std::vector m_factors = {per_core_M, 1}; + std::vector n_factors = {per_core_N, 1}; for (uint32_t per_core_factor_m = per_core_M / 2; per_core_factor_m > 1; per_core_factor_m--) { if (per_core_M % per_core_factor_m == 0) { m_factors.push_back(per_core_factor_m); @@ -288,7 +288,7 @@ inline std::vector get_multi_dim_per_core_factor( } for (uint32_t per_core_factor_n = per_core_N / 2; per_core_factor_n > 1; per_core_factor_n--) { if (per_core_N % per_core_factor_n == 0) { - m_factors.push_back(per_core_factor_n); + n_factors.push_back(per_core_factor_n); } } // Insert into ordered map, over write entry if new one is closer to a square (smallest ratio closest to 1). @@ -325,7 +325,7 @@ inline std::vector get_multi_dim_per_core_factor( uint32_t per_core_factor_m = std::get<0>(it->second); uint32_t per_core_factor_n = std::get<1>(it->second); - uint32_t size = get_estimated_size_of_cbs( + size = get_estimated_size_of_cbs( per_core_factor_m, per_core_factor_n, per_core_factor_k,