diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py index 41bb90bc1e7..343bfdc6574 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py @@ -255,3 +255,252 @@ def test_01_volume_tensors(device, a, b, c_golden, memory_config): assert c.tolist() == c_golden + +@pytest.mark.parametrize( + "input_shapes", + ( + # (torch.Size([5, 3, 32, 32]), torch.Size([5, 3, 32, 32])), + # (torch.Size([1, 3, 64, 64]), torch.Size([5, 3, 64, 64])), # batch bcast + # (torch.Size([2, 1, 1, 64]), torch.Size([2, 1, 32, 1])), # rowA colB bcast + # (torch.Size([2, 1, 1, 64]), torch.Size([2, 1, 128, 1])), # rowA colB bcast + (torch.Size([2, 1, 2, 2]), torch.Size([2, 1, 2, 2])), # rowA colB bcast + # (torch.Size([5, 3, 32, 64]), torch.Size([5, 3, 32, 64])), + # (torch.Size([5, 3, 64, 32]), torch.Size([5, 3, 64, 32])), + # (torch.Size([5,3,1,1]), torch.Size([5,3,1,1])), # (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])), + # (torch.Size([5, 3, 64, 32]), torch.Size([5, 3, 1, 32])), # (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])), + ), +) +@pytest.mark.parametrize( + "ttnn_fn", + [ + ttnn.experimental.sub, + ], +) +def test_binary_ng(input_shapes, ttnn_fn, device): + a_shape, b_shape = input_shapes + # a_pt = torch.rand(a_shape).bfloat16() + # b_pt = torch.rand(b_shape).bfloat16() + a_pt = torch.ones(a_shape, dtype=torch.bfloat16) * 1 + # b_pt = torch.ones(b_shape, dtype=torch.bfloat16) * 7 + b_pt = 0.1111111 + + a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) + # b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG) + b_tt = 0.1111111 + cq_id = 0 + out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id) + golden_fn = ttnn.get_golden_function(ttnn_fn) + out_pt = golden_fn(a_pt, b_pt) + torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17) + print(ttnn.to_torch(out_tt)) + print(out_pt) + print(a_pt - b_pt) + # comp_pass = compare_pcc([out_tt], [out_pt]) + comp_pass = ttnn.pearson_correlation_coefficient(out_pt, out_tt) + assert comp_pass >= 0.99988 + + +@pytest.mark.parametrize( + "input_shapes", + [ + [[2, 1, 2, 2], [2, 1, 2, 2]], + [[2, 1, 1, 1], [2, 1, 2, 2]], + [[1, 1, 4, 1], [1, 1, 1, 4]], + [[1, 1, 4, 4], [1, 1, 1, 4]], + [[1, 1, 4, 4], [1, 1, 1, 1]], + # [[5, 1, 64, 1], [1, 3, 1, 128]], + ], +) +@pytest.mark.parametrize( + "ttnn_fn", + [ + ttnn.experimental.sub, + ], +) +def test_binary_ng_fp32(input_shapes, ttnn_fn, device): + [a_shape, b_shape] = input_shapes + x_torch = torch.ones(a_shape, dtype=torch.float32) * 2 + y_torch = torch.ones(b_shape, dtype=torch.float32) * 0.00030171126 + # y_torch = 0.00030171126 + golden_fn = ttnn.get_golden_function(ttnn_fn) + z_torch = torch.sub(torch.square(x_torch), -y_torch) + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + # y_tt = 0.00030171126 + z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG]) + # ttnn.set_printoptions(profile="full") + # print("tt ", z_tt_sub) + tt_out = ttnn.to_torch(z_tt_sub) + + torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17) + print("torch", z_torch) + print("tt ", tt_out) + + status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@pytest.mark.parametrize( + "input_shapes", + [ + [[2, 1, 2, 2], [2, 1, 2, 2]], + [[2, 1, 1, 1], [2, 1, 2, 2]], + [[1, 1, 4, 1], [1, 1, 1, 4]], + [[1, 1, 4, 4], [1, 1, 1, 4]], + [[1, 1, 4, 4], [1, 1, 1, 1]], + # [[5, 1, 64, 1], [1, 3, 1, 128]], + ], +) +@pytest.mark.parametrize( + "ttnn_fn", + [ + ttnn.experimental.sub, + ], +) +def test_binary_ng_fp32sc(input_shapes, ttnn_fn, device): + [a_shape, b_shape] = input_shapes + x_torch = torch.ones(a_shape, dtype=torch.float32) * 2 + y_torch = 0.00030171126 + golden_fn = ttnn.get_golden_function(ttnn_fn) + z_torch = torch.sub(torch.square(x_torch), -y_torch) + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + y_tt = 0.00030171126 + z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG]) + # ttnn.set_printoptions(profile="full") + # print("tt ", z_tt_sub) + tt_out = ttnn.to_torch(z_tt_sub) + + torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17) + print("torch", z_torch) + print("tt ", tt_out) + + status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@pytest.mark.parametrize( + "input_shapes", + ((torch.Size([2, 1, 2, 2]), torch.Size([2, 1, 2, 2])),), + # ((torch.Size([2, 1, 1, 2]), torch.Size([2, 1, 2, 2])),), + # ((torch.Size([2, 1, 1, 32]), torch.Size([2, 1, 32, 1])),), + # ((torch.Size([2, 1, 2, 2]), torch.Size([2, 1, 1, 1])),), + # (torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])), +) +@pytest.mark.parametrize( + "ttnn_fn", + [ + ttnn.experimental.add, + ], +) +def test_binary_ng_int32(input_shapes, ttnn_fn, device): + a_shape, b_shape = input_shapes + x_torch = torch.ones(a_shape, dtype=torch.int32) * 2 + y_torch = torch.ones(b_shape, dtype=torch.int32) * -10 + # y_torch = -10 + golden_fn = ttnn.get_golden_function(ttnn_fn) + z_torch = golden_fn(x_torch, y_torch) + # z_torch = torch.add(torch.square(x_torch), y_torch) + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device) + y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device) + # y_tt = -10 + z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device) + z_tt_sub = ttnn_fn( + x_tt, + y_tt, + ) # lhs_activations=[ttnn.UnaryOpType.SQUARE] + tt_out = ttnn.to_torch(z_tt_sub) + + torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17) + print("torch", z_torch) + print("tt ", tt_out) + + status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@pytest.mark.parametrize( + "input_shapes", + [ + [[2, 1, 2, 2], [2, 1, 2, 2]], + # [[2, 1, 1, 1], [2, 1, 2, 2]], + # [[1, 1, 4, 1], [1, 1, 1, 4]], + # [[1, 1, 4, 4], [1, 1, 1, 4]], + # [[1, 1, 1, 1], [1, 1, 4, 4]], + # [[5, 1, 64, 1], [1, 3, 1, 128]], + ], +) +@pytest.mark.parametrize( + "ttnn_fn", + [ + ttnn.experimental.sub, + ], +) +def test_binary_ng_fp32_activ(input_shapes, ttnn_fn, device): + a_shape, b_shape = input_shapes + x_torch = torch.ones(a_shape, dtype=torch.float32) * 2 + # y_torch = torch.ones(b_shape, dtype=torch.float32) * 0.00030171126 + y_torch = 0.00030171122 + golden_fn = ttnn.get_golden_function(ttnn_fn) + # z_torch = golden_fn(x_torch, y_torch) + z_torch = torch.sub(torch.square(x_torch), -y_torch) + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + # y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + y_tt = 0.00030171122 + z_tt = ttnn.from_torch(z_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + # z_tt_sub = ttnn_fn(x_tt, y_tt,) + z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG]) + # ttnn.set_printoptions(profile="full") + # print("tt ", z_tt_sub) + tt_out = ttnn.to_torch(z_tt_sub) + + torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17) + print("torch", z_torch) + print("tt ", tt_out) + + status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status + + +@pytest.mark.parametrize( + "input_shapes", + [ + [[2, 1, 2, 2], [2, 1, 2, 2]], + # [[2, 1, 1, 1], [2, 1, 2, 2]], + # [[1, 1, 4, 1], [1, 1, 1, 4]], + # [[1, 1, 4, 4], [1, 1, 1, 4]], + # [[1, 1, 1, 1], [1, 1, 4, 4]], + # [[5, 1, 64, 1], [1, 3, 1, 128]], + ], +) +@pytest.mark.parametrize( + "ttnn_fn", + [ + ttnn.experimental.sub, + ], +) +def test_binary_ng_bf16_activ(input_shapes, ttnn_fn, device): + a_shape, b_shape = input_shapes + x_torch = torch.ones(a_shape, dtype=torch.bfloat16) * 2 + # y_torch = torch.ones(b_shape, dtype=torch.bfloat16) * 0.00030171126 + y_torch = 0.00030171122 + golden_fn = ttnn.get_golden_function(ttnn_fn) + # z_torch = golden_fn(x_torch, y_torch) + z_torch = torch.sub(torch.square(x_torch), -y_torch) + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + # y_tt = ttnn.from_torch(y_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + y_tt = 0.00030171122 + z_tt = ttnn.from_torch(z_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + # z_tt_sub = ttnn_fn(x_tt, y_tt,) + z_tt_sub = ttnn_fn(x_tt, y_tt, lhs_activations=[ttnn.UnaryOpType.SQUARE], rhs_activations=[ttnn.UnaryOpType.NEG]) + # ttnn.set_printoptions(profile="full") + # print("tt ", z_tt_sub) + tt_out = ttnn.to_torch(z_tt_sub) + + torch.set_printoptions(linewidth=200, threshold=10000, precision=15, sci_mode=False, edgeitems=17) + print("torch", z_torch) + print("tt ", tt_out) + + status = torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + assert status diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp index 4d844cab987..99dadbfdfc7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.cpp @@ -19,8 +19,13 @@ Tensor BinaryNg::invoke( tt::stl::Span lhs_activations, tt::stl::Span rhs_activations, tt::stl::Span post_activations) { - auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); - auto input_b = typecast_to(DataType::BFLOAT16, input_tensor_b); + Tensor input_a = input_tensor_a; + Tensor input_b = input_tensor_b; + if (input_tensor_a.get_dtype() == DataType::BFLOAT8_B || input_tensor_b.get_dtype() == DataType::BFLOAT8_B || + input_tensor_a.get_dtype() == DataType::BFLOAT4_B || input_tensor_b.get_dtype() == DataType::BFLOAT4_B) { + input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + input_b = typecast_to(DataType::BFLOAT16, input_tensor_b); + } return ttnn::prim::binary_ng( queue_id, @@ -68,7 +73,10 @@ Tensor BinaryNg::invoke( tt::stl::Span lhs_activations, tt::stl::Span rhs_activations, tt::stl::Span post_activations) { - auto input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + Tensor input_a = input_tensor_a; + if (input_tensor_a.get_dtype() == DataType::BFLOAT8_B || input_tensor_a.get_dtype() == DataType::BFLOAT4_B) { + input_a = typecast_to(DataType::BFLOAT16, input_tensor_a); + } return ttnn::prim::binary_ng( queue_id, @@ -109,6 +117,7 @@ template struct BinaryNg; template struct BinaryNg; template struct BinaryNg; template struct BinaryNg; +template struct BinaryNg; template struct BinaryNg; template struct BinaryNg; template struct BinaryNg; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp index 34cf4165e61..79ddce96e44 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng.hpp @@ -72,6 +72,10 @@ constexpr auto sub = ttnn::register_operation_with_auto_launch_op< "ttnn::experimental::sub", ttnn::operations::binary_ng::BinaryNg>(); +constexpr auto rsub = ttnn::register_operation_with_auto_launch_op< + "ttnn::experimental::rsub", + ttnn::operations::binary_ng::BinaryNg>(); + constexpr auto mul = ttnn::register_operation_with_auto_launch_op< "ttnn::experimental::mul", ttnn::operations::binary_ng::BinaryNg>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp index a392c476a8a..90fa92538e1 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/binary_ng_pybind.cpp @@ -91,6 +91,7 @@ void py_module(py::module& module) { detail::bind_binary_ng_operation(module, ttnn::experimental::sub, "Binary Sub Operation"); detail::bind_binary_ng_operation(module, ttnn::experimental::mul, "Binary Mul Operation"); detail::bind_binary_ng_operation(module, ttnn::experimental::div, "Binary Div Operation"); + detail::bind_binary_ng_operation(module, ttnn::experimental::rsub, "Binary Rsub Operation"); detail::bind_binary_ng_operation(module, ttnn::experimental::gt, "Binary Greater Than Operation"); detail::bind_binary_ng_operation(module, ttnn::experimental::lt, "Binary Less Than Operation"); detail::bind_binary_ng_operation(module, ttnn::experimental::lte, "Binary Less Than or Equal To Operation"); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp index 15a53a734b8..0de973e1912 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.cpp @@ -8,6 +8,42 @@ using namespace tt::tt_metal; namespace ttnn::operations::binary_ng { +namespace utils { +bool is_binary_sfpu_op(BinaryOpType val, DataType a, DataType b) { + switch (val) { + case BinaryOpType::ADD: + return ( + (a == DataType::FLOAT32 && b == DataType::FLOAT32) || (a == DataType::INT32 && b == DataType::INT32)); + case BinaryOpType::SUB: + case BinaryOpType::MUL: + case BinaryOpType::DIV: + case BinaryOpType::RSUB: + case BinaryOpType::LOGADDEXP: + case BinaryOpType::LOGADDEXP2: + case BinaryOpType::LDEXP: + case BinaryOpType::SQUARED_DIFFERENCE: + case BinaryOpType::LOGICAL_OR: + case BinaryOpType::LOGICAL_XOR: + case BinaryOpType::LOGICAL_AND: + case BinaryOpType::BIAS_GELU: + case BinaryOpType::GT: + case BinaryOpType::LT: + case BinaryOpType::GTE: + case BinaryOpType::LTE: + case BinaryOpType::EQ: + case BinaryOpType::NE: return (a == DataType::FLOAT32 && b == DataType::FLOAT32); + case BinaryOpType::LEFT_SHIFT: + case BinaryOpType::RIGHT_SHIFT: + case BinaryOpType::BITWISE_XOR: + case BinaryOpType::BITWISE_AND: + case BinaryOpType::BITWISE_OR: return (a == DataType::INT32 && b == DataType::INT32); + case BinaryOpType::POWER: return true; + default: return false; + } + return false; +} +} // namespace utils + SubtileBroadcastType get_subtile_broadcast_type(uint32_t a_h, uint32_t a_w, uint32_t b_h, uint32_t b_w) { if (a_h == b_h && a_w == b_w) { return SubtileBroadcastType::NONE; @@ -49,7 +85,9 @@ tt::stl::hash::hash_t BinaryNgDeviceOperation::operation_attributes_t::to_hash() memory_config, get_dtype(), compute_kernel_config, - subtile_broadcast_type); + subtile_broadcast_type, + is_sfpu); + // should is_sfpu attribute be a part of this hash fn ? } DataType BinaryNgDeviceOperation::operation_attributes_t::get_dtype() const { @@ -216,6 +254,12 @@ BinaryNgDeviceOperation::invoke( input_tensor_b_arg.get_logical_shape()[-2], input_tensor_b_arg.get_logical_shape()[-1]); + DataType dtype1 = input_tensor_a_arg.get_dtype(); + DataType dtype2 = input_tensor_a_arg.get_dtype(); + bool device_check = input_tensor_a_arg.device()->arch() != tt::ARCH::GRAYSKULL; + bool is_sfpu_op = (utils::is_binary_sfpu_op(binary_op_type, dtype1, dtype2) && device_check); + std::cout << "is sfpu device op? " << is_sfpu_op << std::endl; + return { operation_attributes_t{ binary_op_type, @@ -227,7 +271,8 @@ BinaryNgDeviceOperation::invoke( input_tensor_a_arg.get_dtype(), output_dtype, std::nullopt, - subtile_broadcast_type}, + subtile_broadcast_type, + is_sfpu_op}, tensor_args_t{input_tensor_a_arg, input_tensor_b_arg, std::move(optional_output_tensor)}}; } @@ -242,6 +287,9 @@ BinaryNgDeviceOperation::invoke( tt::stl::Span lhs_activations, tt::stl::Span rhs_activations, tt::stl::Span post_activations) { + DataType dtype1 = input_tensor_a_arg.get_dtype(); + bool device_check = input_tensor_a_arg.device()->arch() != tt::ARCH::GRAYSKULL; + bool is_sfpu_op = (utils::is_binary_sfpu_op(binary_op_type, dtype1, dtype1) && device_check); return { operation_attributes_t{ binary_op_type, @@ -252,7 +300,9 @@ BinaryNgDeviceOperation::invoke( memory_config.value_or(input_tensor_a_arg.memory_config()), input_tensor_a_arg.get_dtype(), output_dtype, - std::nullopt}, + std::nullopt, + SubtileBroadcastType::NONE, + is_sfpu_op}, tensor_args_t{input_tensor_a_arg, std::nullopt, std::move(optional_output_tensor)}}; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp index 0603ecbee2c..410b053ef18 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_device_operation.hpp @@ -40,6 +40,7 @@ struct BinaryNgDeviceOperation { std::optional dtype; std::optional compute_kernel_config; SubtileBroadcastType subtile_broadcast_type = SubtileBroadcastType::NONE; + bool is_sfpu = false; tt::stl::hash::hash_t to_hash() const; DataType get_dtype() const; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp index 3121cc2e7ae..d8b697fe5e5 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_program_factory.cpp @@ -121,8 +121,12 @@ void set_or_update_runtime_arguments( std::array compute_runtime_args = {num_tiles_per_core, freq, counter}; handle_args(program, compute_kernel_id, core, compute_runtime_args); } else { - class bfloat16 bfloat_scalar(*operation_attributes.scalar); - uint32_t packed_scalar = pack_two_bfloat16_into_uint32({bfloat_scalar, bfloat_scalar}); + const auto scalar = *operation_attributes.scalar; + class bfloat16 bfloat_scalar(scalar); + const auto packed_scalar = a.get_dtype() == DataType::FLOAT32 ? std::bit_cast(scalar) + : a.get_dtype() == DataType::INT32 + ? std::bit_cast(static_cast(scalar)) + : pack_two_bfloat16_into_uint32({bfloat_scalar, bfloat_scalar}); std::array writer_runtime_args = { packed_scalar, c.buffer()->address(), @@ -158,15 +162,20 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio const auto& a = tensor_args.input_tensor_a; const auto& b = tensor_args.input_tensor_b; + auto is_sfpu_op = operation_attributes.is_sfpu; auto program = CreateProgram(); auto* device = a.device(); auto a_data_format = datatype_to_dataformat_converter(a.get_dtype()); - auto b_data_format = b.has_value() ? datatype_to_dataformat_converter(b->get_dtype()) : DataFormat::Float16_b; + auto b_data_format = b.has_value() ? datatype_to_dataformat_converter(b->get_dtype()) + : is_sfpu_op ? datatype_to_dataformat_converter(a.get_dtype()) + : DataFormat::Float16_b; auto c_data_format = datatype_to_dataformat_converter(c.get_dtype()); + tt::log_info(tt::LogOp, "******** c_data_format : {}", c_data_format); + uint32_t a_single_tile_size = tt_metal::detail::TileSize(a_data_format); uint32_t b_single_tile_size = tt_metal::detail::TileSize(b_data_format); uint32_t c_single_tile_size = tt_metal::detail::TileSize(c_data_format); @@ -185,8 +194,12 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio Buffer* c_buffer = c.buffer(); auto op_type = operation_attributes.binary_op_type; - OpConfig op_config(op_type); - auto compute_kernel_defines = op_config.as_defines(); + + std::cout << "is sfpu op ? " << is_sfpu_op << std::endl; + OpConfig op_config(op_type, is_sfpu_op); + tt::log_info(tt::LogOp, "******** sfpu_binary_op : {}", op_config.sfpu_binary_op); + + auto compute_kernel_defines = is_sfpu_op ? op_config.as_sfpu_defines(a.get_dtype()) : op_config.as_defines(); { ttnn::SmallVector lhs_activations = operation_attributes.lhs_activations; @@ -227,7 +240,9 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio create_cb(tt::CBIndex::c_0, program, all_device_cores, a_single_tile_size, num_tiles_per_cb, a_data_format); if (not compute_kernel_defines["PROCESS_LHS_ACTIVATIONS(i)"].empty()) { - auto a_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : a_data_format; + auto a_intermediate_format = is_sfpu_op ? a_data_format + : op_has_exp ? tt::DataFormat::Float16_b + : a_data_format; uint32_t a_intermediate_single_tile_size = tt_metal::detail::TileSize(a_intermediate_format); auto [a_cb_interim, a_cb_interim_handle] = create_cb( tt::CBIndex::c_3, program, all_device_cores, a_intermediate_single_tile_size, 1, a_intermediate_format); @@ -242,7 +257,9 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio create_cb(tt::CBIndex::c_1, program, all_device_cores, b_single_tile_size, b_num_tiles_per_cb, b_data_format); if (not compute_kernel_defines["PROCESS_RHS_ACTIVATIONS(i)"].empty()) { - auto b_intermediate_format = op_has_exp ? tt::DataFormat::Float16_b : b_data_format; + auto b_intermediate_format = is_sfpu_op ? b_data_format + : op_has_exp ? tt::DataFormat::Float16_b + : b_data_format; uint32_t b_intermediate_single_tile_size = tt_metal::detail::TileSize(b_intermediate_format); auto [b_cb_interim, b_cb_interim_handle] = create_cb( tt::CBIndex::c_4, program, all_device_cores, b_intermediate_single_tile_size, 1, b_intermediate_format); @@ -254,13 +271,27 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto kernel_config = CMAKE_UNIQUE_NAMESPACE::BinaryNgKernelConfig(operation_attributes.subtile_broadcast_type); + std::map reader_defines; + std::map writer_defines; + if (is_sfpu_op && a.get_dtype() == DataType::FLOAT32) { + reader_defines["SFPU_F32"] = "1"; + } else if (is_sfpu_op && a.get_dtype() == DataType::INT32) { + reader_defines["SFPU_INT32"] = "1"; + } else { + reader_defines["FPU_BF16"] = "1"; + } + writer_defines = reader_defines; + // READER KERNEL auto reader_kernel_id = tt_metal::CreateKernel( program, - get_kernel_file_path(kernel_config.reader_kernel), + get_kernel_file_path(kernel_config.reader_kernel, is_sfpu_op), all_device_cores, - tt_metal::ReaderDataMovementConfig({a_is_dram})); - + tt_metal::ReaderDataMovementConfig({a_is_dram}, reader_defines)); + std::cout << "reader kernel " << get_kernel_file_path(kernel_config.reader_kernel, is_sfpu_op) << std::endl; + for (const auto& pair : reader_defines) { + std::cout << "reader sf " << pair.first << ": " << pair.second << std::endl; + } // WRITER KERNEL auto writer_kernel = CMAKE_UNIQUE_NAMESPACE::KernelName::WriterScalar; auto compute_kernel = CMAKE_UNIQUE_NAMESPACE::KernelName::ComputeScalar; @@ -273,24 +304,45 @@ BinaryNgDeviceOperation::ProgramFactory::cached_program_t BinaryNgDeviceOperatio auto writer_kernel_id = tt_metal::CreateKernel( program, - get_kernel_file_path(writer_kernel), + get_kernel_file_path(writer_kernel, is_sfpu_op), all_device_cores, - tt_metal::WriterDataMovementConfig({b_is_dram, c_is_dram})); - + tt_metal::WriterDataMovementConfig({b_is_dram, c_is_dram}, writer_defines)); + std::cout << "writer kernel " << get_kernel_file_path(writer_kernel, is_sfpu_op) << std::endl; + for (const auto& pair : writer_defines) { + std::cout << "writer_defines sf " << pair.first << ": " << pair.second << std::endl; + } // COMPUTE KERNEL bool fp32_dest_acc_en = c_data_format == tt::DataFormat::UInt32 || c_data_format == tt::DataFormat::Int32 || c_data_format == tt::DataFormat::Float32; - + std::cout << "fp32_dest_acc_en " << fp32_dest_acc_en << std::endl; + uint32_t src0_cb_index = tt::CBIndex::c_0; + uint32_t src1_cb_index = tt::CBIndex::c_1; + uint32_t src0interim_cb_index = tt::CBIndex::c_3; + uint32_t src1interim_cb_index = tt::CBIndex::c_4; + + std::vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); + if (is_sfpu_op) { + unpack_to_dest_mode[src0_cb_index] = UnpackToDestMode::UnpackToDestFp32; + unpack_to_dest_mode[src1_cb_index] = UnpackToDestMode::UnpackToDestFp32; + unpack_to_dest_mode[src0interim_cb_index] = UnpackToDestMode::UnpackToDestFp32; + unpack_to_dest_mode[src1interim_cb_index] = UnpackToDestMode::UnpackToDestFp32; + } // Compute kernel needs to know which op it's going to perform // This has to be passed as a compile-time argument // For now we're just going to do addition compute_kernel_defines["BCAST_INPUT"] = kernel_config.bcast_input_str(); + for (const auto& pair : compute_kernel_defines) { + std::cout << "compute_kernel_defines sf " << pair.first << ": " << pair.second << std::endl; + } auto compute_kernel_id = tt_metal::CreateKernel( program, - get_kernel_file_path(compute_kernel), + get_kernel_file_path(compute_kernel, is_sfpu_op), all_device_cores, - tt_metal::ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .defines = compute_kernel_defines}); - + tt_metal::ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, + .unpack_to_dest_mode = unpack_to_dest_mode, + .defines = compute_kernel_defines}); + std::cout << "compute_kernel " << get_kernel_file_path(compute_kernel, is_sfpu_op) << std::endl; auto set_runtime_args = [](Program& program, KernelHandle kernel_id, CoreCoord core, auto&& args) { tt_metal::SetRuntimeArgs(program, kernel_id, core, args); }; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp index 7d4410f7b3f..888aa99bd1e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.cpp @@ -97,7 +97,7 @@ std::string BinaryNgKernelConfig::bcast_input_str() const { return ""; } -std::string get_kernel_file_path(KernelName kernel_name) { +std::string get_kernel_file_path(KernelName kernel_name, bool is_sfpu) { constexpr std::string_view root = "ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels"; constexpr std::string_view dataflow = "{}/dataflow/{}"; constexpr std::string_view compute = "{}/compute/{}"; @@ -112,69 +112,179 @@ std::string get_kernel_file_path(KernelName kernel_name) { case KernelName::WriterColBcast: return fmt::format(dataflow, root, "writer_interleaved_col_bcast.cpp"); case KernelName::WriterScalarBcast: return fmt::format(dataflow, root, "writer_interleaved_scalar_bcast.cpp"); case KernelName::WriterScalar: return fmt::format(dataflow, root, "writer_interleaved_scalar.cpp"); - case KernelName::ComputeNoBcast: return fmt::format(compute, root, "eltwise_binary_no_bcast.cpp"); - case KernelName::ComputeBcast: return fmt::format(compute, root, "eltwise_binary.cpp"); - case KernelName::ComputeScalar: return fmt::format(compute, root, "eltwise_binary_scalar.cpp"); + case KernelName::ComputeNoBcast: + if (is_sfpu) { + return fmt::format(compute, root, "eltwise_binary_sfpu_no_bcast.cpp"); + } else { + return fmt::format(compute, root, "eltwise_binary_no_bcast.cpp"); + } + case KernelName::ComputeBcast: + if (is_sfpu) { + return fmt::format(compute, root, "eltwise_binary_sfpu.cpp"); + } else { + return fmt::format(compute, root, "eltwise_binary.cpp"); + } + case KernelName::ComputeScalar: + if (is_sfpu) { + return fmt::format(compute, root, "eltwise_binary_sfpu_scalar.cpp"); + } else { + return fmt::format(compute, root, "eltwise_binary_scalar.cpp"); + } default: __builtin_unreachable(); // GCC 12 doesn't compile even though we exhaustively match } } -OpConfig::OpConfig(BinaryOpType binary_op_type) { - fpu_binary_op = FpuBinaryOp::SUB; - switch (binary_op_type) { - case BinaryOpType::ADD: fpu_binary_op = FpuBinaryOp::ADD; break; - case BinaryOpType::SUB: break; - case BinaryOpType::MUL: fpu_binary_op = FpuBinaryOp::MUL; break; - case BinaryOpType::DIV: - process_rhs = unary::UnaryOpType::RECIP; - fpu_binary_op = FpuBinaryOp::MUL; - break; - case BinaryOpType::GT: postprocess = unary::UnaryOpType::GTZ; break; - case BinaryOpType::LT: postprocess = unary::UnaryOpType::LTZ; break; - case BinaryOpType::GTE: postprocess = unary::UnaryOpType::GEZ; break; - case BinaryOpType::LTE: postprocess = unary::UnaryOpType::LEZ; break; - case BinaryOpType::EQ: postprocess = unary::UnaryOpType::EQZ; break; - case BinaryOpType::NE: postprocess = unary::UnaryOpType::NEZ; break; - case BinaryOpType::SQUARED_DIFFERENCE: postprocess = unary::UnaryOpType::SQUARE; break; - case BinaryOpType::BIAS_GELU: - fpu_binary_op = FpuBinaryOp::ADD; - process_lhs = unary::UnaryOpType::GELU; - break; - case BinaryOpType::LOGICAL_AND: - fpu_binary_op = FpuBinaryOp::MUL; - postprocess = unary::UnaryOpType::NEZ; - break; - case BinaryOpType::LOGICAL_OR: - fpu_binary_op = FpuBinaryOp::ADD; - process_lhs = unary::UnaryOpType::NEZ; - process_rhs = unary::UnaryOpType::NEZ; - postprocess = unary::UnaryOpType::GTZ; - break; - case BinaryOpType::LOGICAL_XOR: - process_lhs = unary::UnaryOpType::NEZ; - process_rhs = unary::UnaryOpType::NEZ; - postprocess = unary::UnaryOpType::NEZ; - break; - case BinaryOpType::LDEXP: - fpu_binary_op = FpuBinaryOp::MUL; - process_rhs = unary::UnaryOpType::EXP2; - break; - case BinaryOpType::LOGADDEXP: - fpu_binary_op = FpuBinaryOp::ADD; - process_lhs = unary::UnaryOpType::EXP; - process_rhs = unary::UnaryOpType::EXP; - postprocess = unary::UnaryOpType::LOG; - break; - case BinaryOpType::LOGADDEXP2: - fpu_binary_op = FpuBinaryOp::ADD; - process_lhs = unary::UnaryOpType::EXP2; - process_rhs = unary::UnaryOpType::EXP2; - postprocess = unary::UnaryOpType::LOG2; - break; - default: TT_THROW("Unsupported binary op {}", binary_op_type); +OpConfig::OpConfig(BinaryOpType binary_op_type, bool is_sfpu_op) { + if (!is_sfpu_op) { + fpu_binary_op = FpuBinaryOp::SUB; + switch (binary_op_type) { + case BinaryOpType::ADD: fpu_binary_op = FpuBinaryOp::ADD; break; + case BinaryOpType::SUB: break; + case BinaryOpType::MUL: fpu_binary_op = FpuBinaryOp::MUL; break; + case BinaryOpType::DIV: + process_rhs = unary::UnaryOpType::RECIP; + fpu_binary_op = FpuBinaryOp::MUL; + break; + case BinaryOpType::RSUB: + process_rhs = unary::UnaryOpType::NEG; + fpu_binary_op = FpuBinaryOp::ADD; + break; + case BinaryOpType::GT: postprocess = unary::UnaryOpType::GTZ; break; + case BinaryOpType::LT: postprocess = unary::UnaryOpType::LTZ; break; + case BinaryOpType::GTE: postprocess = unary::UnaryOpType::GEZ; break; + case BinaryOpType::LTE: postprocess = unary::UnaryOpType::LEZ; break; + case BinaryOpType::EQ: postprocess = unary::UnaryOpType::EQZ; break; + case BinaryOpType::NE: postprocess = unary::UnaryOpType::NEZ; break; + case BinaryOpType::SQUARED_DIFFERENCE: postprocess = unary::UnaryOpType::SQUARE; break; + case BinaryOpType::BIAS_GELU: + fpu_binary_op = FpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::GELU; + break; + case BinaryOpType::LOGICAL_AND: + fpu_binary_op = FpuBinaryOp::MUL; + postprocess = unary::UnaryOpType::NEZ; + break; + case BinaryOpType::LOGICAL_OR: + fpu_binary_op = FpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::GTZ; + break; + case BinaryOpType::LOGICAL_XOR: + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::NEZ; + break; + case BinaryOpType::LDEXP: + fpu_binary_op = FpuBinaryOp::MUL; + process_rhs = unary::UnaryOpType::EXP2; + break; + case BinaryOpType::LOGADDEXP: + fpu_binary_op = FpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::EXP; + process_rhs = unary::UnaryOpType::EXP; + postprocess = unary::UnaryOpType::LOG; + break; + case BinaryOpType::LOGADDEXP2: + fpu_binary_op = FpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::EXP2; + process_rhs = unary::UnaryOpType::EXP2; + postprocess = unary::UnaryOpType::LOG2; + break; + default: TT_THROW("Unsupported binary op {}", binary_op_type); + } + } else { + sfpu_binary_op = SfpuBinaryOp::SUB; + switch (binary_op_type) { + case BinaryOpType::ADD: sfpu_binary_op = SfpuBinaryOp::ADD; break; + case BinaryOpType::SUB: break; + case BinaryOpType::MUL: sfpu_binary_op = SfpuBinaryOp::MUL; break; + case BinaryOpType::DIV: sfpu_binary_op = SfpuBinaryOp::DIV; break; + case BinaryOpType::RSUB: sfpu_binary_op = SfpuBinaryOp::RSUB; break; + case BinaryOpType::GT: postprocess = unary::UnaryOpType::GTZ; break; + case BinaryOpType::LT: postprocess = unary::UnaryOpType::LTZ; break; + case BinaryOpType::GTE: postprocess = unary::UnaryOpType::GEZ; break; + case BinaryOpType::LTE: postprocess = unary::UnaryOpType::LEZ; break; + case BinaryOpType::EQ: postprocess = unary::UnaryOpType::EQZ; break; + case BinaryOpType::NE: postprocess = unary::UnaryOpType::NEZ; break; + case BinaryOpType::SQUARED_DIFFERENCE: postprocess = unary::UnaryOpType::SQUARE; break; + case BinaryOpType::BIAS_GELU: + sfpu_binary_op = SfpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::GELU; + break; + case BinaryOpType::LOGICAL_AND: + sfpu_binary_op = SfpuBinaryOp::MUL; + postprocess = unary::UnaryOpType::NEZ; + break; + case BinaryOpType::LOGICAL_OR: + sfpu_binary_op = SfpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::GTZ; + break; + case BinaryOpType::LOGICAL_XOR: + process_lhs = unary::UnaryOpType::NEZ; + process_rhs = unary::UnaryOpType::NEZ; + postprocess = unary::UnaryOpType::NEZ; + break; + case BinaryOpType::LDEXP: + sfpu_binary_op = SfpuBinaryOp::MUL; + process_rhs = unary::UnaryOpType::EXP2; + break; + case BinaryOpType::LOGADDEXP: + sfpu_binary_op = SfpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::EXP; + process_rhs = unary::UnaryOpType::EXP; + postprocess = unary::UnaryOpType::LOG; + break; + case BinaryOpType::LOGADDEXP2: + sfpu_binary_op = SfpuBinaryOp::ADD; + process_lhs = unary::UnaryOpType::EXP2; + process_rhs = unary::UnaryOpType::EXP2; + postprocess = unary::UnaryOpType::LOG2; + break; + case BinaryOpType::BITWISE_AND: sfpu_binary_op = SfpuBinaryOp::BITWISE_AND; break; + case BinaryOpType::BITWISE_OR: sfpu_binary_op = SfpuBinaryOp::BITWISE_OR; break; + case BinaryOpType::BITWISE_XOR: sfpu_binary_op = SfpuBinaryOp::BITWISE_XOR; break; + case BinaryOpType::LEFT_SHIFT: sfpu_binary_op = SfpuBinaryOp::LEFT_SHIFT; break; + case BinaryOpType::RIGHT_SHIFT: sfpu_binary_op = SfpuBinaryOp::RIGHT_SHIFT; break; + case BinaryOpType::POWER: sfpu_binary_op = SfpuBinaryOp::POWER; break; + default: TT_THROW("Unsupported binary op {}", binary_op_type); + } } } +std::pair OpConfig::get_sfpu_init_fn(DataType dtype) const { + switch (sfpu_binary_op) { + case SfpuBinaryOp::ADD: + if (dtype == DataType::INT32) { + return {"add_int32_tile_init();", "add_int32_tile"}; + } else { + return {"add_binary_tile_init();", "add_binary_tile"}; + } + case SfpuBinaryOp::SUB: return {"sub_binary_tile_init();", "sub_binary_tile"}; + case SfpuBinaryOp::MUL: return {"mul_binary_tile_init();", "mul_binary_tile"}; + case SfpuBinaryOp::DIV: return {"div_binary_tile_init();", "div_binary_tile"}; + case SfpuBinaryOp::POWER: return {"power_binary_tile_init();", "power_binary_tile"}; + case SfpuBinaryOp::RSUB: return {"rsub_binary_tile_init();", "rsub_binary_tile"}; + case SfpuBinaryOp::LEFT_SHIFT: return {"binary_shift_tile_init();", "binary_left_shift_tile"}; + case SfpuBinaryOp::RIGHT_SHIFT: return {"binary_shift_tile_init();", "binary_right_shift_tile"}; + case SfpuBinaryOp::BITWISE_AND: return {"binary_bitwise_tile_init();", "and_binary_tile"}; + case SfpuBinaryOp::BITWISE_OR: return {"binary_bitwise_tile_init();", "or_binary_tile"}; + case SfpuBinaryOp::BITWISE_XOR: return {"binary_bitwise_tile_init();", "xor_binary_tile"}; + default: TT_THROW("Unsupported sfpu binary op {}", sfpu_binary_op); return {"", ""}; + } +} + +std::map OpConfig::as_sfpu_defines(DataType dtype) const { + std::map defines; + auto [tile_init, tile_fn] = get_sfpu_init_fn(dtype); + defines["BINARY_SFPU_INIT"] = fmt::format("{}", tile_init); + defines["BINARY_SFPU_OP"] = fmt::format("{}", tile_fn); + + return defines; +} + std::map OpConfig::as_defines() const { std::map defines; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp index b1ee7f41700..7b1902db0e7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/binary_ng_utils.hpp @@ -6,6 +6,7 @@ #include "binary_ng_device_operation.hpp" #include "ttnn/operations/eltwise/binary_ng/types.hpp" +#include "ttnn/tensor/types.hpp" #include #include @@ -38,19 +39,35 @@ struct BinaryNgKernelConfig { std::optional bcast_input; }; -std::string get_kernel_file_path(KernelName kernel_name); +std::string get_kernel_file_path(KernelName kernel_name, bool is_sfpu); struct OpConfig { enum class FpuBinaryOp { ADD, SUB, MUL }; - - OpConfig(BinaryOpType binary_op_type); + enum class SfpuBinaryOp { + ADD, + SUB, + MUL, + DIV, + POWER, + RSUB, + LEFT_SHIFT, + RIGHT_SHIFT, + BITWISE_AND, + BITWISE_OR, + BITWISE_XOR + }; + + OpConfig(BinaryOpType binary_op_type, bool is_sfpu); std::map as_defines() const; + std::pair get_sfpu_init_fn(DataType dtype) const; + std::map as_sfpu_defines(DataType dtype) const; std::optional process_lhs{}; std::optional process_rhs{}; std::optional postprocess{}; FpuBinaryOp fpu_binary_op; + SfpuBinaryOp sfpu_binary_op; }; void add_activation_defines( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp index 8f6f7990bf1..7d2bde98bf5 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp @@ -6,6 +6,7 @@ #include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" +#include "eltwise_utils_common.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp index f2a263bfb1a..30ff24df71d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp @@ -7,6 +7,7 @@ #include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" +#include "eltwise_utils_common.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp index b9f2e29903e..db05ebae73a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp @@ -6,6 +6,7 @@ #include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" #include "compute_kernel_api/eltwise_binary.h" +#include "eltwise_utils_common.hpp" #include "eltwise_utils.hpp" namespace NAMESPACE { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu.cpp new file mode 100644 index 00000000000..b19c619cd07 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu.cpp @@ -0,0 +1,115 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" + +#include "compute_kernel_api/eltwise_binary_sfpu.h" +#include "compute_kernel_api/binary_bitwise_sfpu.h" +#include "compute_kernel_api/binary_shift.h" +#include "compute_kernel_api/add_int32_sfpu.h" + +#include "eltwise_utils_common.hpp" +#include "eltwise_utils_sfpu.hpp" + +namespace NAMESPACE { + +ALWI void process_tile( + tt::CBIndex cb_pre_lhs, + tt::CBIndex cb_post_lhs, + tt::CBIndex cb_pre_rhs, + tt::CBIndex cb_post_rhs, + tt::CBIndex cb_out, + uint32_t freq, + uint32_t tile_start) { + using namespace ckernel; + constexpr uint32_t onetile = 1; + +#if BCAST_INPUT +#define CB_PRE_BCAST cb_pre_rhs +#define CB_POST_BCAST cb_post_rhs +#define CB_PRE_OTHER cb_pre_lhs +#define CB_POST_OTHER cb_post_lhs +#else +#define CB_PRE_BCAST cb_pre_lhs +#define CB_POST_BCAST cb_post_lhs +#define CB_PRE_OTHER cb_pre_rhs +#define CB_POST_OTHER cb_post_rhs +#endif + + PREPROCESS(BCAST_OP, CB_PRE_BCAST, CB_POST_BCAST, cb_out, onetile); + cb_wait_front(CB_POST_BCAST, onetile); + + for (uint32_t j = tile_start; j < freq; ++j) { + PREPROCESS(OTHER_OP, CB_PRE_OTHER, CB_POST_OTHER, cb_out, onetile); + cb_wait_front(CB_POST_OTHER, onetile); + + cb_reserve_back(cb_out, onetile); + +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) + BINARY_SFPU_INIT +#endif + tile_regs_acquire(); + copy_tile_to_dst_init_short_with_dt(cb_post_rhs, cb_post_lhs); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_post_lhs, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_post_lhs, cb_post_rhs); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_post_rhs, i, i * 2 + 1); + + BINARY_SFPU_OP(i * 2, i * 2 + 1); + PROCESS_POST_ACTIVATIONS(i * 2); + tile_regs_commit(); + + tile_regs_wait(); + pack_tile(i * 2, cb_out); + } + tile_regs_release(); + + cb_push_back(cb_out, onetile); + cb_pop_front(CB_POST_OTHER, onetile); + } + cb_pop_front(CB_POST_BCAST, onetile); +} + +void MAIN { + uint32_t num_tiles = get_arg_val(0); + uint32_t tile_freq = get_arg_val(1); + uint32_t tile_start = get_arg_val(2); + + if (num_tiles == 0) { + return; + } + + constexpr auto cb_pre_lhs = tt::CBIndex::c_0; + constexpr auto cb_pre_rhs = tt::CBIndex::c_1; + constexpr auto cb_out = tt::CBIndex::c_2; + + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; + + unary_op_init_common(cb_post_lhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif + +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) + BINARY_SFPU_INIT +#endif + + uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq; + uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq; + + for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) { + process_tile(cb_pre_lhs, cb_post_lhs, cb_pre_rhs, cb_post_rhs, cb_out, tile_freq, tile_start); + } + + if (remaining_iterations > 0) { + process_tile(cb_pre_lhs, cb_post_lhs, cb_pre_rhs, cb_post_rhs, cb_out, remaining_iterations, tile_start); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu_no_bcast.cpp new file mode 100644 index 00000000000..25eec4e1eeb --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu_no_bcast.cpp @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" + +#include "compute_kernel_api/eltwise_binary_sfpu.h" +#include "compute_kernel_api/binary_bitwise_sfpu.h" +#include "compute_kernel_api/binary_shift.h" +#include "compute_kernel_api/add_int32_sfpu.h" + +#include "eltwise_utils_common.hpp" +#include "eltwise_utils_sfpu.hpp" + +namespace NAMESPACE { +void MAIN { + uint32_t num_tiles = get_arg_val(0); + + constexpr auto cb_pre_lhs = tt::CBIndex::c_0; + constexpr auto cb_pre_rhs = tt::CBIndex::c_1; + constexpr auto cb_out = tt::CBIndex::c_2; + + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; + + unary_op_init_common(cb_post_lhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif + +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) + BINARY_SFPU_INIT +#endif + + constexpr uint32_t onetile = 1; + + for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { + PREPROCESS(LHS, cb_pre_lhs, cb_post_lhs, cb_out, onetile); + cb_wait_front(cb_post_lhs, onetile); + + PREPROCESS(RHS, cb_pre_rhs, cb_post_rhs, cb_out, onetile); + cb_wait_front(cb_post_rhs, onetile); + + cb_reserve_back(cb_out, onetile); + +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) + BINARY_SFPU_INIT +#endif + tile_regs_acquire(); + copy_tile_to_dst_init_short_with_dt(cb_post_rhs, cb_post_lhs); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_post_lhs, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_post_lhs, cb_post_rhs); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_post_rhs, i, i * 2 + 1); + + BINARY_SFPU_OP(i * 2, i * 2 + 1); + PROCESS_POST_ACTIVATIONS(i * 2); + tile_regs_commit(); + + tile_regs_wait(); + pack_tile(i * 2, cb_out); + } + tile_regs_release(); + + cb_push_back(cb_out, onetile); + cb_pop_front(cb_post_lhs, onetile); + cb_pop_front(cb_post_rhs, onetile); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu_scalar.cpp new file mode 100644 index 00000000000..655bc99aab5 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_sfpu_scalar.cpp @@ -0,0 +1,76 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" + +#include "compute_kernel_api/eltwise_binary_sfpu.h" +#include "compute_kernel_api/binary_bitwise_sfpu.h" +#include "compute_kernel_api/binary_shift.h" +#include "compute_kernel_api/add_int32_sfpu.h" + +#include "eltwise_utils_common.hpp" +#include "eltwise_utils_sfpu.hpp" + +namespace NAMESPACE { +void MAIN { + uint32_t num_tiles = get_arg_val(0); + + constexpr auto cb_pre_lhs = tt::CBIndex::c_0; + constexpr auto cb_pre_rhs = tt::CBIndex::c_1; + constexpr auto cb_out = tt::CBIndex::c_2; + + constexpr auto cb_post_lhs = HAS_ACTIVATIONS(LHS) ? tt::CBIndex::c_3 : cb_pre_lhs; + constexpr auto cb_post_rhs = HAS_ACTIVATIONS(RHS) ? tt::CBIndex::c_4 : cb_pre_rhs; + + unary_op_init_common(cb_post_lhs, cb_out); +#ifdef PACK_RELU + PACK((llk_pack_relu_config(ReluType::ZERO_RELU))); +#endif + +#if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) + BINARY_SFPU_INIT +#endif + + constexpr uint32_t onetile = 1; + + PREPROCESS(RHS, cb_pre_rhs, cb_post_rhs, cb_out, onetile); + cb_wait_front(cb_post_rhs, onetile); + // copy_tile_to_dst_init_short_with_dt(cb_post_lhs, cb_post_rhs); + // copy_tile(cb_post_rhs, 0, 1); + + for (uint32_t tile_id = 0; tile_id < num_tiles; ++tile_id) { + PREPROCESS(LHS, cb_pre_lhs, cb_post_lhs, cb_out, onetile); + cb_wait_front(cb_post_lhs, onetile); + + cb_reserve_back(cb_out, onetile); + +#if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) + BINARY_SFPU_INIT +#endif + tile_regs_acquire(); + copy_tile_to_dst_init_short_with_dt(cb_post_rhs, cb_post_lhs); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_post_lhs, i, i * 2); + } + copy_tile_to_dst_init_short_with_dt(cb_post_lhs, cb_post_rhs); + for (uint32_t i = 0; i < onetile; ++i) { + copy_tile(cb_post_rhs, i, i * 2 + 1); + BINARY_SFPU_OP(i * 2, i * 2 + 1); + PROCESS_POST_ACTIVATIONS(i * 2); + tile_regs_commit(); + + tile_regs_wait(); + pack_tile(i * 2, cb_out); + } + tile_regs_release(); + + cb_pop_front(cb_post_lhs, onetile); + cb_push_back(cb_out, onetile); + } + // cb_pop_front(cb_post_rhs, onetile); +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp index 4451eb0e8f3..91f1a238d35 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils.hpp @@ -7,181 +7,6 @@ #include "compute_kernel_api/common.h" #include "compute_kernel_api/tile_move_copy.h" -#define ACTIVATION_INIT_RELU relu_tile_init -#define ACTIVATION_APPLY_RELU relu_tile - -#define ACTIVATION_INIT_SQUARE square_tile_init -#define ACTIVATION_APPLY_SQUARE square_tile - -#define ACTIVATION_INIT_GTZ gtz_tile_init -#define ACTIVATION_APPLY_GTZ gtz_tile - -#define ACTIVATION_INIT_LTZ ltz_tile_init -#define ACTIVATION_APPLY_LTZ ltz_tile - -#define ACTIVATION_INIT_GEZ gez_tile_init -#define ACTIVATION_APPLY_GEZ gez_tile - -#define ACTIVATION_INIT_LEZ lez_tile_init -#define ACTIVATION_APPLY_LEZ lez_tile - -#define ACTIVATION_INIT_EQZ eqz_tile_init -#define ACTIVATION_APPLY_EQZ eqz_tile - -#define ACTIVATION_INIT_NEZ nez_tile_init -#define ACTIVATION_APPLY_NEZ nez_tile - -#define ACTIVATION_INIT_LOG log_tile_init -#define ACTIVATION_APPLY_LOG log_tile - -#define ACTIVATION_INIT_TANH tanh_tile_init -#define ACTIVATION_APPLY_TANH tanh_tile - -#define ACTIVATION_INIT_LOG2 log_with_base_tile_init -#define ACTIVATION_APPLY_LOG2(i) log_with_base_tile(i, 0x3dc5u) - -#define ACTIVATION_INIT_LOG10 log_with_base_tile_init -#define ACTIVATION_APPLY_LOG10(i) log_with_base_tile(i, 0x36f3u) - -#define ACTIVATION_INIT_EXP exp_tile_init -#define ACTIVATION_APPLY_EXP exp_tile - -#define ACTIVATION_INIT_EXP2 exp2_tile_init -#define ACTIVATION_APPLY_EXP2 exp2_tile - -#define ACTIVATION_INIT_EXPM1 expm1_tile_init -#define ACTIVATION_APPLY_EXPM1 expm1_tile - -#define ACTIVATION_INIT_RECIP recip_tile_init -#define ACTIVATION_APPLY_RECIP recip_tile - -#define ACTIVATION_INIT_GELU gelu_tile_init -#define ACTIVATION_APPLY_GELU gelu_tile - -#define ACTIVATION_INIT_SQRT sqrt_tile_init -#define ACTIVATION_APPLY_SQRT sqrt_tile - -#define ACTIVATION_INIT_SIGMOID sigmoid_tile_init -#define ACTIVATION_APPLY_SIGMOID sigmoid_tile - -#define ACTIVATION_INIT_SIN sin_tile_init -#define ACTIVATION_APPLY_SIN sin_tile - -#define ACTIVATION_INIT_COS cos_tile_init -#define ACTIVATION_APPLY_COS cos_tile - -#define ACTIVATION_INIT_TAN tan_tile_init -#define ACTIVATION_APPLY_TAN tan_tile - -#define ACTIVATION_INIT_ASIN asin_tile_init -#define ACTIVATION_APPLY_ASIN asin_tile - -#define ACTIVATION_INIT_ACOS acos_tile_init -#define ACTIVATION_APPLY_ACOS acos_tile - -#define ACTIVATION_INIT_ATAN atan_tile_init -#define ACTIVATION_APPLY_ATAN atan_tile - -#define ACTIVATION_INIT_ABS abs_tile_init -#define ACTIVATION_APPLY_ABS abs_tile - -#define ACTIVATION_INIT_SIGN sign_tile_init -#define ACTIVATION_APPLY_SIGN sign_tile - -#define ACTIVATION_INIT_SIGNBIT signbit_tile_init -#define ACTIVATION_APPLY_SIGNBIT signbit_tile - -#define ACTIVATION_INIT_RSQRT rsqrt_tile_init -#define ACTIVATION_APPLY_RSQRT rsqrt_tile - -#define ACTIVATION_INIT_RELU6 relu_max_tile_init -#define ACTIVATION_APPLY_RELU6(i) relu_max_tile(i, 0x40c00000u) - -#define ACTIVATION_INIT_ERF erf_tile_init -#define ACTIVATION_APPLY_ERF erf_tile - -#define ACTIVATION_INIT_ERFC erfc_tile_init -#define ACTIVATION_APPLY_ERFC erfc_tile - -#define ACTIVATION_INIT_ISINF isinf_tile_init -#define ACTIVATION_APPLY_ISINF isinf_tile - -#define ACTIVATION_INIT_ISPOSINF isposinf_tile_init -#define ACTIVATION_APPLY_ISPOSINF isposinf_tile - -#define ACTIVATION_INIT_ISNEGINF isneginf_tile_init -#define ACTIVATION_APPLY_ISNEGINF isneginf_tile - -#define ACTIVATION_INIT_ISNAN isnan_tile_init -#define ACTIVATION_APPLY_ISNAN isnan_tile - -#define ACTIVATION_INIT_ISFINITE isfinite_tile_init -#define ACTIVATION_APPLY_ISFINITE isfinite_tile - -#define ACTIVATION_INIT_LOGICAL_NOT_UNARY logical_not_unary_tile_init -#define ACTIVATION_APPLY_LOGICAL_NOT_UNARY logical_not_unary_tile - -#define ACTIVATION_INIT_ERFINV erfinv_tile_init -#define ACTIVATION_APPLY_ERFINV erfinv_tile - -#define ACTIVATION_INIT_I0 i0_tile_init -#define ACTIVATION_APPLY_I0 i0_tile - -#define ACTIVATION_INIT_I1 i1_tile_init -#define ACTIVATION_APPLY_I1 i1_tile - -#define ACTIVATION_INIT_SILU silu_tile_init -#define ACTIVATION_APPLY_SILU silu_tile - -#define ACTIVATION_INIT_NEG negative_tile_init -#define ACTIVATION_APPLY_NEG negative_tile - -#define ACTIVATION_INIT_BITWISE_NOT bitwise_not_tile_init -#define ACTIVATION_APPLY_BITWISE_NOT bitwise_not_tile - -#define ACTIVATION_INIT_FLOOR floor_tile_init -#define ACTIVATION_APPLY_FLOOR floor_tile - -#define ACTIVATION_INIT_CEIL ceil_tile_init -#define ACTIVATION_APPLY_CEIL ceil_tile - -#define IS_EMPTY(...) P_CAT(IS_EMPTY_, IS_BEGIN_PARENS(__VA_ARGS__))(__VA_ARGS__) -#define IS_EMPTY_0(...) IS_BEGIN_PARENS(IS_EMPTY_NON_FUNCTION_C __VA_ARGS__()) -#define IS_EMPTY_1(...) 0 -#define IS_EMPTY_NON_FUNCTION_C(...) () - -#define IS_BEGIN_PARENS(...) P_FIRST(P_CAT(P_IS_VARIADIC_R_, P_IS_VARIADIC_C __VA_ARGS__)) - -#define P_IS_VARIADIC_R_1 1, -#define P_IS_VARIADIC_R_P_IS_VARIADIC_C 0, -#define P_IS_VARIADIC_C(...) 1 - -#define P_FIRST(...) P_FIRST_(__VA_ARGS__, ) -#define P_FIRST_(a, ...) a - -#define P_CAT(a, ...) P_CAT_(a, __VA_ARGS__) -#define P_CAT_(a, ...) a##__VA_ARGS__ - -#define P_COMPL(b) P_CAT(P_COMPL_, b) -#define P_COMPL_0 1 -#define P_COMPL_1 0 - -#define ACTIVATION_INIT(elem) ACTIVATION_INIT_##elem() -#define ACTIVATION_APPLY(elem, i) ACTIVATION_APPLY_##elem(i) - -#define PROCESS_ACTIVATION(elem, i) \ - ACTIVATION_INIT(elem); \ - ACTIVATION_APPLY(elem, i) - -#define PROCESS_ACTIVATIONS(op, i) PROCESS_ACTIVATIONS_(op)(i) -#define PROCESS_ACTIVATIONS_(op) PROCESS_##op##_ACTIVATIONS -#define HAS_ACTIVATIONS(op) P_COMPL(IS_EMPTY(PROCESS_ACTIVATIONS(op, 0))) - -#define BCAST_OP P_CAT(BCAST_OP_, BCAST_INPUT) -#define OTHER_OP P_CAT(BCAST_OP_, P_COMPL(BCAST_INPUT)) -#define BCAST_OP_0 LHS -#define BCAST_OP_1 RHS - #define PREPROCESS(op, ...) P_CAT(PREPROCESS_, HAS_ACTIVATIONS(op))(op, __VA_ARGS__) #define PREPROCESS_0(...) #define PREPROCESS_1(op, cb_pre, cb_post, cb_out, per_core_block_size) \ diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp new file mode 100644 index 00000000000..bae10426875 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp @@ -0,0 +1,180 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#define ACTIVATION_INIT_RELU relu_tile_init +#define ACTIVATION_APPLY_RELU relu_tile + +#define ACTIVATION_INIT_SQUARE square_tile_init +#define ACTIVATION_APPLY_SQUARE square_tile + +#define ACTIVATION_INIT_GTZ gtz_tile_init +#define ACTIVATION_APPLY_GTZ gtz_tile + +#define ACTIVATION_INIT_LTZ ltz_tile_init +#define ACTIVATION_APPLY_LTZ ltz_tile + +#define ACTIVATION_INIT_GEZ gez_tile_init +#define ACTIVATION_APPLY_GEZ gez_tile + +#define ACTIVATION_INIT_LEZ lez_tile_init +#define ACTIVATION_APPLY_LEZ lez_tile + +#define ACTIVATION_INIT_EQZ eqz_tile_init +#define ACTIVATION_APPLY_EQZ eqz_tile + +#define ACTIVATION_INIT_NEZ nez_tile_init +#define ACTIVATION_APPLY_NEZ nez_tile + +#define ACTIVATION_INIT_LOG log_tile_init +#define ACTIVATION_APPLY_LOG log_tile + +#define ACTIVATION_INIT_TANH tanh_tile_init +#define ACTIVATION_APPLY_TANH tanh_tile + +#define ACTIVATION_INIT_LOG2 log_with_base_tile_init +#define ACTIVATION_APPLY_LOG2(i) log_with_base_tile(i, 0x3dc5u) + +#define ACTIVATION_INIT_LOG10 log_with_base_tile_init +#define ACTIVATION_APPLY_LOG10(i) log_with_base_tile(i, 0x36f3u) + +#define ACTIVATION_INIT_EXP exp_tile_init +#define ACTIVATION_APPLY_EXP exp_tile + +#define ACTIVATION_INIT_EXP2 exp2_tile_init +#define ACTIVATION_APPLY_EXP2 exp2_tile + +#define ACTIVATION_INIT_EXPM1 expm1_tile_init +#define ACTIVATION_APPLY_EXPM1 expm1_tile + +#define ACTIVATION_INIT_RECIP recip_tile_init +#define ACTIVATION_APPLY_RECIP recip_tile + +#define ACTIVATION_INIT_GELU gelu_tile_init +#define ACTIVATION_APPLY_GELU gelu_tile + +#define ACTIVATION_INIT_SQRT sqrt_tile_init +#define ACTIVATION_APPLY_SQRT sqrt_tile + +#define ACTIVATION_INIT_SIGMOID sigmoid_tile_init +#define ACTIVATION_APPLY_SIGMOID sigmoid_tile + +#define ACTIVATION_INIT_SIN sin_tile_init +#define ACTIVATION_APPLY_SIN sin_tile + +#define ACTIVATION_INIT_COS cos_tile_init +#define ACTIVATION_APPLY_COS cos_tile + +#define ACTIVATION_INIT_TAN tan_tile_init +#define ACTIVATION_APPLY_TAN tan_tile + +#define ACTIVATION_INIT_ASIN asin_tile_init +#define ACTIVATION_APPLY_ASIN asin_tile + +#define ACTIVATION_INIT_ACOS acos_tile_init +#define ACTIVATION_APPLY_ACOS acos_tile + +#define ACTIVATION_INIT_ATAN atan_tile_init +#define ACTIVATION_APPLY_ATAN atan_tile + +#define ACTIVATION_INIT_ABS abs_tile_init +#define ACTIVATION_APPLY_ABS abs_tile + +#define ACTIVATION_INIT_SIGN sign_tile_init +#define ACTIVATION_APPLY_SIGN sign_tile + +#define ACTIVATION_INIT_SIGNBIT signbit_tile_init +#define ACTIVATION_APPLY_SIGNBIT signbit_tile + +#define ACTIVATION_INIT_RSQRT rsqrt_tile_init +#define ACTIVATION_APPLY_RSQRT rsqrt_tile + +#define ACTIVATION_INIT_RELU6 relu_max_tile_init +#define ACTIVATION_APPLY_RELU6(i) relu_max_tile(i, 0x40c00000u) + +#define ACTIVATION_INIT_ERF erf_tile_init +#define ACTIVATION_APPLY_ERF erf_tile + +#define ACTIVATION_INIT_ERFC erfc_tile_init +#define ACTIVATION_APPLY_ERFC erfc_tile + +#define ACTIVATION_INIT_ISINF isinf_tile_init +#define ACTIVATION_APPLY_ISINF isinf_tile + +#define ACTIVATION_INIT_ISPOSINF isposinf_tile_init +#define ACTIVATION_APPLY_ISPOSINF isposinf_tile + +#define ACTIVATION_INIT_ISNEGINF isneginf_tile_init +#define ACTIVATION_APPLY_ISNEGINF isneginf_tile + +#define ACTIVATION_INIT_ISNAN isnan_tile_init +#define ACTIVATION_APPLY_ISNAN isnan_tile + +#define ACTIVATION_INIT_ISFINITE isfinite_tile_init +#define ACTIVATION_APPLY_ISFINITE isfinite_tile + +#define ACTIVATION_INIT_LOGICAL_NOT_UNARY logical_not_unary_tile_init +#define ACTIVATION_APPLY_LOGICAL_NOT_UNARY logical_not_unary_tile + +#define ACTIVATION_INIT_ERFINV erfinv_tile_init +#define ACTIVATION_APPLY_ERFINV erfinv_tile + +#define ACTIVATION_INIT_I0 i0_tile_init +#define ACTIVATION_APPLY_I0 i0_tile + +#define ACTIVATION_INIT_I1 i1_tile_init +#define ACTIVATION_APPLY_I1 i1_tile + +#define ACTIVATION_INIT_SILU silu_tile_init +#define ACTIVATION_APPLY_SILU silu_tile + +#define ACTIVATION_INIT_NEG negative_tile_init +#define ACTIVATION_APPLY_NEG negative_tile + +#define ACTIVATION_INIT_BITWISE_NOT bitwise_not_tile_init +#define ACTIVATION_APPLY_BITWISE_NOT bitwise_not_tile + +#define ACTIVATION_INIT_FLOOR floor_tile_init +#define ACTIVATION_APPLY_FLOOR floor_tile + +#define ACTIVATION_INIT_CEIL ceil_tile_init +#define ACTIVATION_APPLY_CEIL ceil_tile + +#define IS_EMPTY(...) P_CAT(IS_EMPTY_, IS_BEGIN_PARENS(__VA_ARGS__))(__VA_ARGS__) +#define IS_EMPTY_0(...) IS_BEGIN_PARENS(IS_EMPTY_NON_FUNCTION_C __VA_ARGS__()) +#define IS_EMPTY_1(...) 0 +#define IS_EMPTY_NON_FUNCTION_C(...) () + +#define IS_BEGIN_PARENS(...) P_FIRST(P_CAT(P_IS_VARIADIC_R_, P_IS_VARIADIC_C __VA_ARGS__)) + +#define P_IS_VARIADIC_R_1 1, +#define P_IS_VARIADIC_R_P_IS_VARIADIC_C 0, +#define P_IS_VARIADIC_C(...) 1 + +#define P_FIRST(...) P_FIRST_(__VA_ARGS__, ) +#define P_FIRST_(a, ...) a + +#define P_CAT(a, ...) P_CAT_(a, __VA_ARGS__) +#define P_CAT_(a, ...) a##__VA_ARGS__ + +#define P_COMPL(b) P_CAT(P_COMPL_, b) +#define P_COMPL_0 1 +#define P_COMPL_1 0 + +#define ACTIVATION_INIT(elem) ACTIVATION_INIT_##elem() +#define ACTIVATION_APPLY(elem, i) ACTIVATION_APPLY_##elem(i) + +#define PROCESS_ACTIVATION(elem, i) \ + ACTIVATION_INIT(elem); \ + ACTIVATION_APPLY(elem, i) + +#define PROCESS_ACTIVATIONS(op, i) PROCESS_ACTIVATIONS_(op)(i) +#define PROCESS_ACTIVATIONS_(op) PROCESS_##op##_ACTIVATIONS +#define HAS_ACTIVATIONS(op) P_COMPL(IS_EMPTY(PROCESS_ACTIVATIONS(op, 0))) + +#define BCAST_OP P_CAT(BCAST_OP_, BCAST_INPUT) +#define OTHER_OP P_CAT(BCAST_OP_, P_COMPL(BCAST_INPUT)) +#define BCAST_OP_0 LHS +#define BCAST_OP_1 RHS diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_sfpu.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_sfpu.hpp new file mode 100644 index 00000000000..f65757f1e4c --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_sfpu.hpp @@ -0,0 +1,35 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "compute_kernel_api/common.h" +#include "compute_kernel_api/tile_move_copy.h" + +#define PREPROCESS(op, ...) P_CAT(PREPROCESS_, HAS_ACTIVATIONS(op))(op, __VA_ARGS__) +#define PREPROCESS_0(...) +#define PREPROCESS_1(op, cb_pre, cb_post, cb_out, per_core_block_size) \ + do { \ + using namespace ckernel; \ + cb_wait_front(cb_pre, per_core_block_size); \ + cb_reserve_back(cb_post, per_core_block_size); \ + \ + tile_regs_acquire(); \ + for (uint32_t i = 0; i < per_core_block_size; ++i) { \ + copy_tile_to_dst_init_short(); \ + copy_tile(cb_pre, i, i); \ + PROCESS_ACTIVATIONS(op, i); \ + } \ + tile_regs_commit(); \ + \ + tile_regs_wait(); \ + for (uint32_t i = 0; i < per_core_block_size; ++i) { \ + pack_tile(i, cb_post); \ + } \ + tile_regs_release(); \ + \ + cb_pop_front(cb_pre, per_core_block_size); \ + cb_push_back(cb_post, per_core_block_size); \ + \ + } while (0) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp index 17f8a669c99..1b80fc62106 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/fill_tile_utils.hpp @@ -17,6 +17,14 @@ FORCE_INLINE void fill_with_val_bfloat16(uint32_t cb_id, uint32_t packed_scalar) } } +template +FORCE_INLINE void fill_with_val(uint32_t cb_id, ScalarT scalar) { + auto* ptr = reinterpret_cast(get_write_ptr(cb_id)); + for (uint32_t i = 0; i < ElementsV; ++i) { + ptr[i] = scalar; + } +} + // Reads the very first element of the CB and fills the entire tile with that value. // Tile is assumed to have 16-bit elements FORCE_INLINE void fill_tile_with_first_element_bfloat16(uint32_t cb_id) { @@ -33,6 +41,19 @@ FORCE_INLINE void fill_tile_with_first_element_bfloat16(uint32_t cb_id) { } } +// Reads the very first element of the CB and fills the entire tile with that value. +// Tile is assumed to have 32-bit elements (float32 or int32). +template +FORCE_INLINE void fill_tile_with_first_element(uint32_t cb_id) { + auto* read_ptr = reinterpret_cast(get_write_ptr(cb_id)); + const T first_elem = read_ptr[0]; + + auto* write_ptr = reinterpret_cast(get_write_ptr(cb_id)); + for (uint32_t i = 0; i < 1024; ++i) { + write_ptr[i] = first_elem; + } +} + // Reads the very first row of the CB and fills the entire tile with the same row. // Tile is assumed to have 16-bit elements. FORCE_INLINE void fill_tile_with_first_row_bfloat16(uint32_t cb_id) { @@ -59,6 +80,30 @@ FORCE_INLINE void fill_tile_with_first_row_bfloat16(uint32_t cb_id) { } } +// Reads the very first row of the CB and fills the entire tile with the same row. +// Tile is assumed to have 32-bit elements (float32/int32). +FORCE_INLINE void fill_tile_with_first_row(uint32_t cb_id) { + // Tile with 4 faces (16x16) and 32-bit elements + auto* ptr = reinterpret_cast(get_write_ptr(cb_id)); + + uint32_t row_offset = 16; // Start at the second row (offset by 16 elements) + uint32_t num_rows = 15; // 15 rows to fill per face + + // Iterate over face pairs (0,1) and (2,3) + for (uint32_t k = 0, face_offset = 0; k < 2; ++k, face_offset += 512) { // Offset 512 = 256 elements x 2 faces + for (uint32_t row = 0; row < num_rows; ++row) { + uint32_t dst_offset = face_offset + row_offset; + for (uint32_t col = 0; col < 16; ++col) { + ptr[dst_offset + col] = ptr[col]; // left face + ptr[dst_offset + col + 256] = ptr[col + 256]; // right face + } + row_offset += 16; // Move to the next row (16 elements per row) + } + row_offset = 0; // Reset for the next face pair + num_rows = 16; // Process all rows for the next face pair + } +} + // Reads the very first column of the CB and fills the entire tile with the same column. // Tile is assumed to have 16-bit elements. FORCE_INLINE void fill_tile_with_first_column_bfloat16(uint32_t cb_id) { @@ -83,3 +128,31 @@ FORCE_INLINE void fill_tile_with_first_column_bfloat16(uint32_t cb_id) { } } } + +// Reads the very first column of the CB and fills the entire tile with the same column. +// Tile is assumed to have 32-bit elements (float32/int32). +FORCE_INLINE void fill_tile_with_first_column(uint32_t cb_id) { + // Tile with 4 faces (16x16) and 32-bit elements + auto* ptr = reinterpret_cast(get_write_ptr(cb_id)); + + constexpr uint32_t num_rows = 16; // Number of rows per face + constexpr uint32_t face_row_stride = 16; // Elements per row + constexpr uint32_t face_size = 256; // Total elements per face (16x16) + constexpr uint32_t face_offset_stride = 512; // Total elements per pair of faces (2x16x16) + + // Iterate over face pairs (0,1) and (2,3) + for (uint32_t k = 0, face_offset = 0; k < 2; ++k, face_offset += face_offset_stride) { + for (uint32_t row = 0, row_offset = 0; row < num_rows; ++row, row_offset += face_row_stride) { + uint32_t left_dst_offset = face_offset + row_offset; // Left face (0 or 2) + uint32_t right_dst_offset = left_dst_offset + face_size; // Right face (1 or 3) + + // Read the first column value for the current row from the left face + auto src_val = ptr[left_dst_offset]; + + for (uint32_t col = 0; col < face_row_stride; ++col) { + ptr[left_dst_offset + col] = src_val; // left face + ptr[right_dst_offset + col] = src_val; // right face + } + } + } +} diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp index cf0a1e4e13d..7d22fb8f7d5 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_col_bcast.cpp @@ -49,7 +49,15 @@ void kernel_main() { uint32_t l1_write_addr = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + th, src, l1_write_addr); noc_async_read_barrier(); +#ifdef SFPU_F32 + fill_tile_with_first_column(cb_id_src); +#endif +#ifdef SFPU_INT32 + fill_tile_with_first_column(cb_id_src); +#endif +#ifdef FPU_BF16 fill_tile_with_first_column_bfloat16(cb_id_src); +#endif cb_push_back(cb_id_src, onetile); num_tiles_read += Wt - start_tw; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp index 7278e6b3510..39db4b2aa8e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_row_bcast.cpp @@ -50,7 +50,15 @@ void kernel_main() { uint32_t l1_write_addr_src = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + tw, src, l1_write_addr_src); noc_async_read_barrier(); +#ifdef SFPU_F32 + fill_tile_with_first_row(cb_id_src); +#endif +#ifdef SFPU_INT32 + fill_tile_with_first_row(cb_id_src); +#endif +#ifdef FPU_BF16 fill_tile_with_first_row_bfloat16(cb_id_src); +#endif cb_push_back(cb_id_src, onetile); } } diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp index d1f0922d648..e39c9d99638 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/reader_interleaved_scalar_bcast.cpp @@ -45,7 +45,15 @@ void kernel_main() { uint32_t l1_write_addr_src = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset, src, l1_write_addr_src); noc_async_read_barrier(); +#ifdef SFPU_F32 + fill_tile_with_first_element(cb_id_src); +#endif +#ifdef SFPU_INT32 + fill_tile_with_first_element(cb_id_src); +#endif +#ifdef FPU_BF16 fill_tile_with_first_element_bfloat16(cb_id_src); +#endif cb_push_back(cb_id_src, onetile); num_tiles_read += HtWt - start_t; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp index f17b23684ff..07ccfaebb66 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_col_bcast.cpp @@ -59,7 +59,15 @@ void kernel_main() { uint32_t l1_write_addr = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + th, src, l1_write_addr); noc_async_read_barrier(); +#ifdef SFPU_F32 + fill_tile_with_first_column(cb_id_src); +#endif +#ifdef SFPU_INT32 + fill_tile_with_first_column(cb_id_src); +#endif +#ifdef FPU_BF16 fill_tile_with_first_column_bfloat16(cb_id_src); +#endif cb_push_back(cb_id_src, onetile); for (uint32_t tw = start_tw; tw < Wt && num_tiles_written < num_tiles; ++tw, ++num_tiles_written) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp index 65ff8e60f69..21eb1abd027 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_row_bcast.cpp @@ -60,7 +60,15 @@ void kernel_main() { uint32_t l1_write_addr = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset + tw, src, l1_write_addr); noc_async_read_barrier(); +#ifdef SFPU_F32 + fill_tile_with_first_row(cb_id_src); +#endif +#ifdef SFPU_INT32 + fill_tile_with_first_row(cb_id_src); +#endif +#ifdef FPU_BF16 fill_tile_with_first_row_bfloat16(cb_id_src); +#endif cb_push_back(cb_id_src, onetile); // write a tile to dst, since the dst shape is full, the tile offset simply grows linearly diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp index 452d1dafaa7..bda991ddd7f 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar.cpp @@ -39,7 +39,16 @@ void kernel_main() { // we only need to fill a tile with the scalar value once cb_reserve_back(cb_id_src, onetile); +#ifdef SFPU_F32 + float* float_ptr = reinterpret_cast(&packed_scalar); + fill_with_val<1024, float>(cb_id_src, *float_ptr); +#endif +#ifdef SFPU_INT32 + fill_with_val<1024, int32_t>(cb_id_src, packed_scalar); +#endif +#ifdef FPU_BF16 fill_with_val_bfloat16(cb_id_src, packed_scalar); +#endif cb_push_back(cb_id_src, onetile); uint32_t num_tiles_written = 0; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp index 18915373d25..81b32282846 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/dataflow/writer_interleaved_scalar_bcast.cpp @@ -54,7 +54,15 @@ void kernel_main() { uint32_t l1_write_addr = get_write_ptr(cb_id_src); noc_async_read_tile(tile_offset, src, l1_write_addr); noc_async_read_barrier(); +#ifdef SFPU_F32 + fill_tile_with_first_element(cb_id_src); +#endif +#ifdef SFPU_INT32 + fill_tile_with_first_element(cb_id_src); +#endif +#ifdef FPU_BF16 fill_tile_with_first_element_bfloat16(cb_id_src); +#endif cb_push_back(cb_id_src, onetile); for (uint32_t t = start_t; t < HtWt && num_tiles_written < num_tiles; ++t, ++num_tiles_written) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp index 06c53bfe7e6..76a602040b0 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/types.hpp @@ -10,7 +10,6 @@ enum class BinaryOpType { ADD, SUB, MUL, - DIV, GT, LT, LTE, @@ -25,5 +24,13 @@ enum class BinaryOpType { LDEXP, LOGADDEXP, LOGADDEXP2, + DIV, + RSUB, + POWER, + BITWISE_XOR, + BITWISE_AND, + BITWISE_OR, + LEFT_SHIFT, + RIGHT_SHIFT }; } diff --git a/ttnn/ttnn/operations/binary_ng.py b/ttnn/ttnn/operations/binary_ng.py index d012271bda5..c1d0533dd6f 100644 --- a/ttnn/ttnn/operations/binary_ng.py +++ b/ttnn/ttnn/operations/binary_ng.py @@ -8,6 +8,7 @@ ttnn.attach_golden_function(ttnn.experimental.add, golden_function=lambda a, b: a + b) ttnn.attach_golden_function(ttnn.experimental.sub, golden_function=lambda a, b: a - b) +ttnn.attach_golden_function(ttnn.experimental.rsub, golden_function=lambda a, b: b - a) ttnn.attach_golden_function(ttnn.experimental.mul, golden_function=lambda a, b: a * b) ttnn.attach_golden_function(ttnn.experimental.div, golden_function=lambda a, b: torch.divide(a, b)) ttnn.attach_golden_function(ttnn.experimental.eq, golden_function=lambda a, b: torch.eq(a, b))