Skip to content

Commit

Permalink
#13137: Revise moreh_arange operation
Browse files Browse the repository at this point in the history
  • Loading branch information
mrshaw01 committed Sep 27, 2024
1 parent e649cda commit e53f39d
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 275 deletions.
296 changes: 103 additions & 193 deletions tests/ttnn/unit_tests/operations/test_moreh_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,226 +8,136 @@

import ttnn
from models.utility_functions import comp_allclose_and_pcc
from tests.ttnn.unit_tests.operations.test_utils import to_npu


def get_tt_dtype(torch_dtype):
if torch_dtype == torch.int32:
return ttnn.int32
if torch_dtype == torch.bfloat16:
return ttnn.bfloat16
if torch_dtype == torch.float32:
return ttnn.float32
return None
def get_lib_dtype(lib, dtype):
"""Maps dtype to corresponding library dtype."""
dtype_map = {
"int32": lib.int32,
"bfloat16": lib.bfloat16,
"float32": lib.float32,
}
return dtype_map.get(dtype, None)


def create_tt_tensor(tensor: torch.Tensor, dtype, device):
return ttnn.from_torch(tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device)


@pytest.mark.parametrize(
"start_end_step",
(
(-5, 27, 1), # simple
(2.3, 15.3, 0.5), # floating point
(10, 0, -0.3), # minus step
(10, 32 * 3, 1), # multiple cores
),
)
def test_arange_row_major_simple(start_end_step, device):
# Prepare and compute by torch
def run_moreh_arange(start_end_step, optional_output, dtype, tilized, device):
"""Run a comparison of arange results between torch and ttnn."""
# Prepare inputs
start, end, step = start_end_step
any_cpu = torch.ones((1024))
any = create_tt_tensor(any_cpu, ttnn.bfloat16, device)
untilize_out = True
tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16)

# Compute by ttnn
tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out)
tt_dev = tt_npu.cpu().to_torch()
if (dtype == "int32") and (start != int(start) or end != int(end) or step != int(step)):
pytest.skip(f"start/end/step must be integer when using int32 dtype")

# Compute using torch
torch_dtype = get_lib_dtype(torch, dtype)
if torch_dtype is None:
torch_dtype = torch.bfloat16
expected_output = torch.arange(start=start, end=end, step=step).to(torch_dtype)

# Compute using ttnn
ttnn_dtype = get_lib_dtype(ttnn, dtype)
if ttnn_dtype is None:
ttnn_dtype = ttnn.bfloat16
any_cpu = torch.ones(1024)
any_npu = ttnn.from_torch(any_cpu, dtype=ttnn_dtype, device=device)
if tilized:
L = expected_output.shape[0]
if optional_output:
output_cpu = torch.empty_like(expected_output)
if tilized:
output_npu = (
ttnn.from_torch(output_cpu, ttnn_dtype)
.reshape([1, L])
.pad_to_tile(float("nan"))
.to(ttnn.TILE_LAYOUT)
.to(device)
)
else:
output_npu = ttnn.from_torch(output_cpu, dtype=ttnn_dtype, device=device)
else:
output_npu = None
output_npu = ttnn.operations.moreh.arange(
start,
end,
step,
any_npu,
output=output_npu,
untilize_out=not tilized,
dtype=get_lib_dtype(ttnn, dtype),
)
if tilized:
actual_output = output_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L))
else:
actual_output = output_npu.cpu().to_torch()

# Compare
assert tt_dev.shape == tt_cpu.shape
rtol = atol = 0.1
passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol)
# Assert shape and value comparison
assert actual_output.shape == expected_output.shape
passing, out = comp_allclose_and_pcc(expected_output, actual_output, rtol=0.1, atol=0.1)
logger.info(out)
assert passing


@pytest.mark.parametrize(
"start_end_step",
((0, 32 * 10, 1),), # simple
[
[0, 32, 1],
[2.3, 15.3, 0.5],
[10.9, -13, -0.3],
[-100, 32 * 10, 1],
[0, 32000, 1],
[2300.3, 15300.3, 0.5392],
[10900.9, -13000, -0.3111],
[-10000, 32 * 10000, 1],
],
)
@pytest.mark.parametrize(
"optional_output",
(
True,
False,
),
)
def test_arange_row_major_optional_output(start_end_step, optional_output, device):
# Prepare and compute by torch
start, end, step = start_end_step
any_cpu = torch.ones((1024))
any = ttnn.Tensor(any_cpu, ttnn.bfloat16).to(device)
untilize_out = True
tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16)

# Compute by ttnn
if optional_output:
output_cpu = torch.empty_like(tt_cpu)
output = ttnn.from_torch(output_cpu, dtype=ttnn.bfloat16, device=device)
tt_npu = ttnn.operations.moreh.arange(start, end, step, any, output_tensor=output, untilize_out=untilize_out)
else:
tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out)

tt_dev = tt_npu.cpu().to_torch()

# Compare
assert tt_dev.shape == tt_cpu.shape
rtol = atol = 0.1
passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol)
logger.info(out)
assert passing


@pytest.mark.parametrize(
"start_end_step",
((-10, 22, 1),), # simple
[True, False],
)
@pytest.mark.parametrize(
"output_dtype",
(
torch.bfloat16,
torch.int32,
torch.float32,
),
ids=["bfloat16", "int32", "float32"],
"dtype",
[None, "bfloat16", "int32", "float32"],
)
def test_arange_row_major_dtype(start_end_step, output_dtype, device):
# Prepare and compute by torch
start, end, step = start_end_step
tt_dtype = get_tt_dtype(output_dtype)
tt_cpu = torch.arange(start=start, end=end, step=step).to(output_dtype)
any_cpu = torch.ones((1024))
any = create_tt_tensor(any_cpu, tt_dtype, device)
untilize_out = True

# Compute by ttnn
tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out, output_dtype=tt_dtype)
tt_dev = tt_npu.cpu().to_torch()

# Compare
assert tt_dev.shape == tt_cpu.shape
rtol = atol = 0.1
passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol)
logger.info(out)
assert passing


@pytest.mark.parametrize(
"start_end_step",
(
(0, 32, 1), # simple
(2.3, 15.7, 0.5), # floating point
(10, 0, -0.3), # minus step
(10, 32 * 3, 1), # multiple cores
),
"tilized",
[True, False],
)
def test_arange_tilized_simple(start_end_step, device):
# Prepare and compute by torch
start, end, step = start_end_step
tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16)
any_cpu = torch.ones((1024))
any = create_tt_tensor(any_cpu, ttnn.bfloat16, device)

# Compute by ttnn
tt_npu = ttnn.operations.moreh.arange(start, end, step, any)
L = tt_cpu.shape[0]
tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L)).to(torch.bfloat16)

# Compare
assert tt_dev.shape == tt_cpu.shape
rtol = atol = 0.1
passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol)
logger.info(out)
assert passing
def test_arange(start_end_step, optional_output, dtype, tilized, device):
"""Test arange functionality with different parameters."""
run_moreh_arange(start_end_step, optional_output, dtype, tilized, device)


@pytest.mark.parametrize(
"start_end_step",
((0, 32 * 10, 1),), # simple
[
[0, 32, 1],
[2.3, 15.3, 0.5],
[10.9, -13, -0.3],
[-100, 32 * 10, 1],
[0, 32000, 1],
[2300.3, 15300.3, 0.5392],
[10900.9, -13000, -0.3111],
[-10000, 32 * 10000, 1],
],
)
@pytest.mark.parametrize(
"optional_output",
(
True,
False,
),
)
def test_arange_tilized_major_optional_output(start_end_step, optional_output, device):
# Prepare and compute by torch
start, end, step = start_end_step
tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16)
L = tt_cpu.shape[0]
any_cpu = torch.ones((1024))
any = create_tt_tensor(any_cpu, ttnn.bfloat16, device)
untilize_out = False

# Compute by ttnn
if optional_output:
output_cpu = torch.empty_like(tt_cpu)
output = (
ttnn.from_torch(output_cpu, ttnn.bfloat16)
.reshape([1, L])
.pad_to_tile(float("nan"))
.to(ttnn.TILE_LAYOUT)
.to(device)
)
tt_npu = ttnn.operations.moreh.arange(start, end, step, any, output_tensor=output, untilize_out=untilize_out)
else:
tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out)
tt_dev = tt_npu.cpu().to_torch()
tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L)).to(torch.bfloat16)

# Compare
assert tt_dev.shape == tt_cpu.shape
rtol = atol = 0.1
passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol)
logger.info(out)
assert passing


@pytest.mark.parametrize(
"start_end_step",
((-10, 57, 1),), # simple
[True, False],
)
@pytest.mark.parametrize(
"output_dtype",
(
torch.bfloat16,
torch.int32,
torch.float32,
),
ids=["bfloat16", "int32", "float32"],
"dtype",
[None, "bfloat16", "int32", "float32"],
)
def test_arange_tilized_dtype(start_end_step, output_dtype, device):
# Prepare and compute by torch
start, end, step = start_end_step
tt_dtype = get_tt_dtype(output_dtype)
tt_cpu = torch.arange(start=start, end=end, step=step).to(output_dtype)
any_cpu = torch.ones((1024))
any = ttnn.Tensor(any_cpu, tt_dtype).to(device)
untilize_out = False

# Compute by ttnn
tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out, output_dtype=tt_dtype)
tt_dev = tt_npu.cpu().to_torch()
L = tt_cpu.shape[0]
tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L)).to(output_dtype)

# Compare
assert tt_dev.shape == tt_cpu.shape
rtol = atol = 0.1
passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol)
logger.info(out)
assert passing
def test_arange_callback(start_end_step, optional_output, dtype, device, use_program_cache):
"""Test arange functionality with callback and program cache validation."""
num_program_cache_entries_list = []
for i in range(4):
if i % 2 == 0:
run_moreh_arange(start_end_step, optional_output, dtype, True, device)
else:
run_moreh_arange(start_end_step, optional_output, dtype, False, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_npu(torch_dummy, device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list == [1, 2, 2, 2]
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#define TILE_WIDTH 32

void kernel_main() {
// Kernel args
uint32_t dst_addr = get_arg_val<uint32_t>(0);
uint32_t tile_offset = get_arg_val<uint32_t>(1);
uint32_t num_tiles = get_arg_val<uint32_t>(2);
Expand Down Expand Up @@ -39,7 +38,7 @@ void kernel_main() {

uint32_t w_addr = get_write_ptr(cb_out);

#ifdef OUTPUT_DTYPE_BFLOAT16
#ifdef OUTPUT_DTYPE_BFLOAT16
auto ptr = reinterpret_cast<uint16_t *>(w_addr);
for (uint32_t w = 0; w < 16; w++) {
int32_t idx = w + tile_idx * TILE_WIDTH;
Expand All @@ -53,8 +52,8 @@ void kernel_main() {
val.f = start_u.f + step_u.f * idx;
ptr[w + 256] = uint16_t(val.u >> 16);
}
#endif
#ifdef OUTPUT_DTYPE_INT32
#endif
#ifdef OUTPUT_DTYPE_INT32
auto ptr = reinterpret_cast<uint32_t *>(w_addr);
for (uint32_t w = 0; w < 16; w++) {
int32_t idx = w + tile_idx * TILE_WIDTH;
Expand All @@ -68,8 +67,8 @@ void kernel_main() {
val = start_u.f + step_u.f * idx;
ptr[w + 256] = val;
}
#endif
#ifdef OUTPUT_DTYPE_FLOAT32
#endif
#ifdef OUTPUT_DTYPE_FLOAT32
auto ptr = reinterpret_cast<uint32_t *>(w_addr);
for (uint32_t w = 0; w < 16; w++) {
int32_t idx = w + tile_idx * TILE_WIDTH;
Expand All @@ -83,7 +82,7 @@ void kernel_main() {
val.f = start_u.f + step_u.f * idx;
ptr[w + 256] = val.u;
}
#endif
#endif

uint64_t dst_noc_addr = get_noc_addr(tile_idx, s0);
noc_async_write(w_addr, dst_noc_addr, num_bytes_per_tile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#define TILE_WIDTH 32

void kernel_main() {
// Kernel args
uint32_t dst_addr = get_arg_val<uint32_t>(0);
uint32_t tile_offset = get_arg_val<uint32_t>(1);
uint32_t num_tiles = get_arg_val<uint32_t>(2);
Expand Down
Loading

0 comments on commit e53f39d

Please sign in to comment.