diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py index 2bb028ee6cf..da305202e9c 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_unary.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py @@ -428,3 +428,22 @@ def run_unary_test_bitwise_not(device, h, w, fill_value, ttnn_function, pcc=0.99 @pytest.mark.parametrize("fill_value", [-2147483647, 2147483648, 7534, 225, 97, 3]) def test_bitwise_not(device, h, w, fill_value): run_unary_test_bitwise_not(device, h, w, fill_value, ttnn.bitwise_not) + + +@skip_for_grayskull() +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_unary_floor(input_shapes, device): + in_data1 = torch.empty(input_shapes, dtype=torch.float32).uniform_(-43566, 43565) + input_tensor1 = ttnn.from_torch(in_data1, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.floor(input_tensor1) + golden_function = ttnn.get_golden_function(ttnn.floor) + golden_tensor = golden_function(in_data1) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(golden_tensor, output_tensor, 0.999) diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h index 30d18596ef4..ad167758a24 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h @@ -15,12 +15,40 @@ using namespace sfpi; namespace ckernel { namespace sfpu { +inline vInt float_to_int32(vFloat in) +{ + vInt result; + vInt exp = exexp(in); // extract exponent + v_if (exp < 0) { + result = 0; + } v_elseif (exp > 30) { + // set to int32 max value in case of overflow + result = std::numeric_limits::max(); + // check sign + v_if (in < 0) { + result = reinterpret(setsgn(reinterpret(result), 1)); + } v_endif + } v_else { + // extract mantissa + vInt man = exman8(in); + // shift the mantissa by (23-exponent) to the right + vInt shift = exp - 23; + man = shft(reinterpret(man), shift); + // check sign + v_if (in < 0) { + man = reinterpret(setsgn(reinterpret(man), 1)); + } v_endif + result = man; + } v_endif + return result; +} + template inline void calculate_floor() { for (int d = 0; d < ITERATIONS; d++) { vFloat result = dst_reg[0]; vFloat v = result; - vInt tmp = float_to_int16(result, 0); // TODO: Replace float_to_int16 to float_to_int32 once it is available + vInt tmp = float_to_int16(result, 0); result = int32_to_float(tmp, 0); v_if(result > v) { result = result - 1; } v_endif; @@ -31,5 +59,19 @@ inline void calculate_floor() { } } +template +inline void calculate_floor_float32() { + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = dst_reg[0]; + vFloat v = result; + vInt tmp = float_to_int32(result); + result = int32_to_float(tmp, 0); + v_if(result > v) { result = result - 1; } + v_endif; + dst_reg[0] = result; + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h index ff0e6e96daf..26325252a0f 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h @@ -23,4 +23,9 @@ inline void llk_math_eltwise_unary_sfpu_floor(uint dst_index, int vector_mode = ckernel::sfpu::calculate_floor, dst_index, vector_mode); } +template +inline void llk_math_eltwise_unary_sfpu_floor_float32(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_floor_float32, dst_index, vector_mode); +} } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h index 30d18596ef4..ad167758a24 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h @@ -15,12 +15,40 @@ using namespace sfpi; namespace ckernel { namespace sfpu { +inline vInt float_to_int32(vFloat in) +{ + vInt result; + vInt exp = exexp(in); // extract exponent + v_if (exp < 0) { + result = 0; + } v_elseif (exp > 30) { + // set to int32 max value in case of overflow + result = std::numeric_limits::max(); + // check sign + v_if (in < 0) { + result = reinterpret(setsgn(reinterpret(result), 1)); + } v_endif + } v_else { + // extract mantissa + vInt man = exman8(in); + // shift the mantissa by (23-exponent) to the right + vInt shift = exp - 23; + man = shft(reinterpret(man), shift); + // check sign + v_if (in < 0) { + man = reinterpret(setsgn(reinterpret(man), 1)); + } v_endif + result = man; + } v_endif + return result; +} + template inline void calculate_floor() { for (int d = 0; d < ITERATIONS; d++) { vFloat result = dst_reg[0]; vFloat v = result; - vInt tmp = float_to_int16(result, 0); // TODO: Replace float_to_int16 to float_to_int32 once it is available + vInt tmp = float_to_int16(result, 0); result = int32_to_float(tmp, 0); v_if(result > v) { result = result - 1; } v_endif; @@ -31,5 +59,19 @@ inline void calculate_floor() { } } +template +inline void calculate_floor_float32() { + for (int d = 0; d < ITERATIONS; d++) { + vFloat result = dst_reg[0]; + vFloat v = result; + vInt tmp = float_to_int32(result); + result = int32_to_float(tmp, 0); + v_if(result > v) { result = result - 1; } + v_endif; + dst_reg[0] = result; + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h index ff0e6e96daf..26325252a0f 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_floor.h @@ -23,4 +23,9 @@ inline void llk_math_eltwise_unary_sfpu_floor(uint dst_index, int vector_mode = ckernel::sfpu::calculate_floor, dst_index, vector_mode); } +template +inline void llk_math_eltwise_unary_sfpu_floor_float32(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_floor_float32, dst_index, vector_mode); +} } // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h b/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h index ecad592ee52..fe45132ff0a 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/floor.h @@ -14,7 +14,6 @@ #endif namespace ckernel { - /** * Please refer to documentation for any_init. */ @@ -31,9 +30,25 @@ ALWI void floor_tile_init() { MATH((llk_math_eltwise_unary_sfpu_floor_init(idst))); } +/** + * Performs floor operation on each row of a tile. + * in DST register at index tile_index. The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid + * Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to perform floor operation | uint32_t | Must be + * less than the size of the DST register buffer | True | + */ +ALWI void floor_tile_float32(uint32_t idst) { MATH((llk_math_eltwise_unary_sfpu_floor_float32(idst))); } + } // namespace ckernel diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp index 394647cdb68..da0ef55c936 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_types.hpp @@ -80,6 +80,7 @@ enum class UnaryOpType { BITWISE_OR, RIGHT_SHIFT, FLOOR, + FLOOR_FLOAT32, CEIL, LEFT_SHIFT, REMAINDER, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp index f18aa992748..c1e40613c1c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp @@ -49,6 +49,8 @@ void update_macro_defines(UnaryOpType op_type, std::map get_op_init_and_func_default(UnaryOpType op_type, std: case UnaryOpType::SIGNBIT: op_init_and_name = {"signbit_tile_init();", fmt::format("signbit_tile({});", idst)}; break; - case UnaryOpType::FLOOR: op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile({});", idst)}; break; case UnaryOpType::CEIL: op_init_and_name = {"ceil_tile_init();", fmt::format("ceil_tile({});", idst)}; break; case UnaryOpType::SIN: op_init_and_name = {"sin_tile_init();", fmt::format("sin_tile({});", idst)}; break; case UnaryOpType::COS: op_init_and_name = {"cos_tile_init();", fmt::format("cos_tile({});", idst)}; break; @@ -340,6 +340,12 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, std: case UnaryOpType::IDENTITY_UINT32: op_init_and_name = {"identity_tile_init();", fmt::format("identity_tile_uint32({});", idst)}; break; + case UnaryOpType::FLOOR: + op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile({});", idst)}; + break; + case UnaryOpType::FLOOR_FLOAT32: + op_init_and_name = {"floor_tile_init();", fmt::format("floor_tile_float32({});", idst)}; break; + break; case UnaryOpType::RELU6: op_init_and_name = {"relu_max_tile_init();", fmt::format("relu_max_tile({}, 0x40c00000u);", idst)}; break; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp index 81eba41d570..c87dae81384 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp @@ -99,7 +99,6 @@ template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; -template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; template struct ExecuteUnary; @@ -337,6 +336,32 @@ Tensor Identity::invoke( DefaultQueueId, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); } +Tensor Floor::invoke( + uint8_t queue_id, + const Tensor& input_tensor, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + UnaryOpType op_type = UnaryOpType::FLOOR; + if (input_tensor.get_dtype() == DataType::FLOAT32) { + op_type = UnaryOpType::FLOOR_FLOAT32; + } + + return detail::unary_impl(queue_id, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); +} + +Tensor Floor::invoke( + const Tensor& input_tensor, + const std::optional& memory_config, + const std::optional& optional_output_tensor) { + UnaryOpType op_type = UnaryOpType::FLOOR; + if (input_tensor.get_dtype() == DataType::FLOAT32) { + op_type = UnaryOpType::FLOOR_FLOAT32; + } + + return detail::unary_impl( + DefaultQueueId, input_tensor, {UnaryWithParam{op_type}}, memory_config, optional_output_tensor); +} + Tensor Dropout::invoke( const Tensor& input, const uint32_t seed, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp index 79d5d22eb5a..7c034bfe66e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary.hpp @@ -148,6 +148,19 @@ struct Identity { const std::optional& optional_output_tensor = std::nullopt); }; +struct Floor { + static Tensor invoke( + uint8_t queue_id, + const Tensor& input_tensor, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); + + static Tensor invoke( + const Tensor& input_tensor, + const std::optional& memory_config = std::nullopt, + const std::optional& optional_output_tensor = std::nullopt); +}; + struct Dropout { static Tensor invoke( const Tensor& input, @@ -281,7 +294,6 @@ REGISTER_UNARY_OPERATION(erfinv, ERFINV); REGISTER_UNARY_OPERATION(exp2, EXP2); REGISTER_UNARY_OPERATION(expm1, EXPM1); REGISTER_UNARY_OPERATION(eqz, EQZ); -REGISTER_UNARY_OPERATION(floor, FLOOR); REGISTER_UNARY_OPERATION(ceil, CEIL); REGISTER_UNARY_OPERATION(gez, GEZ); REGISTER_UNARY_OPERATION(gtz, GTZ); @@ -354,6 +366,8 @@ constexpr auto dropout = ttnn::register_operation_with_auto_launch_op<"ttnn::dropout", ttnn::operations::unary::Dropout>(); constexpr auto identity = ttnn::register_operation_with_auto_launch_op<"ttnn::identity", ttnn::operations::unary::Identity>(); +constexpr auto floor = + ttnn::register_operation_with_auto_launch_op<"ttnn::floor", ttnn::operations::unary::Floor>(); constexpr auto softplus = ttnn::register_operation_with_auto_launch_op<"ttnn::softplus", ttnn::operations::unary::Softplus>(); constexpr auto prelu_sfpu =