Skip to content

Commit

Permalink
#15690: Fix clip_grad_norm step 3 and refactor some deprecated api
Browse files Browse the repository at this point in the history
  • Loading branch information
DuongQLee committed Dec 11, 2024
1 parent 8b7b70c commit 3f2551e
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 178 deletions.
169 changes: 66 additions & 103 deletions tests/ttnn/unit_tests/operations/test_moreh_clip_grad_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,6 @@

from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import TILE_HEIGHT, TILE_WIDTH

torch.set_printoptions(
profile="full",
sci_mode=False, # Disable scientific notation
)


def to_cpu(npu_tensor, shape, *, cpu_layout=ttnn.ROW_MAJOR_LAYOUT):
if npu_tensor is None:
return None
cpu_tensor = npu_tensor.cpu().to(cpu_layout).unpad_from_tile(shape).to_torch()
return cpu_tensor


def to_npu(
cpu_tensor,
Expand All @@ -40,26 +28,15 @@ def to_npu(
return ttnn.from_torch(cpu_tensor, npu_dtype, device=device, layout=npu_layout)


# @pytest.mark.skip(reason="assertion fails during binary op input shape comparison because of different padding")
@pytest.mark.parametrize("num_iters_of_each_case", [1])
@pytest.mark.parametrize("num_iters_of_each_case", [2])
@pytest.mark.parametrize("range_of_padding", [(0, 21, 10)]) # [0, 10, 20]
@pytest.mark.parametrize("range_of_n", [(1, 4)])
@pytest.mark.parametrize("range_of_c", [(1, 4)])
@pytest.mark.parametrize("range_of_ht", [(1, 4)])
@pytest.mark.parametrize("range_of_wt", [(1, 4)])
@pytest.mark.parametrize("max_norm", [2.0])
@pytest.mark.parametrize("norm_type", [2.0])
@pytest.mark.parametrize("num_parameters", [1])

# @pytest.mark.parametrize("num_iters_of_each_case", [2])
# @pytest.mark.parametrize("range_of_padding", [(0, 21, 10)]) # [0, 10, 20]
# @pytest.mark.parametrize("range_of_n", [(1, 4)])
# @pytest.mark.parametrize("range_of_c", [(1, 4)])
# @pytest.mark.parametrize("range_of_ht", [(1, 4)])
# @pytest.mark.parametrize("range_of_wt", [(1, 4)])
# @pytest.mark.parametrize("max_norm", [2.0, 1.0, -1.0])
# @pytest.mark.parametrize("norm_type", [2.0, -0.8, 2.2])
# @pytest.mark.parametrize("num_parameters", [32, 128])
@pytest.mark.parametrize("max_norm", [2.0, 1.0, -1.0])
@pytest.mark.parametrize("norm_type", [2.0, -0.8, 2.2])
@pytest.mark.parametrize("num_parameters", [32, 128])
def test_moreh_clip_grad_norm(
num_iters_of_each_case,
num_parameters,
Expand All @@ -75,7 +52,7 @@ def test_moreh_clip_grad_norm(
torch.manual_seed(2023)
random.seed(2023)

cpu_dtype = torch.bfloat16
cpu_dtype = torch.float32
npu_dtype = ttnn.bfloat16

cpu_inputs = []
Expand All @@ -92,21 +69,14 @@ def test_moreh_clip_grad_norm(
padding_w = random.randrange(*range_of_padding)

input_shape = (
1,
1,
1 * TILE_HEIGHT,
1 * TILE_WIDTH,
n,
c,
ht * TILE_HEIGHT - padding_h,
wt * TILE_WIDTH - padding_w,
)
# input_shape = (
# n,
# c,
# ht * TILE_HEIGHT - padding_h,
# wt * TILE_WIDTH - padding_w,
# )

param = torch.nn.Parameter(torch.empty(input_shape, dtype=cpu_dtype))
grad = torch.ones(input_shape, dtype=cpu_dtype) / 10
# grad = torch.empty(input_shape, dtype=cpu_dtype).uniform_(0, 2.5)
grad = torch.empty(input_shape, dtype=cpu_dtype).uniform_(0, 2.5)
param.grad = grad

cpu_inputs.append(param)
Expand All @@ -115,66 +85,59 @@ def test_moreh_clip_grad_norm(

cpu_total_norm = torch.nn.utils.clip_grad_norm_(cpu_inputs, max_norm, norm_type)
npu_total_norm = ttnn.operations.moreh.clip_grad_norm(npu_inputs, max_norm, norm_type)
to_cpu_result = ttnn.to_torch(npu_total_norm)
to_cpu_result_reshape = to_cpu_result.reshape(-1, 32, 32)
print("cpu_total_norm", cpu_total_norm)
print("to_cpu_result_reshape.shape", to_cpu_result_reshape.shape)
print("to_cpu_result_reshape", to_cpu_result_reshape)

# expected_total_norm = cpu_total_norm
# actual_total_norm = to_cpu(npu_total_norm, [1, 1, 1, 1])

# rtol = atol = 0.1
# # Check total_norm
# pass_total_norm, out_total_norm = comp_allclose_and_pcc(
# actual_total_norm, expected_total_norm, rtol=rtol, atol=atol
# )
# logger.debug(f"total_norm's {out_total_norm}")
# assert pass_total_norm

# # Check inputs
# for i in range(num_parameters):
# expected_input_i = cpu_inputs[i].grad
# actual_input_i = to_cpu(npu_inputs[i], input_shapes[i])
# pass_input_i, out_input_i = comp_allclose_and_pcc(expected_input_i, actual_input_i, rtol=rtol, atol=atol)
# logger.debug(f"inputs[{i}]-shape[{input_shapes[i]}]'s {out_input_i}")
# assert pass_input_i


# @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH")
# @pytest.mark.parametrize("error_if_nonfinite", [True, False])
# def test_moreh_clip_grad_norm_with_error_if_nonfinite(error_if_nonfinite, device):
# torch.manual_seed(2023)

# cpu_dtype = torch.bfloat16
# npu_dtype = ttnn.bfloat16

# input_shape = [4, 4, 4 * TILE_HEIGHT, 4 * TILE_WIDTH]
# param = torch.nn.Parameter(torch.empty(input_shape, dtype=cpu_dtype))
# grad = torch.randn(input_shape, dtype=cpu_dtype)
# param.grad = grad

# max_norm = 1.0
# norm_type = float("nan")

# expected_error_msg = (
# f"The total norm of order {norm_type} for gradients from `parameters` is non-finite, so it cannot be clipped."
# )

# # Check vanilla torch behavior
# try:
# torch.nn.utils.clip_grad_norm_((param), max_norm, norm_type, error_if_nonfinite)
# assert not error_if_nonfinite
# except RuntimeError as actual_error_msg:
# assert expected_error_msg in str(actual_error_msg)
# assert error_if_nonfinite

# # Check tt behavior
# try:
# ttnn.operations.moreh.clip_grad_norm(
# [to_npu(param.grad.bfloat16(), device, npu_dtype=npu_dtype)], max_norm, norm_type, error_if_nonfinite
# )
# assert not error_if_nonfinite
# except RuntimeError as actual_error_msg:
# assert expected_error_msg in str(actual_error_msg)
# assert error_if_nonfinite
actual_total_norm = ttnn.to_torch(npu_total_norm).reshape(1)
expected_total_norm = cpu_total_norm

rtol = atol = 0.1
# Check total_norm
pass_total_norm, out_total_norm = comp_allclose_and_pcc(
actual_total_norm, expected_total_norm, rtol=rtol, atol=atol
)
logger.debug(f"total_norm's {out_total_norm}")
assert pass_total_norm

# Check inputs
for i in range(num_parameters):
expected_input_i = cpu_inputs[i].grad
actual_input_i = ttnn.to_torch(npu_inputs[i])
pass_input_i, out_input_i = comp_allclose_and_pcc(expected_input_i, actual_input_i, rtol=rtol, atol=atol)
logger.debug(f"inputs[{i}]-shape[{input_shapes[i]}]'s {out_input_i}")
assert pass_input_i


@pytest.mark.parametrize("error_if_nonfinite", [True, False])
def test_moreh_clip_grad_norm_with_error_if_nonfinite(error_if_nonfinite, device):
torch.manual_seed(2023)

cpu_dtype = torch.bfloat16
npu_dtype = ttnn.bfloat16

input_shape = [4, 4, 4 * TILE_HEIGHT, 4 * TILE_WIDTH]
param = torch.nn.Parameter(torch.empty(input_shape, dtype=cpu_dtype))
grad = torch.randn(input_shape, dtype=cpu_dtype)
param.grad = grad

max_norm = 1.0
norm_type = float("nan")

expected_error_msg = (
f"The total norm of order {norm_type} for gradients from `parameters` is non-finite, so it cannot be clipped."
)

# Check vanilla torch behavior
try:
torch.nn.utils.clip_grad_norm_((param), max_norm, norm_type, error_if_nonfinite)
assert not error_if_nonfinite
except RuntimeError as actual_error_msg:
assert expected_error_msg in str(actual_error_msg)
assert error_if_nonfinite

# Check tt behavior
try:
ttnn.operations.moreh.clip_grad_norm(
[to_npu(param.grad.bfloat16(), device, npu_dtype=npu_dtype)], max_norm, norm_type, error_if_nonfinite
)
assert not error_if_nonfinite
except RuntimeError as actual_error_msg:
assert expected_error_msg in str(actual_error_msg)
assert error_if_nonfinite
66 changes: 28 additions & 38 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,59 +1055,49 @@ ALWI void power_tile_to_cb(
log_tile(dst0);
tile_regs_commit();

// tile_regs_wait();
// pack_tile_with_dt(dst0, cb_logx);
// tile_regs_release();
tile_regs_wait();
pack_tile_with_dt(dst0, cb_logx);
tile_regs_release();

cb_pop_front(cb_x, onetile);
cb_push_back(cb_logx, onetile);

// // exp(log(x) * decimal)
// tile_regs_acquire();
// cb_wait_front(cb_logx, onetile);
// cb_reserve_back(cb_exp_lxmd, onetile);

// mul_tiles_init_with_dt(cb_logx, cb_decimal);
// mul_tiles(cb_logx, cb_decimal, 0, 0, dst0);

// exp_tile_init();
// exp_tile(dst0);
// tile_regs_commit();

// tile_regs_wait();
// pack_tile_with_dt(dst0, cb_exp_lxmd);
// tile_regs_release();
// exp(log(x) * decimal)
tile_regs_acquire();
cb_wait_front(cb_logx, onetile);
cb_reserve_back(cb_exp_lxmd, onetile);

// cb_pop_front(cb_logx, onetile);
// cb_push_back(cb_exp_lxmd, onetile);
mul_tiles_init_with_dt(cb_logx, cb_decimal);
mul_tiles(cb_logx, cb_decimal, 0, 0, dst0);

// // x^p * exp(log(x) * decimal)(==(x + decimal)^p)
// tile_regs_acquire();
// cb_wait_front(cb_xpow, onetile);
// cb_wait_front(cb_exp_lxmd, onetile);
// cb_reserve_back(cb_correct_xpow, onetile);
exp_tile_init();
exp_tile(dst0);
tile_regs_commit();

// mul_tiles_init_with_dt(cb_xpow, cb_exp_lxmd);
// mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0);
// tile_regs_commit();
tile_regs_wait();
pack_tile_with_dt(dst0, cb_exp_lxmd);
tile_regs_release();

// tile_regs_wait();
// pack_tile_with_dt(dst0, cb_correct_xpow);
// tile_regs_release();
cb_pop_front(cb_logx, onetile);
cb_push_back(cb_exp_lxmd, onetile);

// cb_pop_front(cb_xpow, onetile);
// cb_pop_front(cb_exp_lxmd, onetile);
// cb_push_back(cb_correct_xpow, onetile);
// x^p * exp(log(x) * decimal)(==(x + decimal)^p)
tile_regs_acquire();
cb_wait_front(cb_xpow, onetile);
cb_wait_front(cb_exp_lxmd, onetile);
cb_reserve_back(cb_correct_xpow, onetile);

// TESTING START
mul_tiles_init_with_dt(cb_xpow, cb_exp_lxmd);
mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0);
tile_regs_commit();

cb_reserve_back(cb_correct_xpow, onetile);
tile_regs_wait();
pack_tile_with_dt(dst0, cb_correct_xpow);
tile_regs_release();
cb_push_back(cb_correct_xpow, onetile);

// TESTING END
cb_pop_front(cb_xpow, onetile);
cb_pop_front(cb_exp_lxmd, onetile);
cb_push_back(cb_correct_xpow, onetile);
}

ALWI void power_tile_with_abs_x_to_cb(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Tensor MorehClipGradNorm::invoke(
const auto num_iter = (total_num_inputs + max_num_inputs - 1) / max_num_inputs;

auto tmp_pow_sum = create_device_tensor(
SimpleShape{tt::constants::TILE_HEIGHT, tt::constants::TILE_WIDTH * static_cast<uint32_t>(inputs.size())},
SimpleShape{static_cast<uint32_t>(inputs.size()), 1, 1},
inputs.at(0).get_dtype(),
Layout::TILE,
device,
Expand Down Expand Up @@ -90,31 +90,33 @@ Tensor MorehClipGradNorm::invoke(
norm_type);
}

// // max_norm / (total_norm + 1e-6)
// auto clip_coef = ttnn::multiply(ttnn::add(output_total_norm, 1e-6f), (1 / max_norm));
// // min(clip_coef, 1.0f)
// Tensor scalar = creation::create_scalar(1.0f, inputs.at(0).get_dtype(), Layout::TILE, device);
// auto clip_coef_clamped = ttnn::minimum(clip_coef, scalar);
// scalar.deallocate();

// // Run Step 3
// // Inplace update inputs(inputs *= clip_coef_clamped)
// uint32_t start_input_idx{0};
// num_inputs = total_num_inputs;
// for (uint32_t i = 0; i < num_iter; ++i) {
// const auto num_inputs_at_this_iter = std::min(num_inputs, max_num_inputs);

// auto input_tensors = std::vector<Tensor>(
// inputs.begin() + start_input_idx, inputs.begin() + start_input_idx + num_inputs_at_this_iter);

// ttnn::prim::moreh_clip_grad_norm_step3(
// input_tensors, clip_coef_clamped, memory_config, compute_kernel_config_val);

// if (i < (num_iter - 1)) {
// start_input_idx += num_inputs_at_this_iter;
// num_inputs -= num_inputs_at_this_iter;
// }
// }
// max_norm / (total_norm + 1e-6)
Tensor max_norm_tensor = creation::create_scalar(max_norm, inputs.at(0).get_dtype(), Layout::TILE, device);
auto clip_coef = ttnn::div(max_norm_tensor, ttnn::add(output_total_norm, 1e-6f));
// min(clip_coef, 1.0f)
Tensor scalar = creation::create_scalar(1.0f, inputs.at(0).get_dtype(), Layout::TILE, device);
auto clip_coef_clamped = ttnn::minimum(clip_coef, scalar);
scalar.deallocate();
max_norm_tensor.deallocate();

// Run Step 3
// Inplace update inputs(inputs *= clip_coef_clamped)
uint32_t start_input_idx{0};
num_inputs = total_num_inputs;
for (uint32_t i = 0; i < num_iter; ++i) {
const auto num_inputs_at_this_iter = std::min(num_inputs, max_num_inputs);

auto input_tensors = std::vector<Tensor>(
inputs.begin() + start_input_idx, inputs.begin() + start_input_idx + num_inputs_at_this_iter);

ttnn::prim::moreh_clip_grad_norm_step3(
input_tensors, clip_coef_clamped, memory_config, compute_kernel_config_val);

if (i < (num_iter - 1)) {
start_input_idx += num_inputs_at_this_iter;
num_inputs -= num_inputs_at_this_iter;
}
}

return output_total_norm;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ MorehClipGradNormStep1Operation::ProgramFactory::create(
}

auto [p, decimal, p_is_negative] = get_p_decimal_p_is_negative(norm_type);
std::cout << "step1: p: " << p << " ;decimal: " << decimal << " ;p_is_negative: " << p_is_negative << std::endl;

////////////////////////////////////////////////////////////////////////////
// Core Setup
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ void MAIN {
tile_regs_release();
}
}

// Maybe bug here. Produce very incorrect result
// x^p
power_tile_to_cb(cb_x, cb_xpow, cb_logx, cb_decimal, cb_exp_lxmd, cb_y, p, p_is_negative);
} // void MAIN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void MorehClipGradNormStep2Operation::validate_on_program_cache_hit(

MorehClipGradNormStep2Operation::shape_return_value_t MorehClipGradNormStep2Operation::compute_output_shapes(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return SimpleShape{tt::constants::TILE_HEIGHT, tt::constants::TILE_WIDTH};
return SimpleShape{1, 1};
};

MorehClipGradNormStep2Operation::tensor_return_value_t MorehClipGradNormStep2Operation::create_output_tensors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ MorehClipGradNormStep2Operation::ProgramFactory::create(
const auto num_tiles = tmp_pow_sum.volume() / tt::constants::TILE_HW;

auto [p, decimal, p_is_negative] = get_p_decimal_p_is_negative(1.0f / norm_type);
std::cout << "step2: p: " << p << " ;decimal: " << decimal << " ;p_is_negative: " << p_is_negative << std::endl;

////////////////////////////////////////////////////////////////////////////
// Core Setup
Expand Down
Loading

0 comments on commit 3f2551e

Please sign in to comment.