diff --git a/tests/ttnn/unit_tests/test_model_preprocessing.py b/tests/ttnn/unit_tests/test_model_preprocessing.py index 416d891133a..602e0ff956f 100644 --- a/tests/ttnn/unit_tests/test_model_preprocessing.py +++ b/tests/ttnn/unit_tests/test_model_preprocessing.py @@ -326,7 +326,10 @@ def forward(self, x): output_tensor = ttnn.to_device(output_tensor, device=device) output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT) output_tensor = ttnn.relu(output_tensor) + print(f"Output shape is {output_tensor.shape}") output_tensor = ttnn.permute(output_tensor, (0, 3, 1, 2)) + print(f"Expected shape after permute is ttnn.Shape([1, 128, 28, 28[32]]") + print(f"Actual shape is {output_tensor.shape}") output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) output_tensor = ttnn.get_fallback_function(ttnn.reshape)(output_tensor, (-1, num_output_channels)) output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 4345ded2fc0..66a56755dad 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -813,6 +813,8 @@ const Shape Tensor::strides() const { return detail::compute_strides(this->get_l uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); } +uint32_t Tensor::intended_volume() const { return tt::tt_metal::compute_volume(this->get_shape()); } + Tensor create_device_tensor( const Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) { ZoneScoped; diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index 226f7913e13..af49d700bbb 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -311,6 +311,7 @@ struct Tensor { StorageType storage_type() const; const Shape strides() const; uint32_t volume() const; + uint32_t intended_volume() const; bool is_allocated() const; diff --git a/tt_eager/tt_dnn/op_library/auto_format.cpp b/tt_eager/tt_dnn/op_library/auto_format.cpp index 56ad92cc639..744124b9763 100644 --- a/tt_eager/tt_dnn/op_library/auto_format.cpp +++ b/tt_eager/tt_dnn/op_library/auto_format.cpp @@ -40,6 +40,24 @@ Tensor AutoFormat::move_tensor_to_mem_config(const Tensor& input, const MemoryCo } } +// This code is a workaround for cases where we need to remove autoformat but other dependent ops +// are not quite ready. So here we basically just put the tensor back on device. +// Used in backward_ops.cpp +// See: Remove auto format within permute_op.cpp #9404 +Tensor AutoFormat::move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional target_mem_config){ + const auto intended_shape = input.get_shape(); + const auto device_shape = input.get_legacy_shape(); + const auto new_intended_shape = std::array{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]}; + const auto new_device_shape = std::array{ + device_shape[0], + device_shape[1], + (device_shape[-2] % TILE_HEIGHT != 0 ? (device_shape[-2] / TILE_HEIGHT + 1) * TILE_HEIGHT : device_shape[-2]), + (device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1]) + }; + const auto new_shape = tt_metal::Shape(new_intended_shape, new_device_shape); + return AutoFormat::format_input_tensor(input, device, new_shape, 0.0, target_layout, target_mem_config); +} + Tensor AutoFormat::format_input_tensor( const Tensor& input, Device* device, diff --git a/tt_eager/tt_dnn/op_library/auto_format.hpp b/tt_eager/tt_dnn/op_library/auto_format.hpp index c2de0e0542f..0e6f9056ae3 100644 --- a/tt_eager/tt_dnn/op_library/auto_format.hpp +++ b/tt_eager/tt_dnn/op_library/auto_format.hpp @@ -34,10 +34,10 @@ class AutoFormat { static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) { - auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0]; - auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1]; - auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2]; - auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3]; + auto n = pad_n ? round_up(unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1, TILE_HEIGHT) : unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1; + auto c = pad_c ? round_up(unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1, TILE_WIDTH) : unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1; + auto h = pad_h ? round_up(unpadded_shape[-2], TILE_HEIGHT) : unpadded_shape[-2]; + auto w = pad_w ? round_up(unpadded_shape[-1], TILE_WIDTH) : unpadded_shape[-1]; Shape padded_shape = {n, c, h, w}; return padded_shape; } @@ -83,6 +83,12 @@ class AutoFormat { return false; } + // This code is a workaround for cases where we need to remove autoformat but other dependent ops + // are not quite ready. So here we basically just put the tensor back on device. + // Used in backward_ops.cpp + // See: Remove auto format within permute_op.cpp #9404 + static Tensor move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional target_mem_config); + static Tensor move_tensor_to_device(const Tensor &input, Device * device, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); static Tensor move_tensor_to_mem_config(const Tensor &input, const MemoryConfig& mem_config); diff --git a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp index 4a4208c1a80..7720c9918cc 100644 --- a/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp +++ b/tt_eager/tt_dnn/op_library/backward/backward_ops.cpp @@ -2170,6 +2170,11 @@ std::vector _prod_bw( std::vector after_permute_dims = {0, 2, 3, 1}; Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); + + // put the tensor back on device because permute throws it off device + // See: Remove auto format within permute_op.cpp #9404 + tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config()); + after_permute_dims = {0, 3, 1, 2}; Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), @@ -2202,6 +2207,11 @@ std::vector _prod_bw( std::vector after_permute_dims = {3, 1, 2, 0}; Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config); Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config); + + // put the tensor back on device because permute throws it off device + // See: Remove auto format within permute_op.cpp #9404 + tensor_2 = AutoFormat::move_tensor_to_device_and_pad(tensor_2, tensor_1.device(),tensor_1.get_layout(), tensor_1.memory_config()); + Tensor result = permute( bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config), after_permute_dims, diff --git a/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp b/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp index d0e4526412d..da72d4c5890 100644 --- a/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp +++ b/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp @@ -73,7 +73,7 @@ inline Tensor bcast( const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; - operation::launch_with_autoformat( + operation::launch_op( [bcast_op, bcast_dim, output_mem_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { using tt::constants::TILE_HEIGHT; using tt::constants::TILE_WIDTH; @@ -109,7 +109,7 @@ inline Tensor bcast( input_tensor_b.get_legacy_shape()[-1] == TILE_WIDTH)); } } - return operation::run_with_autoformat( + return operation::run( EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config}, {input_tensor_a, input_tensor_b}); }, {input_tensor_a, input_tensor_b}, output_tensors); return output_tensors.at(0); diff --git a/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp b/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp index 06885ce922b..fd9fe860a62 100644 --- a/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp +++ b/tt_eager/tt_dnn/op_library/bcast/multi_core_h/bcast_op_sharded_h.cpp @@ -22,8 +22,8 @@ namespace tt_metal { operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math/*, BcastOpDim bcast_dim*/){ const auto ashape = a.get_legacy_shape(); const auto bshape = b.get_legacy_shape(); - uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3]; - uint32_t bN = bshape[0], bC = bshape[1], bH = bshape[2], bW = bshape[3]; + uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1, C = ashape.rank() >= 3 ? ashape[-3] : 1, H = ashape[-2], W = ashape[-1]; + uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1, bC = bshape.rank() >= 3 ? bshape[-3] : 1, bH = bshape[-2], bW = bshape[-1]; uint32_t NC = N*C; diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index cf9e418618f..a01a83129f7 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -19,24 +19,27 @@ #include "tt_eager/tt_dnn/op_library/pad/pad_op.hpp" #include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp" #include "tt_numpy/functions.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" namespace tt { namespace tt_metal { Tensor mk_zero_tensor_like(const Tensor& reference_tensor, const MemoryConfig& output_mem_config) { - // Tensor zero_like = bcast(reference_tensor, , BcastOpMath::MUL, BcastOpDim::HW); - Tensor zero = mk_tiled_scalar(0.0f, reference_tensor.get_dtype()); + Tensor zero = ttnn::operations::creation::create_scalar(0.0f, reference_tensor.get_dtype(),Layout::TILE, reference_tensor.device()); Tensor zero_like = bcast(reference_tensor, zero, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + zero.deallocate(); return zero_like; } // TODO: enable zeroes(), ones() and eye() type functions on-device using this type of logic template Tensor mk_filled_tensor_like(const Tensor& reference_tensor, T val, const MemoryConfig& output_mem_config) { - Tensor k = mk_tiled_scalar(val, reference_tensor.get_dtype()); + Tensor k = ttnn::operations::creation::create_scalar(val, reference_tensor.get_dtype(), Layout::TILE, reference_tensor.device()); Tensor zero_like = mk_zero_tensor_like(reference_tensor, output_mem_config); Tensor result = bcast(zero_like, k, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + k.deallocate(); + zero_like.deallocate(); return result; } @@ -255,11 +258,11 @@ Tensor mish(const Tensor& a, const MemoryConfig& output_mem_config) { Tensor _selu(const Tensor& x, const float scale, const float alpha, const MemoryConfig& output_mem_config) { // term 2 Tensor x_Exp = exp(x, output_mem_config); - Tensor minus_one = mk_tiled_scalar(-1.0f); + Tensor minus_one = ttnn::operations::creation::create_scalar(-1.0f,x.get_dtype(),Layout::TILE, x.device()); Tensor x_Exp_minus_1 = bcast(x_Exp, minus_one, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); x_Exp.deallocate(); minus_one.deallocate(); - Tensor t_alpha = mk_tiled_scalar(alpha); + Tensor t_alpha = ttnn::operations::creation::create_scalar(alpha,x.get_dtype(),Layout::TILE, x.device()); Tensor result_t2_ = bcast(x_Exp_minus_1, t_alpha, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); x_Exp_minus_1.deallocate(); t_alpha.deallocate(); @@ -267,7 +270,7 @@ Tensor _selu(const Tensor& x, const float scale, const float alpha, const Memory result_t2_.deallocate(); // term 1 - Tensor t_scale = mk_tiled_scalar(scale); + Tensor t_scale = ttnn::operations::creation::create_scalar(scale,x.get_dtype(),Layout::TILE, x.device()); Tensor x_relu = relu(x, output_mem_config); Tensor result_term1 = bcast(x_relu, t_scale, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); t_scale.deallocate(); @@ -288,7 +291,10 @@ Tensor selu(const Tensor& x, const float scale, const float alpha, const MemoryC Tensor rpow(const Tensor& a, float k, const MemoryConfig& output_mem_config) { TT_ASSERT(k > 0.0, "rpow cannot be calcualted for non-positive numbers"); float log_k = logf(k); - Tensor result = bcast(a, mk_tiled_scalar(log_k), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + + Tensor scalar = ttnn::operations::creation::create_scalar(log_k,a.get_dtype(),Layout::TILE, a.device()); + Tensor result = bcast(a, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + scalar.deallocate(); return exp(result, output_mem_config); } @@ -349,13 +355,20 @@ Tensor _polyval(const Tensor& input_tensor, std::vector coeffs, const Mem return mk_filled_tensor_like(input_tensor, coeffs[0], output_mem_config); } + Tensor scalar = ttnn::operations::creation::create_scalar(coeffs[0], input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); Tensor result = - bcast(input_tensor, mk_tiled_scalar(coeffs[0]), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + bcast(input_tensor, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + scalar.deallocate(); for (int idx = 1; idx < coeffs.size() - 1; idx++) { - result = bcast(result, mk_tiled_scalar(coeffs[idx]), BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + Tensor scalar = ttnn::operations::creation::create_scalar(coeffs[idx], input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + result = bcast(result, scalar, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + scalar.deallocate(); result = mul(input_tensor, result, std::nullopt, output_mem_config); } - return bcast(result, mk_tiled_scalar(coeffs.back()), BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + Tensor last_coeffs = ttnn::operations::creation::create_scalar(coeffs.back(), input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor final_tensor = bcast(result, last_coeffs, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config); + last_coeffs.deallocate(); + return final_tensor; } Tensor polyval(const Tensor& input_tensor, std::vector coeffs, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _polyval)(input_tensor, coeffs, output_mem_config); @@ -365,9 +378,9 @@ Tensor polyval(const Tensor& input_tensor, std::vector coeffs, const Memo // compute multiply-accumulate: y = a * b + c, over various 8 combinations of a, b, c // being a scalar or tensor Tensor _mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfig& output_mem_config) { - bool a_is_scalar = a.volume() == 1; - bool b_is_scalar = b.volume() == 1; - bool c_is_scalar = c.volume() == 1; + bool a_is_scalar = a.intended_volume() == 1; + bool b_is_scalar = b.intended_volume() == 1; + bool c_is_scalar = c.intended_volume() == 1; const auto dim = BcastOpDim::HW; if (!a_is_scalar && !b_is_scalar && !c_is_scalar) { @@ -405,9 +418,12 @@ Tensor mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfig } Tensor _mac_overload(const Tensor& a, float b, float c, const MemoryConfig& output_mem_config) { - Tensor t_b = mk_scalar(b); - Tensor t_c = mk_scalar(c); - return mac(a, t_b, t_c, output_mem_config); + Tensor t_b = ttnn::operations::creation::create_scalar(b, a.get_dtype(), Layout::TILE, a.device()); + Tensor t_c = ttnn::operations::creation::create_scalar(c, a.get_dtype(), Layout::TILE, a.device()); + Tensor return_tensor = mac(a, t_b, t_c, output_mem_config); + t_b.deallocate(); + t_c.deallocate(); + return return_tensor; } Tensor mac(const Tensor& input_a, float b, float c, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _mac_overload)(input_a, b, c, output_mem_config); @@ -451,7 +467,9 @@ Tensor _sinh(const Tensor& input_a, const MemoryConfig& output_mem_config) { Tensor nr_term = sub(e_pos_x, e_neg_x, std::nullopt, output_mem_config); e_pos_x.deallocate(); e_neg_x.deallocate(); - return bcast(nr_term, mk_tiled_scalar(0.5f), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor scalar = ttnn::operations::creation::create_scalar(0.5f, input_a.get_dtype(), Layout::TILE, input_a.device()); + return bcast(nr_term, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + scalar.deallocate(); } Tensor sinh(const Tensor& input_a, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _sinh)(input_a, output_mem_config); @@ -464,7 +482,9 @@ Tensor _cosh(const Tensor& input_a, const MemoryConfig& output_mem_config) { Tensor nr_term = add(e_pos_x, e_neg_x, std::nullopt, output_mem_config); e_pos_x.deallocate(); e_neg_x.deallocate(); - return bcast(nr_term, mk_tiled_scalar(0.5f), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + Tensor scalar = ttnn::operations::creation::create_scalar(0.5f, input_a.get_dtype(), Layout::TILE, input_a.device()); + return bcast(nr_term, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + scalar.deallocate(); } Tensor cosh(const Tensor& input_a, const MemoryConfig& output_mem_config) { return operation::decorate_as_composite(__func__, _cosh)(input_a, output_mem_config); @@ -509,12 +529,14 @@ Tensor _acosh(const Tensor& input_a, const MemoryConfig& output_mem_config) { // To handle inputs <= 1 // input < 1, output is nan // input > 1, output is acosh(input) + Tensor scalar = ttnn::operations::creation::create_scalar(std::nanf(""), input_a.get_dtype(), Layout::TILE, input_a.device()); Tensor nan_res = bcast( lte(input_a, t_one, std::nullopt, output_mem_config), - mk_tiled_scalar(std::nanf("")), + scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); + scalar.deallocate(); t_result = mul(gt(input_a, t_one, std::nullopt, output_mem_config), ln_res, std::nullopt, output_mem_config); t_result = add(nan_res, t_result, std::nullopt, output_mem_config); } @@ -553,7 +575,7 @@ Tensor atanh(const Tensor& input_a, const MemoryConfig& output_mem_config) { // lerp(input, end, weight) = start + weight * (end - start) Tensor _lerp(const Tensor& input_a, const Tensor& input_b, float value, const MemoryConfig& output_mem_config) { - Tensor t_value = mk_tiled_scalar(value); + Tensor t_value = ttnn::operations::creation::create_scalar(value,input_a.get_dtype(), Layout::TILE, input_a.device()); Tensor t_diff = sub(input_b, input_a, std::nullopt, output_mem_config); Tensor t_mul = bcast(t_diff, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); Tensor result = add(input_a, t_mul, std::nullopt, output_mem_config); @@ -826,7 +848,7 @@ Tensor _addcmul( const Tensor& input_c, float value, const MemoryConfig& output_mem_config) { - Tensor t_value = mk_tiled_scalar(value); + Tensor t_value = ttnn::operations::creation::create_scalar(value,input_a.get_dtype(), Layout::TILE, input_a.device()); Tensor t_mul = mul(input_b, input_c, std::nullopt, output_mem_config); Tensor t_factor = bcast(t_mul, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); t_mul.deallocate(); @@ -850,7 +872,7 @@ Tensor _addcdiv( const Tensor& input_c, float value, const MemoryConfig& output_mem_config) { - Tensor t_value = mk_tiled_scalar(value); + Tensor t_value = ttnn::operations::creation::create_scalar(value,input_a.get_dtype(), Layout::TILE, input_a.device()); Tensor t_div = mul(input_b, recip(input_c, output_mem_config), std::nullopt, output_mem_config); Tensor t_factor = bcast(t_div, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); t_div.deallocate(); @@ -1294,14 +1316,14 @@ Tensor scatter(const Tensor& input_a, const Tensor& input_b, const MemoryConfig& } // threshold(a,t,v) = (a <= t)*v + (a > t)*a -Tensor _threshold(const Tensor& input_a, float threshold, float value, const MemoryConfig& output_mem_config) { - Tensor t_threshold = mk_tiled_scalar(threshold, input_a.get_dtype()); - Tensor t0 = bcast(input_a, t_threshold, BcastOpMath::SUB, BcastOpDim::HW, output_mem_config); +Tensor _threshold(const Tensor& input_tensor, float threshold, float value, const MemoryConfig& output_mem_config) { + Tensor t_threshold = ttnn::operations::creation::create_scalar(threshold,input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor t0 = bcast(input_tensor, t_threshold, BcastOpMath::SUB, BcastOpDim::HW, output_mem_config); t_threshold.deallocate(); - Tensor t_value = mk_tiled_scalar(value, input_a.get_dtype()); + Tensor t_value = ttnn::operations::creation::create_scalar(value,input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); Tensor t1 = bcast(lez(t0), t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); t_value.deallocate(); - Tensor t2 = mul(gtz(t0, output_mem_config), input_a, std::nullopt, output_mem_config); + Tensor t2 = mul(gtz(t0, output_mem_config), input_tensor, std::nullopt, output_mem_config); return add(t1, t2, std::nullopt, output_mem_config); } Tensor threshold(const Tensor& input_a, float threshold, float value, const MemoryConfig& output_mem_config) { @@ -1352,16 +1374,16 @@ Tensor digamma(const Tensor& input_a, const MemoryConfig& output_mem_config) { // cbrt(a) = pow(a,1/3) or (cbrt(a))**3 = a. // = exp[ (1/3)*log[a] ] -Tensor _cbrt(const Tensor& input_a, const MemoryConfig& output_mem_config) { +Tensor _cbrt(const Tensor& input_tensor, const MemoryConfig& output_mem_config) { constexpr float scale = (float)(1.0 / 3.0); - Tensor t_scale = mk_tiled_scalar(scale); - Tensor t_ln_input = log(abs(input_a, output_mem_config), output_mem_config); // negative log is not useful here + Tensor t_scale = ttnn::operations::creation::create_scalar(scale,input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor t_ln_input = log(abs(input_tensor, output_mem_config), output_mem_config); // negative log is not useful here Tensor t1 = bcast(t_ln_input, t_scale, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config); t_scale.deallocate(); t_ln_input.deallocate(); Tensor t2 = exp(t1, output_mem_config); t1.deallocate(); - Tensor t3 = mul(t2, sign(input_a, output_mem_config), std::nullopt, output_mem_config); + Tensor t3 = mul(t2, sign(input_tensor, output_mem_config), std::nullopt, output_mem_config); return t3; } Tensor cbrt(const Tensor& input_a, const MemoryConfig& output_mem_config) { diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index e6bad24e0f3..271da5b69c4 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -22,44 +22,6 @@ using binary_tensor_op_t = Tensor(const Tensor& a, const Tensor& b); // Note: inline doesn't allow pybind to work well so we keep few function not inlined. -template -Tensor mk_scalar(T value) { - assert(std::is_scalar::value && "T should be scalar"); - std::array shape = {1, 1, 1, 1}; - auto buffer = owned_buffer::create(std::vector{bfloat16(value)}); - Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::ROW_MAJOR); - return scalar; -} - -template -Tensor mk_tiled_scalar(T value) { - assert(std::is_scalar::value && "T should be scalar"); - std::array shape = {1, 1, TILE_HEIGHT, TILE_WIDTH}; - std::vector buffer_vec(TILE_HW, bfloat16(0)); - buffer_vec[0] = bfloat16(value); - auto buffer = owned_buffer::create(std::move(buffer_vec)); - Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::TILE); - return scalar; -} - -template -Tensor mk_tiled_scalar(T value, DataType dtype) { - assert(std::is_scalar::value && "T should be scalar"); - std::array shape = {1, 1, TILE_HEIGHT, TILE_WIDTH}; - if(dtype == DataType::BFLOAT8_B) - { - std::vector buffer_vec(TILE_HW, float(0)); - buffer_vec[0] = float(value); - auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(buffer_vec, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), shape, DataType::BFLOAT8_B, Layout::TILE); - } - std::vector buffer_vec(TILE_HW, bfloat16(0)); - buffer_vec[0] = bfloat16(value); - auto buffer = owned_buffer::create(std::move(buffer_vec)); - Tensor scalar = Tensor(OwnedStorage{buffer}, shape, DataType::BFLOAT16, Layout::TILE); - return scalar; -} // Function: softshrink // Ref: https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html Tensor softshrink( diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index f72b98b6977..8af966bb43e 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -11,6 +11,7 @@ #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" using namespace tt::constants; @@ -433,8 +434,10 @@ const operation::Hash EltwiseUnary::compute_program_hash(const std::vector Tensor tie_binop_to_unary(const Tensor& input_tensor, float value, const MemoryConfig& output_mem_config) { - Tensor t_value = mk_tiled_scalar(value, input_tensor.get_dtype()); - return bcast(input_tensor, t_value, OP, BcastOpDim::HW); + Tensor t_value = ttnn::operations::creation::create_scalar(value,input_tensor.get_dtype(), Layout::TILE, input_tensor.device()); + Tensor tensor = bcast(input_tensor, t_value, OP, BcastOpDim::HW); + t_value.deallocate(); + return tensor; } Tensor lte_unary(const Tensor& input_tensor, float value, const MemoryConfig& output_mem_config) { diff --git a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp index 33197420599..f87918b615e 100644 --- a/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp @@ -8,6 +8,7 @@ #include #include "tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp" +#include "ttnn/cpp/ttnn/operations/creation.hpp" namespace tt { @@ -225,7 +226,9 @@ Tensor moreh_clip_grad_norm_impl( // max_norm / (total_norm + 1e-6) const auto &clip_coef = div_unary(max_norm, add_unary(total_norm, 1e-6f)); // min(clip_coef, 1.0f) - const auto &clip_coef_clamped = min(clip_coef, mk_tiled_scalar(1.0f)); + Tensor scalar = ttnn::operations::creation::create_scalar(1.0f,inputs.at(0).get_dtype(),Layout::TILE, inputs.at(0).device()); + const auto &clip_coef_clamped = min(clip_coef, scalar); + scalar.deallocate(); // Inplace update inputs(inputs *= clip_coef_clamped) moreh_clip_grad_norm_step3(inputs, clip_coef_clamped); diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index 41bb024b840..f082bb2ae99 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -17,6 +17,42 @@ namespace ttnn { namespace operations { namespace creation { +template +Tensor create_scalar(T scalar, DataType data_type, Layout layout, Device* device){ + static_assert(rank >=2, "Rank must be at least 2 when creating a tensor with TILE_LAYOUT"); + std::array intended_shape = {}; + intended_shape.fill(1); + std::array device_shape = {}; + device_shape.fill(1); + + if(layout == Layout::ROW_MAJOR){ + device_shape[device_shape.size() - 2] = 2; + auto host_buffer = owned_buffer::create<::bfloat16>(static_cast(2)); + host_buffer[0] = scalar; + Tensor scalar_tensor_host = Tensor( + OwnedStorage{host_buffer}, + ttnn::Shape(intended_shape, device_shape), + data_type, + Layout::ROW_MAJOR); + return scalar_tensor_host.to(device); + } + else if(layout == Layout::TILE){ + device_shape[device_shape.size() - 2] = TILE_HEIGHT; + device_shape[device_shape.size() - 1] = TILE_WIDTH; + auto host_buffer = owned_buffer::create<::bfloat16>(static_cast(TILE_HEIGHT * TILE_WIDTH)); + host_buffer[0] = scalar; + Tensor scalar_tensor_host = Tensor( + OwnedStorage{host_buffer}, + ttnn::Shape(intended_shape, device_shape), + data_type, + Layout::TILE); + return scalar_tensor_host.to(device); + } + else{ + throw std::runtime_error("Unsupported layout"); + } +} + template inline ttnn::Tensor full( const ttnn::Shape& shape, diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index bae3685e09e..3d4b3ff1d70 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -222,6 +222,18 @@ Tensor triu( return tt::tt_metal::triu(input_tensor, diag, memory_config.value_or(input_tensor.memory_config())); } +// TODO +// Tensor hardshrink( +// const Tensor& input_tensor, +// float32 lambda, +// const std::optional& memory_config = std::nullopt) { +// return tt::tt_metal::hardshrink(input_tensor, lambda, memory_config.value_or(input_tensor.memory_config())); +// } +// ("hardshrink", ttl.tensor.hardshrink, "lambda"), # composite +// ("celu", ttl.tensor.celu, "alpha"), # composite +// ("softshrink", ttl.tensor.softshrink, "lambda"), # composite + + } // namespace unary } // namespace operations