Skip to content

Commit

Permalink
#15690: Debug commit
Browse files Browse the repository at this point in the history
  • Loading branch information
DuongQLee committed Dec 11, 2024
1 parent 2f59d5e commit 8b7b70c
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 106 deletions.
87 changes: 57 additions & 30 deletions tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import TILE_HEIGHT, TILE_WIDTH

torch.set_printoptions(
profile="full",
sci_mode=False, # Disable scientific notation
)


def to_cpu(npu_tensor, shape, *, cpu_layout=ttnn.ROW_MAJOR_LAYOUT):
if npu_tensor is None:
Expand All @@ -35,16 +40,26 @@ def to_npu(
return ttnn.from_torch(cpu_tensor, npu_dtype, device=device, layout=npu_layout)


@pytest.mark.skip(reason="assertion fails during binary op input shape comparison because of different padding")
@pytest.mark.parametrize("num_iters_of_each_case", [2])
# @pytest.mark.skip(reason="assertion fails during binary op input shape comparison because of different padding")
@pytest.mark.parametrize("num_iters_of_each_case", [1])
@pytest.mark.parametrize("range_of_padding", [(0, 21, 10)]) # [0, 10, 20]
@pytest.mark.parametrize("range_of_n", [(1, 4)])
@pytest.mark.parametrize("range_of_c", [(1, 4)])
@pytest.mark.parametrize("range_of_ht", [(1, 4)])
@pytest.mark.parametrize("range_of_wt", [(1, 4)])
@pytest.mark.parametrize("max_norm", [2.0, 1.0, -1.0])
@pytest.mark.parametrize("norm_type", [2.0, -0.8, 2.2])
@pytest.mark.parametrize("num_parameters", [32, 128])
@pytest.mark.parametrize("max_norm", [2.0])
@pytest.mark.parametrize("norm_type", [2.0])
@pytest.mark.parametrize("num_parameters", [1])

# @pytest.mark.parametrize("num_iters_of_each_case", [2])
# @pytest.mark.parametrize("range_of_padding", [(0, 21, 10)]) # [0, 10, 20]
# @pytest.mark.parametrize("range_of_n", [(1, 4)])
# @pytest.mark.parametrize("range_of_c", [(1, 4)])
# @pytest.mark.parametrize("range_of_ht", [(1, 4)])
# @pytest.mark.parametrize("range_of_wt", [(1, 4)])
# @pytest.mark.parametrize("max_norm", [2.0, 1.0, -1.0])
# @pytest.mark.parametrize("norm_type", [2.0, -0.8, 2.2])
# @pytest.mark.parametrize("num_parameters", [32, 128])
def test_moreh_clip_grad_norm(
num_iters_of_each_case,
num_parameters,
Expand All @@ -60,7 +75,7 @@ def test_moreh_clip_grad_norm(
torch.manual_seed(2023)
random.seed(2023)

cpu_dtype = torch.float32
cpu_dtype = torch.bfloat16
npu_dtype = ttnn.bfloat16

cpu_inputs = []
Expand All @@ -77,14 +92,21 @@ def test_moreh_clip_grad_norm(
padding_w = random.randrange(*range_of_padding)

input_shape = (
n,
c,
ht * TILE_HEIGHT - padding_h,
wt * TILE_WIDTH - padding_w,
1,
1,
1 * TILE_HEIGHT,
1 * TILE_WIDTH,
)
# input_shape = (
# n,
# c,
# ht * TILE_HEIGHT - padding_h,
# wt * TILE_WIDTH - padding_w,
# )

param = torch.nn.Parameter(torch.empty(input_shape, dtype=cpu_dtype))
grad = torch.empty(input_shape, dtype=cpu_dtype).uniform_(0, 2.5)
grad = torch.ones(input_shape, dtype=cpu_dtype) / 10
# grad = torch.empty(input_shape, dtype=cpu_dtype).uniform_(0, 2.5)
param.grad = grad

cpu_inputs.append(param)
Expand All @@ -93,25 +115,30 @@ def test_moreh_clip_grad_norm(

cpu_total_norm = torch.nn.utils.clip_grad_norm_(cpu_inputs, max_norm, norm_type)
npu_total_norm = ttnn.operations.moreh.clip_grad_norm(npu_inputs, max_norm, norm_type)

expected_total_norm = cpu_total_norm
actual_total_norm = to_cpu(npu_total_norm, [1, 1, 1, 1])

rtol = atol = 0.1
# Check total_norm
pass_total_norm, out_total_norm = comp_allclose_and_pcc(
actual_total_norm.double(), expected_total_norm.double(), rtol=rtol, atol=atol
)
logger.debug(f"total_norm's {out_total_norm}")
assert pass_total_norm

# Check inputs
for i in range(num_parameters):
expected_input_i = cpu_inputs[i].grad.double()
actual_input_i = to_cpu(npu_inputs[i], input_shapes[i]).double()
pass_input_i, out_input_i = comp_allclose_and_pcc(expected_input_i, actual_input_i, rtol=rtol, atol=atol)
logger.debug(f"inputs[{i}]-shape[{input_shapes[i]}]'s {out_input_i}")
assert pass_input_i
to_cpu_result = ttnn.to_torch(npu_total_norm)
to_cpu_result_reshape = to_cpu_result.reshape(-1, 32, 32)
print("cpu_total_norm", cpu_total_norm)
print("to_cpu_result_reshape.shape", to_cpu_result_reshape.shape)
print("to_cpu_result_reshape", to_cpu_result_reshape)

# expected_total_norm = cpu_total_norm
# actual_total_norm = to_cpu(npu_total_norm, [1, 1, 1, 1])

# rtol = atol = 0.1
# # Check total_norm
# pass_total_norm, out_total_norm = comp_allclose_and_pcc(
# actual_total_norm, expected_total_norm, rtol=rtol, atol=atol
# )
# logger.debug(f"total_norm's {out_total_norm}")
# assert pass_total_norm

# # Check inputs
# for i in range(num_parameters):
# expected_input_i = cpu_inputs[i].grad
# actual_input_i = to_cpu(npu_inputs[i], input_shapes[i])
# pass_input_i, out_input_i = comp_allclose_and_pcc(expected_input_i, actual_input_i, rtol=rtol, atol=atol)
# logger.debug(f"inputs[{i}]-shape[{input_shapes[i]}]'s {out_input_i}")
# assert pass_input_i


# @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH")
Expand Down
68 changes: 40 additions & 28 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,8 @@ ALWI void power_tile_to_cb(
constexpr uint32_t onetile = 1;
constexpr uint32_t dst0 = 0;

// DUONG: to check the value of dst before log_tile api is called, please comment out lines 1054-1055

// x^p
tile_regs_acquire();
cb_wait_front(cb_x, onetile);
Expand Down Expand Up @@ -1053,49 +1055,59 @@ ALWI void power_tile_to_cb(
log_tile(dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile_with_dt(dst0, cb_logx);
tile_regs_release();
// tile_regs_wait();
// pack_tile_with_dt(dst0, cb_logx);
// tile_regs_release();

cb_pop_front(cb_x, onetile);
cb_push_back(cb_logx, onetile);

// exp(log(x) * decimal)
tile_regs_acquire();
cb_wait_front(cb_logx, onetile);
cb_reserve_back(cb_exp_lxmd, onetile);
// // exp(log(x) * decimal)
// tile_regs_acquire();
// cb_wait_front(cb_logx, onetile);
// cb_reserve_back(cb_exp_lxmd, onetile);

mul_tiles_init_with_dt(cb_logx, cb_decimal);
mul_tiles(cb_logx, cb_decimal, 0, 0, dst0);
// mul_tiles_init_with_dt(cb_logx, cb_decimal);
// mul_tiles(cb_logx, cb_decimal, 0, 0, dst0);

exp_tile_init();
exp_tile(dst0);
tile_regs_commit();
// exp_tile_init();
// exp_tile(dst0);
// tile_regs_commit();

tile_regs_wait();
pack_tile_with_dt(dst0, cb_exp_lxmd);
tile_regs_release();
// tile_regs_wait();
// pack_tile_with_dt(dst0, cb_exp_lxmd);
// tile_regs_release();

cb_pop_front(cb_logx, onetile);
cb_push_back(cb_exp_lxmd, onetile);
// cb_pop_front(cb_logx, onetile);
// cb_push_back(cb_exp_lxmd, onetile);

// x^p * exp(log(x) * decimal)(==(x + decimal)^p)
tile_regs_acquire();
cb_wait_front(cb_xpow, onetile);
cb_wait_front(cb_exp_lxmd, onetile);
cb_reserve_back(cb_correct_xpow, onetile);
// // x^p * exp(log(x) * decimal)(==(x + decimal)^p)
// tile_regs_acquire();
// cb_wait_front(cb_xpow, onetile);
// cb_wait_front(cb_exp_lxmd, onetile);
// cb_reserve_back(cb_correct_xpow, onetile);

mul_tiles_init_with_dt(cb_xpow, cb_exp_lxmd);
mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0);
tile_regs_commit();
// mul_tiles_init_with_dt(cb_xpow, cb_exp_lxmd);
// mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0);
// tile_regs_commit();

// tile_regs_wait();
// pack_tile_with_dt(dst0, cb_correct_xpow);
// tile_regs_release();

// cb_pop_front(cb_xpow, onetile);
// cb_pop_front(cb_exp_lxmd, onetile);
// cb_push_back(cb_correct_xpow, onetile);

// TESTING START

cb_reserve_back(cb_correct_xpow, onetile);
tile_regs_wait();
pack_tile_with_dt(dst0, cb_correct_xpow);
tile_regs_release();

cb_pop_front(cb_xpow, onetile);
cb_pop_front(cb_exp_lxmd, onetile);
cb_push_back(cb_correct_xpow, onetile);

// TESTING END
}

ALWI void power_tile_with_abs_x_to_cb(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,31 +90,31 @@ Tensor MorehClipGradNorm::invoke(
norm_type);
}

// max_norm / (total_norm + 1e-6)
auto clip_coef = ttnn::multiply(ttnn::add(output_total_norm, 1e-6f), (1 / max_norm));
// min(clip_coef, 1.0f)
Tensor scalar = creation::create_scalar(1.0f, inputs.at(0).get_dtype(), Layout::TILE, device);
auto clip_coef_clamped = ttnn::minimum(clip_coef, scalar);
scalar.deallocate();

// Run Step 3
// Inplace update inputs(inputs *= clip_coef_clamped)
uint32_t start_input_idx{0};
num_inputs = total_num_inputs;
for (uint32_t i = 0; i < num_iter; ++i) {
const auto num_inputs_at_this_iter = std::min(num_inputs, max_num_inputs);

auto input_tensors = std::vector<Tensor>(
inputs.begin() + start_input_idx, inputs.begin() + start_input_idx + num_inputs_at_this_iter);

ttnn::prim::moreh_clip_grad_norm_step3(
input_tensors, clip_coef_clamped, memory_config, compute_kernel_config_val);

if (i < (num_iter - 1)) {
start_input_idx += num_inputs_at_this_iter;
num_inputs -= num_inputs_at_this_iter;
}
}
// // max_norm / (total_norm + 1e-6)
// auto clip_coef = ttnn::multiply(ttnn::add(output_total_norm, 1e-6f), (1 / max_norm));
// // min(clip_coef, 1.0f)
// Tensor scalar = creation::create_scalar(1.0f, inputs.at(0).get_dtype(), Layout::TILE, device);
// auto clip_coef_clamped = ttnn::minimum(clip_coef, scalar);
// scalar.deallocate();

// // Run Step 3
// // Inplace update inputs(inputs *= clip_coef_clamped)
// uint32_t start_input_idx{0};
// num_inputs = total_num_inputs;
// for (uint32_t i = 0; i < num_iter; ++i) {
// const auto num_inputs_at_this_iter = std::min(num_inputs, max_num_inputs);

// auto input_tensors = std::vector<Tensor>(
// inputs.begin() + start_input_idx, inputs.begin() + start_input_idx + num_inputs_at_this_iter);

// ttnn::prim::moreh_clip_grad_norm_step3(
// input_tensors, clip_coef_clamped, memory_config, compute_kernel_config_val);

// if (i < (num_iter - 1)) {
// start_input_idx += num_inputs_at_this_iter;
// num_inputs -= num_inputs_at_this_iter;
// }
// }

return output_total_norm;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ void MAIN {
for (uint32_t tile_idx = 0; tile_idx < num_tiles; tile_idx++) {
// Comput cb_xabs and mask(optional)
// |x|
ACQ();
tile_regs_acquire();
cb_wait_front(cb_x, onetile); // comes from the reader
cb_reserve_back(cb_xabs, onetile);

copy_tile_init();
copy_tile(cb_x, 0, dst0);
Expand All @@ -83,61 +82,70 @@ void MAIN {

abs_tile_init();
abs_tile(dst0);
cb_pop_front(cb_x, onetile);
tile_regs_commit();

tile_regs_wait();
cb_reserve_back(cb_xabs, onetile);
pack_tile(dst0, cb_xabs);

cb_pop_front(cb_x, onetile);
cb_push_back(cb_xabs, onetile);
REL();
tile_regs_release();

// |x + decimal|^p
power_tile_to_cb(cb_xabs, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_correct_xpow, p, p_is_negative);

if (tile_idx == 0) {
ACQ();
tile_regs_acquire();
cb_wait_front(cb_correct_xpow, onetile);
cb_reserve_back(cb_xpowadd, onetile);

copy_tile_init();
copy_tile(cb_correct_xpow, 0, dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, cb_xpowadd);

cb_pop_front(cb_correct_xpow, onetile);
cb_push_back(cb_xpowadd, onetile);
REL();
tile_regs_release();
} else {
ACQ();
tile_regs_acquire();
cb_wait_front(cb_correct_xpow, onetile);
cb_wait_front(cb_xpowadd, onetile);
cb_reserve_back(cb_xpowadd, onetile);

add_tiles_init();
add_tiles(cb_correct_xpow, cb_xpowadd, 0, 0, dst0);
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, cb_xpowadd);

cb_pop_front(cb_correct_xpow, onetile);
cb_pop_front(cb_xpowadd, onetile);
cb_push_back(cb_xpowadd, onetile);
REL();
tile_regs_release();
}
}

// Compute cb_y
ACQ();
tile_regs_acquire();
cb_wait_front(cb_xpowadd, onetile);
cb_reserve_back(cb_y, onetile);

reduce_init_delta<false>();
reduce_tile(cb_xpowadd, cb_one, 0, 0, dst0);
reduce_revert_delta();
tile_regs_commit();

tile_regs_wait();
pack_tile(dst0, cb_y);

cb_pop_front(cb_xpowadd, onetile);
cb_push_back(cb_y, onetile);
REL();
tile_regs_release();

cb_pop_front(cb_decimal, onetile);
cb_pop_front(cb_one, onetile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ MorehClipGradNormStep1Operation::ProgramFactory::create(
}

auto [p, decimal, p_is_negative] = get_p_decimal_p_is_negative(norm_type);
std::cout << "step1: p: " << p << " ;decimal: " << decimal << " ;p_is_negative: " << p_is_negative << std::endl;

////////////////////////////////////////////////////////////////////////////
// Core Setup
Expand Down
Loading

0 comments on commit 8b7b70c

Please sign in to comment.