Skip to content

Commit

Permalink
#9663: support moreh_nll_loss_unreduced (#9804)
Browse files Browse the repository at this point in the history
#9663: support moreh_nll_loss_unreduced
  • Loading branch information
hschoi4448 authored Jul 9, 2024
1 parent 42d3dd0 commit e9e3ad3
Show file tree
Hide file tree
Showing 15 changed files with 1,525 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ Primary Operations

.. autofunction:: tt_lib.operations.primary.moreh_norm_backward

.. autofunction:: tt_lib.operations.primary.moreh_nll_loss_unreduced

.. autofunction:: tt_lib.operations.primary.moreh_nll_loss_unreduced_backward

Enums
=====

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch

import tt_lib as ttl
import pytest
from models.utility_functions import comp_allclose_and_pcc, is_wormhole_b0
from loguru import logger

from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import (
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
to_cpu,
to_npu,
)


def get_torch_tensors(shape):
C = shape[1]
target_shape = shape[:1] + shape[2:]

cpu_dtype = torch.float32
cpu_index_dtype = torch.long

torch_input = torch.rand(shape, dtype=cpu_dtype).requires_grad_()
torch_target = torch.randint(0, C, target_shape, dtype=cpu_index_dtype)
torch_weight = torch.rand(C, dtype=cpu_dtype)
torch_output = torch.empty(target_shape, dtype=cpu_dtype)

return torch_input, torch_target, torch_weight, torch_output


def get_tt_tensors(torch_input, torch_target, torch_weight, torch_output, device):
npu_index_dtype = ttl.tensor.DataType.INT32

tt_input = to_npu(torch_input, device)
tt_target = to_npu(torch_target, device, npu_dtype=npu_index_dtype)
tt_weight = to_npu(torch_weight, device)
tt_output = to_npu(torch_output, device)

return tt_input, tt_target, tt_weight, tt_output


def get_tt_backward_tensors(torch_target, torch_weight, torch_output_grad, torch_input_grad, device):
npu_index_dtype = ttl.tensor.DataType.INT32

tt_target = to_npu(torch_target, device, npu_dtype=npu_index_dtype)
tt_weight = to_npu(torch_weight, device)
tt_output_grad = to_npu(torch_output_grad, device)
tt_input_grad = to_npu(torch_input_grad, device)

return tt_target, tt_weight, tt_output_grad, tt_input_grad


def run_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, device, compute_kernel_options=None):
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

(torch_input, torch_target, torch_weight, torch_output) = get_torch_tensors(shape)

if none_weight:
torch_weight = None

nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction="none")
torch_loss = nll_loss(torch_input, torch_target)

(tt_input, tt_target, tt_weight, tt_output) = get_tt_tensors(
torch_input, torch_target, torch_weight, torch_output, device
)

tt_loss = ttl.operations.primary.moreh_nll_loss_unreduced(
tt_input,
tt_target,
tt_weight,
tt_output,
ignore_index,
compute_kernel_config=compute_kernel_config,
)

tt_loss_to_cpu = to_cpu(tt_loss, torch_target.shape)

rtol = atol = 0.05
passing, out = comp_allclose_and_pcc(torch_loss, tt_loss_to_cpu, pcc=0.999, rtol=rtol, atol=atol)
logger.debug(f"Out passing (param)={passing}")
logger.debug(f"Output pcc={out}")

assert passing


def run_moreh_nll_loss_unreduced_backward(shape, ignore_index, none_weight, device, compute_kernel_options=None):
compute_kernel_config = get_compute_kernel_options(compute_kernel_options)

# run torch
(torch_input, torch_target, torch_weight, _) = get_torch_tensors(shape)
if none_weight:
torch_weight = None

nll_loss = torch.nn.NLLLoss(weight=torch_weight, ignore_index=ignore_index, reduction="none")
torch_loss = nll_loss(torch_input, torch_target)

output_grad = torch.randn_like(torch_loss)
torch_loss.backward(output_grad)

# run tt
(tt_target, tt_weight, tt_output_grad, tt_input_grad) = get_tt_backward_tensors(
torch_target, torch_weight, output_grad, torch_input.grad, device
)

tt_input_grad = ttl.operations.primary.moreh_nll_loss_unreduced_backward(
tt_target,
tt_weight,
tt_output_grad,
tt_input_grad,
ignore_index,
compute_kernel_config=compute_kernel_config,
)
tt_input_grad_to_cpu = to_cpu(tt_input_grad, torch_input.grad.shape)

rtol = atol = 0.05
passing, out = comp_allclose_and_pcc(torch_input.grad, tt_input_grad_to_cpu, pcc=0.999, rtol=rtol, atol=atol)

logger.debug(f"Out passing (param)={passing}")
logger.debug(f"Output pcc={out}")

assert passing


@pytest.mark.parametrize(
"shape",
[
(5, 10),
(500, 100),
(4, 3, 2, 4, 50, 70),
],
)
@pytest.mark.parametrize("ignore_index", [1])
@pytest.mark.parametrize("none_weight", [True, False])
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, compute_kernel_options, device, use_program_cache):
torch.manual_seed(0)

run_moreh_nll_loss_unreduced(
shape, ignore_index, none_weight, device, compute_kernel_options=compute_kernel_options
)


@pytest.mark.parametrize(
"shape",
[
(32, 32),
(400, 300),
(20, 300, 320),
(5, 2, 5, 40, 70),
],
)
@pytest.mark.parametrize("ignore_index", [1])
@pytest.mark.parametrize("none_weight", [True, False])
@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids)
def test_moreh_nll_loss_unreduced_backward(
shape, ignore_index, none_weight, compute_kernel_options, device, use_program_cache
):
torch.manual_seed(0)

run_moreh_nll_loss_unreduced_backward(
shape, ignore_index, none_weight, device, compute_kernel_options=compute_kernel_options
)


@pytest.mark.parametrize(
"shape",
[
(5, 10),
(5, 10, 10),
(5, 10, 10, 20),
],
)
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_unreduced_callback(shape, none_weight, device, use_program_cache):
torch.manual_seed(0)

ignore_index = 1

for _ in range(2):
run_moreh_nll_loss_unreduced(shape, ignore_index, none_weight, device)


@pytest.mark.parametrize(
"shape",
[
(2, 3),
(2, 3, 4),
(2, 3, 5, 4),
],
)
@pytest.mark.parametrize("none_weight", [True, False])
def test_moreh_nll_loss_unreduced_backward_test_callback(shape, none_weight, device, use_program_cache):
torch.manual_seed(0)

ignore_index = 0

for _ in range(2):
run_moreh_nll_loss_unreduced_backward(shape, ignore_index, none_weight, device)
44 changes: 44 additions & 0 deletions tests/tt_eager/python_api_testing/unit_testing/misc/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import tt_lib as ttl
from models.utility_functions import is_wormhole_b0
import copy

compute_kernel_options = [
False, # for grayskull
Expand Down Expand Up @@ -33,3 +34,46 @@ def get_compute_kernel_options(compute_kernel_options):
math_approx_mode=True,
)
return compute_kernel_config


def to_cpu(npu_tensor, shape, *, cpu_layout=ttl.tensor.Layout.ROW_MAJOR):
if npu_tensor is None:
return None

shape = list(shape)

unpad_shape = copy.copy(shape)

if shape == []:
unpad_shape = [1, 1]

if len(shape) == 1:
unpad_shape = [1] + shape

cpu_tensor = npu_tensor.cpu().to(cpu_layout).unpad_from_tile(unpad_shape).to_torch().reshape(shape)

return cpu_tensor


def to_npu(
cpu_tensor,
device,
*,
npu_layout=ttl.tensor.Layout.TILE,
npu_dtype=ttl.tensor.DataType.BFLOAT16,
shape=None,
):
if cpu_tensor is None:
return None

if shape is not None:
cpu_tensor = cpu_tensor.view(shape)

if len(cpu_tensor.shape) == 1:
cpu_tensor = cpu_tensor.reshape([1, len(cpu_tensor)])

if len(cpu_tensor.shape) == 0:
cpu_tensor = cpu_tensor.reshape([1, 1])

npu_tensor = ttl.tensor.Tensor(cpu_tensor, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device)
return npu_tensor
75 changes: 75 additions & 0 deletions tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ static inline float bfloat16_to_float(uint16_t bfloat_val) {
return f;
}

static inline uint16_t float_to_bfloat16(float val) {
union {
float f;
uint32_t u;
} ret;
ret.f = val;
return uint16_t(ret.u >> 16);
}

#if defined(FP32_DEST_ACC_EN)
using FP32_DEST_ACC_FTYPE = float;
FORCE_INLINE FP32_DEST_ACC_FTYPE fp32_dest_acc_cast(uint16_t val) { return bfloat16_to_float(val); }
Expand Down Expand Up @@ -838,3 +847,69 @@ void get_noc_offset(uint32_t h, uint32_t w, uint32_t element_size, uint32_t &noc

noc_offset = noc_offset_alilgn_32;
}

template<typename T>
volatile tt_l1_ptr T* get_read_ptr(uint32_t cb_id) {

auto l1_write_addr = get_read_ptr(cb_id);
auto l1_ptr = reinterpret_cast<volatile tt_l1_ptr T*>(l1_write_addr);
return l1_ptr;
}

template<typename T>
volatile tt_l1_ptr T* get_write_ptr(uint32_t cb_id) {

auto l1_write_addr = get_write_ptr(cb_id);
auto l1_ptr = reinterpret_cast<volatile tt_l1_ptr T*>(l1_write_addr);
return l1_ptr;
}

// It reads values from one tile.
template<typename T>
void read_tile(uint32_t cb_id, T addrgen, uint32_t noc_id, uint32_t size = 0, uint32_t offset = 0, bool do_reserve = true, bool do_push_back = true) {

constexpr uint32_t onetile = 1;

if (do_reserve) cb_reserve_back(cb_id, onetile);

// If the size is 0, it reads one tile.
if (size == 0){
size = get_tile_size(cb_id);
}

auto l1_write_addr = get_write_ptr(cb_id);
auto noc_addr = get_noc_addr(noc_id, addrgen, offset);
noc_async_read(noc_addr, l1_write_addr, size);
noc_async_read_barrier();

if (do_push_back) cb_push_back(cb_id, onetile);
}


// It reads values from a tilized tensor with shape (1, W).
template<typename T>
void read_line(uint32_t cb_id, T addrgen, uint32_t num_tiles, bool do_reserve = true, bool do_push_back = true) {

if (do_reserve) cb_reserve_back(cb_id, num_tiles);

auto tile_bytes = get_tile_size(cb_id);
auto element_size = tile_bytes / 1024;
auto noc_read_size = FACE_WIDTH * element_size;

uint32_t l1_write_addr = get_write_ptr(cb_id);

for (uint32_t i = 0; i < num_tiles * 2; ++i) {
uint32_t noc_id = i / 2;
uint32_t noc_offset = 0;
if (noc_id * 2 != i) {
noc_offset += 256 * element_size;
}
auto src_noc_addr = get_noc_addr(noc_id, addrgen, noc_offset);
noc_async_read(src_noc_addr, l1_write_addr, noc_read_size);
noc_async_read_barrier();

l1_write_addr += noc_read_size;
}

if (do_push_back) cb_push_back(cb_id, num_tiles);
}
3 changes: 3 additions & 0 deletions tt_eager/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_unreduced/moreh_nll_loss_unreduced_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_nll_loss_unreduced_backward/moreh_nll_loss_unreduced_backward.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/moreh_softmax_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_w_small/softmax_w_small.cpp
${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_h_small/softmax_h_small.cpp
Expand Down
Loading

0 comments on commit e9e3ad3

Please sign in to comment.