From e53f39d1e174d14326a0439e5a5b2e81bd8df5ac Mon Sep 17 00:00:00 2001 From: Shaw Nguyen Date: Fri, 27 Sep 2024 08:27:27 +0000 Subject: [PATCH] #13137: Revise moreh_arange operation --- .../operations/test_moreh_arange.py | 296 ++++++------------ .../device/kernels/writer_moreh_arange.cpp | 13 +- .../device/kernels/writer_moreh_arange_rm.cpp | 1 - .../device/moreh_arange_device_operation.cpp | 51 +-- .../device/moreh_arange_device_operation.hpp | 16 +- .../device/moreh_arange_program_factory.cpp | 51 ++- .../moreh/moreh_arange/moreh_arange.cpp | 10 +- .../moreh/moreh_arange/moreh_arange.hpp | 6 +- .../moreh_arange/moreh_arange_pybind.cpp | 12 +- 9 files changed, 181 insertions(+), 275 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_arange.py b/tests/ttnn/unit_tests/operations/test_moreh_arange.py index d86d656ac36..80a70ab9df0 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_arange.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_arange.py @@ -8,226 +8,136 @@ import ttnn from models.utility_functions import comp_allclose_and_pcc +from tests.ttnn.unit_tests.operations.test_utils import to_npu -def get_tt_dtype(torch_dtype): - if torch_dtype == torch.int32: - return ttnn.int32 - if torch_dtype == torch.bfloat16: - return ttnn.bfloat16 - if torch_dtype == torch.float32: - return ttnn.float32 - return None +def get_lib_dtype(lib, dtype): + """Maps dtype to corresponding library dtype.""" + dtype_map = { + "int32": lib.int32, + "bfloat16": lib.bfloat16, + "float32": lib.float32, + } + return dtype_map.get(dtype, None) -def create_tt_tensor(tensor: torch.Tensor, dtype, device): - return ttnn.from_torch(tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=device) - - -@pytest.mark.parametrize( - "start_end_step", - ( - (-5, 27, 1), # simple - (2.3, 15.3, 0.5), # floating point - (10, 0, -0.3), # minus step - (10, 32 * 3, 1), # multiple cores - ), -) -def test_arange_row_major_simple(start_end_step, device): - # Prepare and compute by torch +def run_moreh_arange(start_end_step, optional_output, dtype, tilized, device): + """Run a comparison of arange results between torch and ttnn.""" + # Prepare inputs start, end, step = start_end_step - any_cpu = torch.ones((1024)) - any = create_tt_tensor(any_cpu, ttnn.bfloat16, device) - untilize_out = True - tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16) - - # Compute by ttnn - tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out) - tt_dev = tt_npu.cpu().to_torch() + if (dtype == "int32") and (start != int(start) or end != int(end) or step != int(step)): + pytest.skip(f"start/end/step must be integer when using int32 dtype") + + # Compute using torch + torch_dtype = get_lib_dtype(torch, dtype) + if torch_dtype is None: + torch_dtype = torch.bfloat16 + expected_output = torch.arange(start=start, end=end, step=step).to(torch_dtype) + + # Compute using ttnn + ttnn_dtype = get_lib_dtype(ttnn, dtype) + if ttnn_dtype is None: + ttnn_dtype = ttnn.bfloat16 + any_cpu = torch.ones(1024) + any_npu = ttnn.from_torch(any_cpu, dtype=ttnn_dtype, device=device) + if tilized: + L = expected_output.shape[0] + if optional_output: + output_cpu = torch.empty_like(expected_output) + if tilized: + output_npu = ( + ttnn.from_torch(output_cpu, ttnn_dtype) + .reshape([1, L]) + .pad_to_tile(float("nan")) + .to(ttnn.TILE_LAYOUT) + .to(device) + ) + else: + output_npu = ttnn.from_torch(output_cpu, dtype=ttnn_dtype, device=device) + else: + output_npu = None + output_npu = ttnn.operations.moreh.arange( + start, + end, + step, + any_npu, + output=output_npu, + untilize_out=not tilized, + dtype=get_lib_dtype(ttnn, dtype), + ) + if tilized: + actual_output = output_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L)) + else: + actual_output = output_npu.cpu().to_torch() - # Compare - assert tt_dev.shape == tt_cpu.shape - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) + # Assert shape and value comparison + assert actual_output.shape == expected_output.shape + passing, out = comp_allclose_and_pcc(expected_output, actual_output, rtol=0.1, atol=0.1) logger.info(out) assert passing @pytest.mark.parametrize( "start_end_step", - ((0, 32 * 10, 1),), # simple + [ + [0, 32, 1], + [2.3, 15.3, 0.5], + [10.9, -13, -0.3], + [-100, 32 * 10, 1], + [0, 32000, 1], + [2300.3, 15300.3, 0.5392], + [10900.9, -13000, -0.3111], + [-10000, 32 * 10000, 1], + ], ) @pytest.mark.parametrize( "optional_output", - ( - True, - False, - ), -) -def test_arange_row_major_optional_output(start_end_step, optional_output, device): - # Prepare and compute by torch - start, end, step = start_end_step - any_cpu = torch.ones((1024)) - any = ttnn.Tensor(any_cpu, ttnn.bfloat16).to(device) - untilize_out = True - tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16) - - # Compute by ttnn - if optional_output: - output_cpu = torch.empty_like(tt_cpu) - output = ttnn.from_torch(output_cpu, dtype=ttnn.bfloat16, device=device) - tt_npu = ttnn.operations.moreh.arange(start, end, step, any, output_tensor=output, untilize_out=untilize_out) - else: - tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out) - - tt_dev = tt_npu.cpu().to_torch() - - # Compare - assert tt_dev.shape == tt_cpu.shape - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing - - -@pytest.mark.parametrize( - "start_end_step", - ((-10, 22, 1),), # simple + [True, False], ) @pytest.mark.parametrize( - "output_dtype", - ( - torch.bfloat16, - torch.int32, - torch.float32, - ), - ids=["bfloat16", "int32", "float32"], + "dtype", + [None, "bfloat16", "int32", "float32"], ) -def test_arange_row_major_dtype(start_end_step, output_dtype, device): - # Prepare and compute by torch - start, end, step = start_end_step - tt_dtype = get_tt_dtype(output_dtype) - tt_cpu = torch.arange(start=start, end=end, step=step).to(output_dtype) - any_cpu = torch.ones((1024)) - any = create_tt_tensor(any_cpu, tt_dtype, device) - untilize_out = True - - # Compute by ttnn - tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out, output_dtype=tt_dtype) - tt_dev = tt_npu.cpu().to_torch() - - # Compare - assert tt_dev.shape == tt_cpu.shape - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing - - @pytest.mark.parametrize( - "start_end_step", - ( - (0, 32, 1), # simple - (2.3, 15.7, 0.5), # floating point - (10, 0, -0.3), # minus step - (10, 32 * 3, 1), # multiple cores - ), + "tilized", + [True, False], ) -def test_arange_tilized_simple(start_end_step, device): - # Prepare and compute by torch - start, end, step = start_end_step - tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16) - any_cpu = torch.ones((1024)) - any = create_tt_tensor(any_cpu, ttnn.bfloat16, device) - - # Compute by ttnn - tt_npu = ttnn.operations.moreh.arange(start, end, step, any) - L = tt_cpu.shape[0] - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L)).to(torch.bfloat16) - - # Compare - assert tt_dev.shape == tt_cpu.shape - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing +def test_arange(start_end_step, optional_output, dtype, tilized, device): + """Test arange functionality with different parameters.""" + run_moreh_arange(start_end_step, optional_output, dtype, tilized, device) @pytest.mark.parametrize( "start_end_step", - ((0, 32 * 10, 1),), # simple + [ + [0, 32, 1], + [2.3, 15.3, 0.5], + [10.9, -13, -0.3], + [-100, 32 * 10, 1], + [0, 32000, 1], + [2300.3, 15300.3, 0.5392], + [10900.9, -13000, -0.3111], + [-10000, 32 * 10000, 1], + ], ) @pytest.mark.parametrize( "optional_output", - ( - True, - False, - ), -) -def test_arange_tilized_major_optional_output(start_end_step, optional_output, device): - # Prepare and compute by torch - start, end, step = start_end_step - tt_cpu = torch.arange(start=start, end=end, step=step).to(torch.bfloat16) - L = tt_cpu.shape[0] - any_cpu = torch.ones((1024)) - any = create_tt_tensor(any_cpu, ttnn.bfloat16, device) - untilize_out = False - - # Compute by ttnn - if optional_output: - output_cpu = torch.empty_like(tt_cpu) - output = ( - ttnn.from_torch(output_cpu, ttnn.bfloat16) - .reshape([1, L]) - .pad_to_tile(float("nan")) - .to(ttnn.TILE_LAYOUT) - .to(device) - ) - tt_npu = ttnn.operations.moreh.arange(start, end, step, any, output_tensor=output, untilize_out=untilize_out) - else: - tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out) - tt_dev = tt_npu.cpu().to_torch() - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L)).to(torch.bfloat16) - - # Compare - assert tt_dev.shape == tt_cpu.shape - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing - - -@pytest.mark.parametrize( - "start_end_step", - ((-10, 57, 1),), # simple + [True, False], ) @pytest.mark.parametrize( - "output_dtype", - ( - torch.bfloat16, - torch.int32, - torch.float32, - ), - ids=["bfloat16", "int32", "float32"], + "dtype", + [None, "bfloat16", "int32", "float32"], ) -def test_arange_tilized_dtype(start_end_step, output_dtype, device): - # Prepare and compute by torch - start, end, step = start_end_step - tt_dtype = get_tt_dtype(output_dtype) - tt_cpu = torch.arange(start=start, end=end, step=step).to(output_dtype) - any_cpu = torch.ones((1024)) - any = ttnn.Tensor(any_cpu, tt_dtype).to(device) - untilize_out = False - - # Compute by ttnn - tt_npu = ttnn.operations.moreh.arange(start, end, step, any, untilize_out=untilize_out, output_dtype=tt_dtype) - tt_dev = tt_npu.cpu().to_torch() - L = tt_cpu.shape[0] - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile((1, L)).to_torch().reshape((L)).to(output_dtype) - - # Compare - assert tt_dev.shape == tt_cpu.shape - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing +def test_arange_callback(start_end_step, optional_output, dtype, device, use_program_cache): + """Test arange functionality with callback and program cache validation.""" + num_program_cache_entries_list = [] + for i in range(4): + if i % 2 == 0: + run_moreh_arange(start_end_step, optional_output, dtype, True, device) + else: + run_moreh_arange(start_end_step, optional_output, dtype, False, device) + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) + num_program_cache_entries_list.append(device.num_program_cache_entries()) + logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}") + assert num_program_cache_entries_list == [1, 2, 2, 2] diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange.cpp index 640845c7605..3b4118ccccb 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange.cpp @@ -9,7 +9,6 @@ #define TILE_WIDTH 32 void kernel_main() { - // Kernel args uint32_t dst_addr = get_arg_val(0); uint32_t tile_offset = get_arg_val(1); uint32_t num_tiles = get_arg_val(2); @@ -39,7 +38,7 @@ void kernel_main() { uint32_t w_addr = get_write_ptr(cb_out); - #ifdef OUTPUT_DTYPE_BFLOAT16 +#ifdef OUTPUT_DTYPE_BFLOAT16 auto ptr = reinterpret_cast(w_addr); for (uint32_t w = 0; w < 16; w++) { int32_t idx = w + tile_idx * TILE_WIDTH; @@ -53,8 +52,8 @@ void kernel_main() { val.f = start_u.f + step_u.f * idx; ptr[w + 256] = uint16_t(val.u >> 16); } - #endif - #ifdef OUTPUT_DTYPE_INT32 +#endif +#ifdef OUTPUT_DTYPE_INT32 auto ptr = reinterpret_cast(w_addr); for (uint32_t w = 0; w < 16; w++) { int32_t idx = w + tile_idx * TILE_WIDTH; @@ -68,8 +67,8 @@ void kernel_main() { val = start_u.f + step_u.f * idx; ptr[w + 256] = val; } - #endif - #ifdef OUTPUT_DTYPE_FLOAT32 +#endif +#ifdef OUTPUT_DTYPE_FLOAT32 auto ptr = reinterpret_cast(w_addr); for (uint32_t w = 0; w < 16; w++) { int32_t idx = w + tile_idx * TILE_WIDTH; @@ -83,7 +82,7 @@ void kernel_main() { val.f = start_u.f + step_u.f * idx; ptr[w + 256] = val.u; } - #endif +#endif uint64_t dst_noc_addr = get_noc_addr(tile_idx, s0); noc_async_write(w_addr, dst_noc_addr, num_bytes_per_tile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange_rm.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange_rm.cpp index 1857aaf75d7..16756f34102 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange_rm.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/kernels/writer_moreh_arange_rm.cpp @@ -9,7 +9,6 @@ #define TILE_WIDTH 32 void kernel_main() { - // Kernel args uint32_t dst_addr = get_arg_val(0); uint32_t tile_offset = get_arg_val(1); uint32_t num_tiles = get_arg_val(2); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp index e1a72b0d986..fd29647e5b1 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "moreh_arange_device_operation.hpp" + #include "tt_dnn/op_library/moreh_helper_functions.hpp" #include "ttnn/tensor/tensor.hpp" @@ -17,24 +18,24 @@ void MorehArangeOperation::validate_inputs( ((step > 0) && (end >= start)) || ((step < 0) && (end <= start)), "Upper bound and larger bound inconsistent with step sign."); - const auto& output_dtype = operation_attributes.output_dtype; - TT_FATAL(output_dtype != DataType::BFLOAT8_B, "moreh arange not support bfloat8_b dtype."); - TT_FATAL(output_dtype != DataType::UINT32, "moreh arange not support uint32 dtype."); + const auto& dtype = operation_attributes.dtype; + TT_FATAL(dtype != DataType::BFLOAT8_B, "moreh arange not support bfloat8_b dtype."); + TT_FATAL(dtype != DataType::UINT32, "moreh arange not support uint32 dtype."); - const auto& output_tensor = tensor_args.output_tensor; - if (!output_tensor.has_value()) + const auto& output = tensor_args.output; + if (!output.has_value()) return; - TT_FATAL(output_tensor->buffer() != nullptr, "Must have 1 output tensor."); + TT_FATAL(output->buffer() != nullptr, "Must have 1 output tensor."); TT_FATAL( - output_dtype == output_tensor->get_dtype(), - "If output_tensor is provided as input, its dtype should match the output_dtype parameter."); - TT_FATAL(output_tensor->memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); + dtype == output->get_dtype(), "If output is provided as input, its dtype should match the dtype parameter."); + TT_FATAL(output->memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); - auto output_layout = output_tensor->get_layout(); + auto output_layout = output->get_layout(); if (operation_attributes.untilize_out) { - TT_FATAL(output_layout == Layout::ROW_MAJOR, "Error"); + TT_FATAL( + output_layout == Layout::ROW_MAJOR, "Error: output_layout must be Layout::ROW_MAJOR when untilize_out"); } else { - TT_FATAL(output_layout == Layout::TILE, "Error"); + TT_FATAL(output_layout == Layout::TILE, "Error: output_layout must be Layout::TILE when !untilize_out"); } } @@ -61,7 +62,7 @@ MorehArangeOperation::shape_return_value_t MorehArangeOperation::compute_output_ if (operation_attributes.untilize_out) return ttnn::Shape(tt::tt_metal::LegacyShape({num_elems})); - std::vector output_size_vec = { + std::vector output_size = { tt::constants::TILE_HEIGHT, tt::round_up(num_elems, tt::constants::TILE_WIDTH)}; auto dimensions_pads = std::vector(); @@ -70,20 +71,20 @@ MorehArangeOperation::shape_return_value_t MorehArangeOperation::compute_output_ Padding::PadDimension{.front = 0, .back = tt::round_up(num_elems, tt::constants::TILE_WIDTH) - num_elems}); const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - return ttnn::Shape{tt::tt_metal::LegacyShape(output_size_vec, padding)}; + return ttnn::Shape{tt::tt_metal::LegacyShape(output_size, padding)}; }; MorehArangeOperation::tensor_return_value_t MorehArangeOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& output_tensor = tensor_args.output_tensor; - if (output_tensor.has_value()) - return {output_tensor.value()}; + const auto& output = tensor_args.output; + if (output.has_value()) + return output.value(); return create_device_tensor( compute_output_shapes(operation_attributes, tensor_args), - operation_attributes.output_dtype, + operation_attributes.dtype, operation_attributes.untilize_out ? Layout::ROW_MAJOR : Layout::TILE, tensor_args.any.device(), - operation_attributes.output_memory_config); + operation_attributes.memory_config); } std::tuple @@ -92,22 +93,22 @@ MorehArangeOperation::invoke( float end, float step, const Tensor& any, - const std::optional& output_tensor, + const std::optional& output, bool untilize_out, - const std::optional& output_dtype, - const std::optional& output_memory_config) { + const std::optional& dtype, + const std::optional& memory_config) { return { operation_attributes_t{ start, end, step, untilize_out, - output_dtype.value_or(any.get_dtype()), - output_memory_config.value_or(any.memory_config()), + dtype.value_or(any.get_dtype()), + memory_config.value_or(any.memory_config()), }, tensor_args_t{ any, - output_tensor, + output, }, }; } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.hpp index 5ba687d6f53..23540113725 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_device_operation.hpp @@ -12,13 +12,13 @@ struct MorehArangeOperation { float end; float step; bool untilize_out; - const DataType output_dtype; - const MemoryConfig output_memory_config; + const DataType dtype; + const MemoryConfig memory_config; }; struct tensor_args_t { const Tensor& any; - const std::optional& output_tensor; + const std::optional& output; }; using shape_return_value_t = Shape; @@ -36,13 +36,13 @@ struct MorehArangeOperation { static cached_program_t create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, - tensor_return_value_t& output_tensor); + tensor_return_value_t& output); static void override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, - tensor_return_value_t& output_tensor); + tensor_return_value_t& output); }; using program_factory_t = std::variant; @@ -59,10 +59,10 @@ struct MorehArangeOperation { float end, float step, const Tensor& any, - const std::optional& output_tensor, + const std::optional& output, bool untilize_out, - const std::optional& output_dtype, - const std::optional& output_memory_config); + const std::optional& dtype, + const std::optional& memory_config); }; } // namespace ttnn::operations::moreh::moreh_arange diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_program_factory.cpp index c233379c9de..71a9347fd8a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/device/moreh_arange_program_factory.cpp @@ -10,8 +10,9 @@ namespace ttnn::operations::moreh::moreh_arange { MorehArangeOperation::ProgramFactory::cached_program_t MorehArangeOperation::ProgramFactory::create( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, - tensor_return_value_t& output_tensor) { - auto W = output_tensor.get_legacy_shape()[-1]; + tensor_return_value_t& output) { + auto dtype = output.get_dtype(); + auto W = output.get_legacy_shape()[-1]; auto Wt = tt::div_up(W, tt::constants::TILE_WIDTH); auto start = operation_attributes.start; @@ -25,27 +26,24 @@ MorehArangeOperation::ProgramFactory::cached_program_t MorehArangeOperation::Pro Program program = Program(); // Create circular buffer - auto data_format = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype()); tt::operations::primary::CreateCircularBuffer( program, all_cores, - data_format, + tt::tt_metal::datatype_to_dataformat_converter(dtype), { {tt::CB::c_out0, 1}, }); // Create write kernel std::map writer_defines; - if (output_tensor.get_dtype() == DataType::BFLOAT16) { - writer_defines["OUTPUT_DTYPE_BFLOAT16"] = 1; + switch (dtype) { + case DataType::BFLOAT16: writer_defines["OUTPUT_DTYPE_BFLOAT16"] = "1"; break; + case DataType::INT32: writer_defines["OUTPUT_DTYPE_INT32"] = "1"; break; + case DataType::FLOAT32: writer_defines["OUTPUT_DTYPE_FLOAT32"] = "1"; break; + default: break; } - if (output_tensor.get_dtype() == DataType::INT32) { - writer_defines["OUTPUT_DTYPE_INT32"] = 1; - } - if (output_tensor.get_dtype() == DataType::FLOAT32) { - writer_defines["OUTPUT_DTYPE_FLOAT32"] = 1; - } - bool dst_is_dram = output_tensor.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + + uint32_t dst_is_dram = output.buffer()->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; auto kernel_id = tt::operations::primary::CreateWriteKernel( program, operation_attributes.untilize_out @@ -55,25 +53,24 @@ MorehArangeOperation::ProgramFactory::cached_program_t MorehArangeOperation::Pro {dst_is_dram}, writer_defines); - // Set RuntimeArgs + // Set runtime arguments uint32_t core_h = grid.y; for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { CoreCoord core = {i / core_h, i % core_h}; uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { + if (core_group_1.core_coord_in_core_ranges(core)) num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { + else if (core_group_2.core_coord_in_core_ranges(core)) num_tiles_per_core = num_tiles_per_core_group_2; - } else { + else TT_FATAL(false, "Core not in specified core ranges"); - } - vector writer_args = { - output_tensor.buffer()->address(), + std::vector writer_args = { + output.buffer()->address(), tile_offset, num_tiles_per_core, *reinterpret_cast(&start), *reinterpret_cast(&step), - output_tensor.element_size()}; + output.element_size()}; SetRuntimeArgs(program, kernel_id, core, writer_args); tile_offset += num_tiles_per_core; } @@ -84,17 +81,17 @@ void MorehArangeOperation::ProgramFactory::override_runtime_arguments( cached_program_t& cached_program, const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, - tensor_return_value_t& output_tensor) { - auto& program = cached_program.program; - auto& kernel_id = cached_program.shared_variables.kernel_id; + tensor_return_value_t& output) { + const auto& program = cached_program.program; + const auto& kernel_id = cached_program.shared_variables.kernel_id; auto num_cores = cached_program.shared_variables.num_cores; auto core_h = cached_program.shared_variables.core_h; - auto src_dram_buffer = output_tensor.buffer(); + auto src_dram_buffer_address = output.buffer()->address(); - for (uint32_t icore = 0; icore < num_cores; icore++) { + for (uint32_t icore = 0; icore < num_cores; ++icore) { CoreCoord core = {icore / core_h, icore % core_h}; auto& runtime_args = GetRuntimeArgs(program, kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); + runtime_args[0] = src_dram_buffer_address; } } } // namespace ttnn::operations::moreh::moreh_arange diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.cpp index d0e71e35680..0d0d24944b6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "moreh_arange.hpp" + #include "device/moreh_arange_device_operation.hpp" namespace ttnn::operations::moreh::moreh_arange { @@ -11,11 +12,10 @@ Tensor MorehArange::invoke( float end, float step, const Tensor& any, - const std::optional& output_tensor, + const std::optional& output, bool untilize_out, - const std::optional& output_dtype, - const std::optional& output_memory_config) { - return ttnn::prim::moreh_arange( - start, end, step, any, output_tensor, untilize_out, output_dtype, output_memory_config); + const std::optional& dtype, + const std::optional& memory_config) { + return ttnn::prim::moreh_arange(start, end, step, any, output, untilize_out, dtype, memory_config); } } // namespace ttnn::operations::moreh::moreh_arange diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.hpp index 2586378814a..6f1c37b0a83 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange.hpp @@ -13,10 +13,10 @@ struct MorehArange { float end, float step, const Tensor& any, - const std::optional& output_tensor, + const std::optional& output, bool untilize_out, - const std::optional& output_dtype, - const std::optional& output_memory_config); + const std::optional& dtype, + const std::optional& memory_config); }; } // namespace ttnn::operations::moreh::moreh_arange diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange_pybind.cpp index 72d9c71ce20..66bd9f201f7 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_arange/moreh_arange_pybind.cpp @@ -4,8 +4,8 @@ #include "moreh_arange_pybind.hpp" -#include "pybind11/decorators.hpp" #include "moreh_arange.hpp" +#include "pybind11/decorators.hpp" namespace ttnn::operations::moreh::moreh_arange { void bind_moreh_arange_operation(py::module& module) { @@ -14,15 +14,15 @@ void bind_moreh_arange_operation(py::module& module) { ttnn::moreh_arange, "Moreh Arange Operation", ttnn::pybind_arguments_t{ - py::arg("start"), + py::arg("start") = 0, py::arg("end"), - py::arg("step"), + py::arg("step") = 1, py::arg("any"), py::kw_only(), - py::arg("output_tensor") = std::nullopt, + py::arg("output") = std::nullopt, py::arg("untilize_out") = false, - py::arg("output_dtype") = std::nullopt, - py::arg("output_memory_config") = std::nullopt, + py::arg("dtype") = std::nullopt, + py::arg("memory_config") = std::nullopt, }); } } // namespace ttnn::operations::moreh::moreh_arange