diff --git a/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py b/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py index c4989099cc48..bd14fd7f5047 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py @@ -14,6 +14,11 @@ from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import TILE_HEIGHT, TILE_WIDTH +torch.set_printoptions( + profile="full", + sci_mode=False, # Disable scientific notation +) + def to_cpu(npu_tensor, shape, *, cpu_layout=ttnn.ROW_MAJOR_LAYOUT): if npu_tensor is None: @@ -35,16 +40,26 @@ def to_npu( return ttnn.from_torch(cpu_tensor, npu_dtype, device=device, layout=npu_layout) -@pytest.mark.skip(reason="assertion fails during binary op input shape comparison because of different padding") -@pytest.mark.parametrize("num_iters_of_each_case", [2]) +# @pytest.mark.skip(reason="assertion fails during binary op input shape comparison because of different padding") +@pytest.mark.parametrize("num_iters_of_each_case", [1]) @pytest.mark.parametrize("range_of_padding", [(0, 21, 10)]) # [0, 10, 20] @pytest.mark.parametrize("range_of_n", [(1, 4)]) @pytest.mark.parametrize("range_of_c", [(1, 4)]) @pytest.mark.parametrize("range_of_ht", [(1, 4)]) @pytest.mark.parametrize("range_of_wt", [(1, 4)]) -@pytest.mark.parametrize("max_norm", [2.0, 1.0, -1.0]) -@pytest.mark.parametrize("norm_type", [2.0, -0.8, 2.2]) -@pytest.mark.parametrize("num_parameters", [32, 128]) +@pytest.mark.parametrize("max_norm", [2.0]) +@pytest.mark.parametrize("norm_type", [2.0]) +@pytest.mark.parametrize("num_parameters", [1]) + +# @pytest.mark.parametrize("num_iters_of_each_case", [2]) +# @pytest.mark.parametrize("range_of_padding", [(0, 21, 10)]) # [0, 10, 20] +# @pytest.mark.parametrize("range_of_n", [(1, 4)]) +# @pytest.mark.parametrize("range_of_c", [(1, 4)]) +# @pytest.mark.parametrize("range_of_ht", [(1, 4)]) +# @pytest.mark.parametrize("range_of_wt", [(1, 4)]) +# @pytest.mark.parametrize("max_norm", [2.0, 1.0, -1.0]) +# @pytest.mark.parametrize("norm_type", [2.0, -0.8, 2.2]) +# @pytest.mark.parametrize("num_parameters", [32, 128]) def test_moreh_clip_grad_norm( num_iters_of_each_case, num_parameters, @@ -60,7 +75,7 @@ def test_moreh_clip_grad_norm( torch.manual_seed(2023) random.seed(2023) - cpu_dtype = torch.float32 + cpu_dtype = torch.bfloat16 npu_dtype = ttnn.bfloat16 cpu_inputs = [] @@ -77,14 +92,21 @@ def test_moreh_clip_grad_norm( padding_w = random.randrange(*range_of_padding) input_shape = ( - n, - c, - ht * TILE_HEIGHT - padding_h, - wt * TILE_WIDTH - padding_w, + 1, + 1, + 1 * TILE_HEIGHT, + 1 * TILE_WIDTH, ) + # input_shape = ( + # n, + # c, + # ht * TILE_HEIGHT - padding_h, + # wt * TILE_WIDTH - padding_w, + # ) param = torch.nn.Parameter(torch.empty(input_shape, dtype=cpu_dtype)) - grad = torch.empty(input_shape, dtype=cpu_dtype).uniform_(0, 2.5) + grad = torch.ones(input_shape, dtype=cpu_dtype) / 10 + # grad = torch.empty(input_shape, dtype=cpu_dtype).uniform_(0, 2.5) param.grad = grad cpu_inputs.append(param) @@ -93,25 +115,30 @@ def test_moreh_clip_grad_norm( cpu_total_norm = torch.nn.utils.clip_grad_norm_(cpu_inputs, max_norm, norm_type) npu_total_norm = ttnn.operations.moreh.clip_grad_norm(npu_inputs, max_norm, norm_type) - - expected_total_norm = cpu_total_norm - actual_total_norm = to_cpu(npu_total_norm, [1, 1, 1, 1]) - - rtol = atol = 0.1 - # Check total_norm - pass_total_norm, out_total_norm = comp_allclose_and_pcc( - actual_total_norm.double(), expected_total_norm.double(), rtol=rtol, atol=atol - ) - logger.debug(f"total_norm's {out_total_norm}") - assert pass_total_norm - - # Check inputs - for i in range(num_parameters): - expected_input_i = cpu_inputs[i].grad.double() - actual_input_i = to_cpu(npu_inputs[i], input_shapes[i]).double() - pass_input_i, out_input_i = comp_allclose_and_pcc(expected_input_i, actual_input_i, rtol=rtol, atol=atol) - logger.debug(f"inputs[{i}]-shape[{input_shapes[i]}]'s {out_input_i}") - assert pass_input_i + to_cpu_result = ttnn.to_torch(npu_total_norm) + to_cpu_result_reshape = to_cpu_result.reshape(-1, 32, 32) + print("cpu_total_norm", cpu_total_norm) + print("to_cpu_result_reshape.shape", to_cpu_result_reshape.shape) + print("to_cpu_result_reshape", to_cpu_result_reshape) + + # expected_total_norm = cpu_total_norm + # actual_total_norm = to_cpu(npu_total_norm, [1, 1, 1, 1]) + + # rtol = atol = 0.1 + # # Check total_norm + # pass_total_norm, out_total_norm = comp_allclose_and_pcc( + # actual_total_norm, expected_total_norm, rtol=rtol, atol=atol + # ) + # logger.debug(f"total_norm's {out_total_norm}") + # assert pass_total_norm + + # # Check inputs + # for i in range(num_parameters): + # expected_input_i = cpu_inputs[i].grad + # actual_input_i = to_cpu(npu_inputs[i], input_shapes[i]) + # pass_input_i, out_input_i = comp_allclose_and_pcc(expected_input_i, actual_input_i, rtol=rtol, atol=atol) + # logger.debug(f"inputs[{i}]-shape[{input_shapes[i]}]'s {out_input_i}") + # assert pass_input_i # @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp index 04bfb4735324..a607ce766ed1 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp @@ -1018,6 +1018,8 @@ ALWI void power_tile_to_cb( constexpr uint32_t onetile = 1; constexpr uint32_t dst0 = 0; + // DUONG: to check the value of dst before log_tile api is called, please comment out lines 1054-1055 + // x^p tile_regs_acquire(); cb_wait_front(cb_x, onetile); @@ -1053,49 +1055,59 @@ ALWI void power_tile_to_cb( log_tile(dst0); tile_regs_commit(); - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_logx); - tile_regs_release(); + // tile_regs_wait(); + // pack_tile_with_dt(dst0, cb_logx); + // tile_regs_release(); cb_pop_front(cb_x, onetile); cb_push_back(cb_logx, onetile); - // exp(log(x) * decimal) - tile_regs_acquire(); - cb_wait_front(cb_logx, onetile); - cb_reserve_back(cb_exp_lxmd, onetile); + // // exp(log(x) * decimal) + // tile_regs_acquire(); + // cb_wait_front(cb_logx, onetile); + // cb_reserve_back(cb_exp_lxmd, onetile); - mul_tiles_init_with_dt(cb_logx, cb_decimal); - mul_tiles(cb_logx, cb_decimal, 0, 0, dst0); + // mul_tiles_init_with_dt(cb_logx, cb_decimal); + // mul_tiles(cb_logx, cb_decimal, 0, 0, dst0); - exp_tile_init(); - exp_tile(dst0); - tile_regs_commit(); + // exp_tile_init(); + // exp_tile(dst0); + // tile_regs_commit(); - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_exp_lxmd); - tile_regs_release(); + // tile_regs_wait(); + // pack_tile_with_dt(dst0, cb_exp_lxmd); + // tile_regs_release(); - cb_pop_front(cb_logx, onetile); - cb_push_back(cb_exp_lxmd, onetile); + // cb_pop_front(cb_logx, onetile); + // cb_push_back(cb_exp_lxmd, onetile); - // x^p * exp(log(x) * decimal)(==(x + decimal)^p) - tile_regs_acquire(); - cb_wait_front(cb_xpow, onetile); - cb_wait_front(cb_exp_lxmd, onetile); - cb_reserve_back(cb_correct_xpow, onetile); + // // x^p * exp(log(x) * decimal)(==(x + decimal)^p) + // tile_regs_acquire(); + // cb_wait_front(cb_xpow, onetile); + // cb_wait_front(cb_exp_lxmd, onetile); + // cb_reserve_back(cb_correct_xpow, onetile); - mul_tiles_init_with_dt(cb_xpow, cb_exp_lxmd); - mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0); - tile_regs_commit(); + // mul_tiles_init_with_dt(cb_xpow, cb_exp_lxmd); + // mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0); + // tile_regs_commit(); + // tile_regs_wait(); + // pack_tile_with_dt(dst0, cb_correct_xpow); + // tile_regs_release(); + + // cb_pop_front(cb_xpow, onetile); + // cb_pop_front(cb_exp_lxmd, onetile); + // cb_push_back(cb_correct_xpow, onetile); + + // TESTING START + + cb_reserve_back(cb_correct_xpow, onetile); tile_regs_wait(); pack_tile_with_dt(dst0, cb_correct_xpow); tile_regs_release(); - - cb_pop_front(cb_xpow, onetile); - cb_pop_front(cb_exp_lxmd, onetile); cb_push_back(cb_correct_xpow, onetile); + + // TESTING END } ALWI void power_tile_with_abs_x_to_cb( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp index c45e48d828f2..882864cfd433 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm.cpp @@ -90,31 +90,31 @@ Tensor MorehClipGradNorm::invoke( norm_type); } - // max_norm / (total_norm + 1e-6) - auto clip_coef = ttnn::multiply(ttnn::add(output_total_norm, 1e-6f), (1 / max_norm)); - // min(clip_coef, 1.0f) - Tensor scalar = creation::create_scalar(1.0f, inputs.at(0).get_dtype(), Layout::TILE, device); - auto clip_coef_clamped = ttnn::minimum(clip_coef, scalar); - scalar.deallocate(); - - // Run Step 3 - // Inplace update inputs(inputs *= clip_coef_clamped) - uint32_t start_input_idx{0}; - num_inputs = total_num_inputs; - for (uint32_t i = 0; i < num_iter; ++i) { - const auto num_inputs_at_this_iter = std::min(num_inputs, max_num_inputs); - - auto input_tensors = std::vector( - inputs.begin() + start_input_idx, inputs.begin() + start_input_idx + num_inputs_at_this_iter); - - ttnn::prim::moreh_clip_grad_norm_step3( - input_tensors, clip_coef_clamped, memory_config, compute_kernel_config_val); - - if (i < (num_iter - 1)) { - start_input_idx += num_inputs_at_this_iter; - num_inputs -= num_inputs_at_this_iter; - } - } + // // max_norm / (total_norm + 1e-6) + // auto clip_coef = ttnn::multiply(ttnn::add(output_total_norm, 1e-6f), (1 / max_norm)); + // // min(clip_coef, 1.0f) + // Tensor scalar = creation::create_scalar(1.0f, inputs.at(0).get_dtype(), Layout::TILE, device); + // auto clip_coef_clamped = ttnn::minimum(clip_coef, scalar); + // scalar.deallocate(); + + // // Run Step 3 + // // Inplace update inputs(inputs *= clip_coef_clamped) + // uint32_t start_input_idx{0}; + // num_inputs = total_num_inputs; + // for (uint32_t i = 0; i < num_iter; ++i) { + // const auto num_inputs_at_this_iter = std::min(num_inputs, max_num_inputs); + + // auto input_tensors = std::vector( + // inputs.begin() + start_input_idx, inputs.begin() + start_input_idx + num_inputs_at_this_iter); + + // ttnn::prim::moreh_clip_grad_norm_step3( + // input_tensors, clip_coef_clamped, memory_config, compute_kernel_config_val); + + // if (i < (num_iter - 1)) { + // start_input_idx += num_inputs_at_this_iter; + // num_inputs -= num_inputs_at_this_iter; + // } + // } return output_total_norm; } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp index 61590ae399e5..dc0631e401fc 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp @@ -58,9 +58,8 @@ void MAIN { for (uint32_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) { // Comput cb_xabs and mask(optional) // |x| - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_x, onetile); // comes from the reader - cb_reserve_back(cb_xabs, onetile); copy_tile_init(); copy_tile(cb_x, 0, dst0); @@ -83,61 +82,70 @@ void MAIN { abs_tile_init(); abs_tile(dst0); + cb_pop_front(cb_x, onetile); + tile_regs_commit(); + tile_regs_wait(); + cb_reserve_back(cb_xabs, onetile); pack_tile(dst0, cb_xabs); - cb_pop_front(cb_x, onetile); cb_push_back(cb_xabs, onetile); - REL(); + tile_regs_release(); // |x + decimal|^p power_tile_to_cb(cb_xabs, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_correct_xpow, p, p_is_negative); if (tile_idx == 0) { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_correct_xpow, onetile); cb_reserve_back(cb_xpowadd, onetile); copy_tile_init(); copy_tile(cb_correct_xpow, 0, dst0); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_xpowadd); cb_pop_front(cb_correct_xpow, onetile); cb_push_back(cb_xpowadd, onetile); - REL(); + tile_regs_release(); } else { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_correct_xpow, onetile); cb_wait_front(cb_xpowadd, onetile); cb_reserve_back(cb_xpowadd, onetile); add_tiles_init(); add_tiles(cb_correct_xpow, cb_xpowadd, 0, 0, dst0); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_xpowadd); cb_pop_front(cb_correct_xpow, onetile); cb_pop_front(cb_xpowadd, onetile); cb_push_back(cb_xpowadd, onetile); - REL(); + tile_regs_release(); } } // Compute cb_y - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_xpowadd, onetile); cb_reserve_back(cb_y, onetile); reduce_init_delta(); reduce_tile(cb_xpowadd, cb_one, 0, 0, dst0); reduce_revert_delta(); + tile_regs_commit(); + tile_regs_wait(); pack_tile(dst0, cb_y); cb_pop_front(cb_xpowadd, onetile); cb_push_back(cb_y, onetile); - REL(); + tile_regs_release(); cb_pop_front(cb_decimal, onetile); cb_pop_front(cb_one, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/moreh_clip_grad_norm_step1_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/moreh_clip_grad_norm_step1_program_factory.cpp index 10897c883305..ce4d7be33252 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/moreh_clip_grad_norm_step1_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/moreh_clip_grad_norm_step1_program_factory.cpp @@ -51,6 +51,7 @@ MorehClipGradNormStep1Operation::ProgramFactory::create( } auto [p, decimal, p_is_negative] = get_p_decimal_p_is_negative(norm_type); + std::cout << "step1: p: " << p << " ;decimal: " << decimal << " ;p_is_negative: " << p_is_negative << std::endl; //////////////////////////////////////////////////////////////////////////// // Core Setup diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp index ee089be698a8..e46916d6b523 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp @@ -40,38 +40,39 @@ void MAIN { // Compute cb_x for (uint32_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) { if (tile_idx == 0) { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_input, onetile); // comes from the reader - cb_reserve_back(cb_x, onetile); copy_tile_init(); copy_tile(cb_input, 0, dst0); + cb_pop_front(cb_input, onetile); + tile_regs_commit(); + tile_regs_wait(); + cb_reserve_back(cb_x, onetile); pack_tile(dst0, cb_x); - - cb_pop_front(cb_input, onetile); cb_push_back(cb_x, onetile); - REL(); + tile_regs_release(); } else { - ACQ(); + tile_regs_acquire(); cb_wait_front(cb_input, onetile); // comes from the reader cb_wait_front(cb_x, onetile); - cb_reserve_back(cb_x, onetile); - add_tiles_init(); add_tiles(cb_input, cb_x, 0, 0, dst0); + cb_pop_front(cb_x, onetile); + cb_pop_front(cb_input, onetile); + tile_regs_commit(); + tile_regs_wait(); + cb_reserve_back(cb_x, onetile); pack_tile(dst0, cb_x); - - cb_pop_front(cb_input, onetile); - cb_pop_front(cb_x, onetile); cb_push_back(cb_x, onetile); - REL(); + tile_regs_release(); } } + // Maybe bug here. Produce very incorrect result // x^p power_tile_to_cb(cb_x, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_y, p, p_is_negative); - } // void MAIN } // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_program_factory.cpp index cbc83d3e7fa2..0841272dd364 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/moreh_clip_grad_norm_step2_program_factory.cpp @@ -41,6 +41,7 @@ MorehClipGradNormStep2Operation::ProgramFactory::create( const auto num_tiles = tmp_pow_sum.volume() / tt::constants::TILE_HW; auto [p, decimal, p_is_negative] = get_p_decimal_p_is_negative(1.0f / norm_type); + std::cout << "step2: p: " << p << " ;decimal: " << decimal << " ;p_is_negative: " << p_is_negative << std::endl; //////////////////////////////////////////////////////////////////////////// // Core Setup