Skip to content

Commit

Permalink
#13095: Refactor moreh_nll_loss operations (#13097)
Browse files Browse the repository at this point in the history
#13095: bind moreh_nll_loss to moreh.nll_loss

#13095: rollback .clangd changes

#13095: change testutils path

#13095: remove redundant comments

#13095: assert device cache entries

#13095: un-cmt testcases

#13095: update callback testcases

#13095: change param from optional<const Tensor> to optional<Tensor>&
  • Loading branch information
BuiChiTrung authored Sep 27, 2024
1 parent e53f39d commit a9e9b51
Show file tree
Hide file tree
Showing 21 changed files with 165 additions and 153 deletions.
49 changes: 37 additions & 12 deletions tests/ttnn/unit_tests/operations/test_moreh_nll_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from models.utility_functions import comp_allclose_and_pcc
from loguru import logger

from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import (
from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
Expand Down Expand Up @@ -63,7 +63,7 @@ def run_moreh_nll_loss(shape, ignore_index, reduction, none_weight, device, comp

assert reduction in ["sum", "mean"]

tt_loss = ttnn.moreh_nll_loss(
tt_loss = ttnn.operations.moreh.nll_loss(
tt_input,
tt_target,
reduction, # reduction_mean,
Expand Down Expand Up @@ -101,7 +101,7 @@ def run_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weight
if reduction_mean == False:
tt_divisor = None
reduction = "sum"
tt_loss = ttnn.moreh_nll_loss(
tt_loss = ttnn.operations.moreh.nll_loss(
tt_input,
tt_target,
reduction,
Expand All @@ -119,7 +119,7 @@ def run_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weight
tt_output_grad = to_npu(output_grad, device)
tt_input_grad = to_npu(torch_input, device)

tt_input_grad = ttnn.moreh_nll_loss_backward(
tt_input_grad = ttnn.operations.moreh.nll_loss_backward(
target_tensor=tt_target,
weight_tensor=tt_weight,
divisor_tensor=tt_divisor,
Expand Down Expand Up @@ -167,17 +167,29 @@ def test_moreh_nll_loss(shape, ignore_index, reduction, none_weight, device):
],
)
@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_program_cache):
def test_moreh_nll_loss_callback(shape, reduction, device, use_program_cache):
torch.manual_seed(0)
ignore_index = 0

ignore_idx = 0
num_program_cache_entries_list = []
for i in range(4):
if i < 2:
none_weight = True
else:
none_weight = False

for _ in range(2):
run_moreh_nll_loss(shape, ignore_idx, reduction, none_weight, device)
run_moreh_nll_loss(shape, ignore_index, reduction, none_weight, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_npu(torch_dummy, device)

num_program_cache_entries_list.append(device.num_program_cache_entries())

logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert (
num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
and num_program_cache_entries_list[2] == num_program_cache_entries_list[3]
)


@pytest.mark.parametrize(
"shape",
Expand Down Expand Up @@ -228,17 +240,30 @@ def test_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weigh
],
)
@pytest.mark.parametrize("reduction_mean", [True, False])
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_backward_test_callback(shape, reduction_mean, none_weight, device, use_program_cache):
def test_moreh_nll_loss_backward_test_callback(shape, reduction_mean, device, use_program_cache):
torch.manual_seed(0)

ignore_index = 0

for _ in range(2):
num_program_cache_entries_list = []
for i in range(4):
if i < 2:
none_weight = True
else:
none_weight = False

run_moreh_nll_loss_backward(shape, ignore_index, reduction_mean, none_weight, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_npu(torch_dummy, device)

num_program_cache_entries_list.append(device.num_program_cache_entries())

logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert (
num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
and num_program_cache_entries_list[2] == num_program_cache_entries_list[3]
)


@pytest.mark.parametrize(
"shape",
Expand Down
41 changes: 34 additions & 7 deletions tests/ttnn/unit_tests/operations/test_moreh_nll_loss_unreduced.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from models.utility_functions import comp_allclose_and_pcc, is_wormhole_b0
from loguru import logger

from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import (
from tests.ttnn.unit_tests.operations.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
Expand Down Expand Up @@ -74,7 +74,7 @@ def run_moreh_nll_loss_unreduced_backward(shape, ignore_index, none_weight, devi
torch_target, torch_weight, output_grad, torch_input.grad, device
)

tt_input_grad = ttnn.moreh_nll_loss_unreduced_backward(
tt_input_grad = ttnn.operations.moreh.nll_loss_unreduced_backward(
tt_target,
tt_output_grad,
weight_tensor=tt_weight,
Expand Down Expand Up @@ -110,7 +110,7 @@ def run_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, device, compu

reduction_mode = "none"

tt_loss = ttnn.moreh_nll_loss(
tt_loss = ttnn.operations.moreh.nll_loss(
tt_input,
tt_target,
reduction_mode,
Expand Down Expand Up @@ -158,17 +158,30 @@ def test_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, compute_kern
(5, 10, 10, 20),
],
)
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_unreduced_callback(shape, none_weight, device, use_program_cache):
def test_moreh_nll_loss_unreduced_callback(shape, device, use_program_cache):
torch.manual_seed(0)

ignore_index = 1
num_program_cache_entries_list = []

for i in range(4):
if i < 2:
none_weight = True
else:
none_weight = False

for _ in range(2):
run_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_npu(torch_dummy, device)

num_program_cache_entries_list.append(device.num_program_cache_entries())

logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert (
num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
and num_program_cache_entries_list[2] == num_program_cache_entries_list[3]
)


@pytest.mark.parametrize(
"shape",
Expand Down Expand Up @@ -205,7 +218,21 @@ def test_moreh_nll_loss_unreduced_backward(
def test_moreh_nll_loss_unreduced_backward_test_callback(shape, none_weight, device, ignore_index, use_program_cache):
torch.manual_seed(0)

for _ in range(2):
num_program_cache_entries_list = []
for i in range(4):
if i < 2:
none_weight = True
else:
none_weight = False

run_moreh_nll_loss_unreduced_backward(shape, ignore_index, none_weight, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_npu(torch_dummy, device)

num_program_cache_entries_list.append(device.num_program_cache_entries())

logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert (
num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
and num_program_cache_entries_list[2] == num_program_cache_entries_list[3]
)
17 changes: 10 additions & 7 deletions ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ Tensor MorehNllLoss::invoke(
const Tensor &input_tensor,
const Tensor &target_tensor,
const std::string reduction,
const std::optional<const Tensor> weight_tensor,
const std::optional<const Tensor> divisor_tensor,
const std::optional<const Tensor> output_tensor,
const std::optional<Tensor> &weight_tensor,
const std::optional<Tensor> &divisor_tensor,
const std::optional<Tensor> &output_tensor,
const int32_t ignore_index,
const std::optional<MemoryConfig> &memory_config,
const std::optional<const DeviceComputeKernelConfig> compute_kernel_config) {
const auto compute_kernel_config_val = init_device_compute_kernel_config(target_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4);
const std::optional<DeviceComputeKernelConfig> &compute_kernel_config) {
const auto compute_kernel_config_val =
init_device_compute_kernel_config(target_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4);
if (reduction == MEAN) {
TT_FATAL(divisor_tensor.has_value(), "Divisor tensor must not be empty");

Expand All @@ -49,7 +50,8 @@ Tensor MorehNllLoss::invoke(
ignore_index,
memory_config,
compute_kernel_config_val);
return ttnn::moreh_sum(step2_result, std::nullopt, false, output_tensor, memory_config, compute_kernel_config_val);
return ttnn::moreh_sum(
step2_result, std::nullopt, false, output_tensor, memory_config, compute_kernel_config_val);
} else if (reduction == SUM) {
const Tensor &step2_result = prim::moreh_nll_loss_step2(
input_tensor,
Expand All @@ -61,7 +63,8 @@ Tensor MorehNllLoss::invoke(
ignore_index,
memory_config,
compute_kernel_config_val);
return ttnn::moreh_sum(step2_result, std::nullopt, false, output_tensor, memory_config, compute_kernel_config_val);
return ttnn::moreh_sum(
step2_result, std::nullopt, false, output_tensor, memory_config, compute_kernel_config_val);
}

return prim::moreh_nll_loss_step2(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ struct MorehNllLoss {
const Tensor &input_tensor,
const Tensor &target_tensor,
const std::string reduction,
const std::optional<const Tensor> weight_tensor,
const std::optional<const Tensor> divisor_tensor,
const std::optional<const Tensor> output_tensor,
const std::optional<Tensor> &weight_tensor,
const std::optional<Tensor> &divisor_tensor,
const std::optional<Tensor> &output_tensor,
const int32_t ignore_index,
const std::optional<MemoryConfig> &memory_config,
const std::optional<const DeviceComputeKernelConfig> compute_kernel_config);
const std::optional<DeviceComputeKernelConfig> &compute_kernel_config);
};

} // namespace ttnn::operations::moreh::moreh_nll_loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ MorehNllLossStep1DeviceOperation::tensor_return_value_t MorehNllLossStep1DeviceO
std::tuple<MorehNllLossStep1DeviceOperation::operation_attributes_t, MorehNllLossStep1DeviceOperation::tensor_args_t>
MorehNllLossStep1DeviceOperation::invoke(
const Tensor& target_tensor,
const std::optional<const Tensor> weight_tensor,
const std::optional<Tensor>& weight_tensor,
const int32_t ignore_index,
const std::string reduction,
const DataType output_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct MorehNllLossStep1DeviceOperation {

struct tensor_args_t {
const Tensor& target_tensor;
const std::optional<const Tensor> weight_tensor;
const std::optional<Tensor>& weight_tensor;
};

using shape_return_value_t = Shape;
Expand Down Expand Up @@ -53,28 +53,21 @@ struct MorehNllLossStep1DeviceOperation {

using program_factory_t = std::variant<Factory>;

// Mandatory methods

// Select the program factory based on the operation attributes and tensor args
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);

static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);

// Validate the operation when it creates a program. Usually will have more checks
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

// Validate the operation when it reuses a program. Usually will have less checks
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);

// Compute the output shapes based on the operation attributes and tensor args
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);

// Create the output tensors based on the operation attributes and tensor args
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
const Tensor& target_tensor,
const std::optional<const Tensor> weight_tensor,
const std::optional<Tensor>& weight_tensor,
const int32_t ignore_index,
const std::string reduction,
const DataType output_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ MorehNllLossStep1DeviceOperation::Factory::cached_program_t MorehNllLossStep1Dev
using namespace tt::tt_metal;

const Tensor& target = tensor_args.target_tensor;
const std::optional<const Tensor> weight = tensor_args.weight_tensor;
const std::optional<Tensor>& weight = tensor_args.weight_tensor;
const Tensor& output = tensor_return_value;
const std::string reduction = operation_attributes.reduction;
const uint32_t ignore_index = operation_attributes.ignore_index;
Expand Down Expand Up @@ -101,7 +101,7 @@ MorehNllLossStep1DeviceOperation::Factory::cached_program_t MorehNllLossStep1Dev
// create read/wrtie kernel
const std::vector<uint32_t> reader_compile_time_args{
static_cast<uint32_t>(tt::operations::primary::is_dram(target)),
static_cast<uint32_t>(tt::operations::primary::is_dram(weight)),
static_cast<uint32_t>(weight.has_value() ? tt::operations::primary::is_dram(weight.value()) : false),
static_cast<uint32_t>(weight_has_value)};

const std::vector<uint32_t> writer_compile_time_args{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ void MorehNllLossStep2DeviceOperation::validate_inputs(
const operation_attributes_t& attributes, const tensor_args_t& tensor_args) {
const Tensor& input_tensor = tensor_args.input_tensor;
const Tensor& target_tensor = tensor_args.target_tensor;
const std::optional<const Tensor> weight_tensor = tensor_args.weight_tensor;
const std::optional<const Tensor> divisor_tensor = tensor_args.divisor_tensor;
const std::optional<Tensor>& weight_tensor = tensor_args.weight_tensor;
const std::optional<Tensor>& divisor_tensor = tensor_args.divisor_tensor;

TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "intput_tensor to nll_loss need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr, "intput_tensor to nll_loss need to be allocated in buffers on device!");
Expand Down Expand Up @@ -130,9 +130,9 @@ MorehNllLossStep2DeviceOperation::invoke(
const Tensor& input_tensor,
const Tensor& target_tensor,
const std::string reduction,
const std::optional<const Tensor> weight_tensor,
const std::optional<const Tensor> divisor_tensor,
const std::optional<const Tensor> output_tensor,
const std::optional<Tensor>& weight_tensor,
const std::optional<Tensor>& divisor_tensor,
const std::optional<Tensor>& output_tensor,
const int32_t ignore_index,
const std::optional<ttnn::MemoryConfig>& memory_config,
const DeviceComputeKernelConfig& compute_kernel_config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ struct MorehNllLossStep2DeviceOperation {
struct tensor_args_t {
const Tensor& input_tensor;
const Tensor& target_tensor;
const std::optional<const Tensor> weight_tensor;
const std::optional<const Tensor> divisor_tensor;
const std::optional<const Tensor> output_tensor;
const std::optional<Tensor>& weight_tensor;
const std::optional<Tensor>& divisor_tensor;
const std::optional<Tensor>& output_tensor;
};

using shape_return_value_t = ttnn::Shape;
Expand Down Expand Up @@ -54,32 +54,25 @@ struct MorehNllLossStep2DeviceOperation {

using program_factory_t = std::variant<Factory>;

// Mandatory methods

// Select the program factory based on the operation attributes and tensor args
static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);

static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);

// Validate the operation when it creates a program. Usually will have more checks
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

// Validate the operation when it reuses a program. Usually will have less checks
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);

// Compute the output shapes based on the operation attributes and tensor args
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);

// Create the output tensors based on the operation attributes and tensor args
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
const Tensor& input_tensor,
const Tensor& target_tensor,
const std::string reduction,
const std::optional<const Tensor> weight_tensor,
const std::optional<const Tensor> divisor_tensor,
const std::optional<const Tensor> output_tensor,
const std::optional<Tensor>& weight_tensor,
const std::optional<Tensor>& divisor_tensor,
const std::optional<Tensor>& output_tensor,
const int32_t ignore_index,
const std::optional<ttnn::MemoryConfig>& memory_config,
const DeviceComputeKernelConfig& compute_kernel_config);
Expand Down
Loading

0 comments on commit a9e9b51

Please sign in to comment.