diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp index 754ea6e24b12..8326d1cb057e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp @@ -12,6 +12,7 @@ #include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp" #include "tt_metal/common/constants.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp" #include "fold.hpp" @@ -221,7 +222,7 @@ std::vector fold_with_transpose_sharded_( // reshape n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], c = tt_output_tensor.shape()[2], h = tt_output_tensor.shape()[3]; - tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h}); + tt_output_tensor = ttnn::experimental::reshape(tt_output_tensor, ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h}); tt::log_debug("reshape_hc_output: {}", tt_output_tensor.shape()); @@ -234,7 +235,7 @@ std::vector fold_with_transpose_sharded_( // reshape n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], h = tt_output_tensor.shape()[2], c = tt_output_tensor.shape()[3]; - tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)}); + tt_output_tensor = ttnn::experimental::reshape(tt_output_tensor, ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)}); tt::log_debug("reshape_hw_output: {}", tt_output_tensor.shape()); diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp index 50b069c43c21..eebd1d652a62 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp @@ -10,6 +10,7 @@ #include "ttnn/operations/experimental/auto_format/auto_format.hpp" #include "ttnn/tensor/tensor_utils.hpp" #include "device/reshape_op.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp" namespace ttnn::operations::data_movement { @@ -57,7 +58,7 @@ ttnn::Tensor ReshapeOperation::invoke( padded_output_shape[3] == input_tensor.get_padded_shape()[3])) { // Don't need to do a check here to see the H and W both divisible by 32 // since handled within the tensor reshape method - return input_tensor.reshape(output_shape); + return ttnn::experimental::reshape(input_tensor, output_shape); } if (input_tensor.get_padded_shape() == padded_output_shape) { return ttnn::operations::experimental::auto_format::AutoFormat::move_tensor_to_mem_config( diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp index 6af8646a0078..775f6605642f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.cpp @@ -21,6 +21,7 @@ #include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp" #include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp" #include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp" namespace ttnn::operations::data_movement { @@ -53,7 +54,7 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) //This function is due to embedding issue 15558, once the issue is fixed we want to delete it tt::log_warning("host_reshape is deprecated and will be removed in the near future"); if (!ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) { - return tensor.reshape(shape); + return ttnn::experimental::reshape(tensor, shape); } auto tensor_shape = tensor.shape(); auto layout = tensor.layout(); @@ -73,7 +74,7 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape) host_tensor_4d = ttnn::slice(host_tensor_4d, begins, ends, step, std::nullopt); host_tensor = squeeze_from_4D(host_tensor_4d, tensor_shape.rank()); } - auto host_reshape_tensor = rm_tensor.reshape(shape); + auto host_reshape_tensor = ttnn::experimental::reshape(rm_tensor, shape); auto final_layout_tensor = ttnn::to_layout(host_reshape_tensor, layout, std::nullopt, std::nullopt, (Device*)nullptr); auto device_tensor = ttnn::data_transfer_to_device(final_layout_tensor, device, memory_config); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp index 8a37c4b6f7dd..099361491138 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_rm_factory.cpp @@ -4,6 +4,7 @@ #include "moreh_getitem_device_operation.hpp" #include "ttnn/operations/moreh/moreh_helper_functions.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp" namespace { namespace CMAKE_UNIQUE_NAMESPACE { @@ -59,7 +60,7 @@ MorehGetItemOperation::MorehGetItemRmFactory::cached_program_t MorehGetItemOpera uint32_t index_end_dim = index_dims.back(); Tensor input_5d = input; - input_5d = input_5d.reshape(input_5d_shape); + input_5d = ttnn::experimental::reshape(input_5d, input_5d_shape); auto input_5d_shape_without_padding = input_5d_shape.value.without_padding(); diff --git a/ttnn/cpp/ttnn/operations/pool/global_avg_pool/global_avg_pool.cpp b/ttnn/cpp/ttnn/operations/pool/global_avg_pool/global_avg_pool.cpp index 557559206121..ecf5b448d1fa 100644 --- a/ttnn/cpp/ttnn/operations/pool/global_avg_pool/global_avg_pool.cpp +++ b/ttnn/cpp/ttnn/operations/pool/global_avg_pool/global_avg_pool.cpp @@ -4,6 +4,7 @@ #include "ttnn/operations/pool/global_avg_pool/global_avg_pool.hpp" #include "ttnn/operations/reduction/generic/generic_reductions.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp" namespace tt { namespace tt_metal { @@ -40,7 +41,7 @@ Tensor global_avg_pool2d( input_padding.pad_value()); auto output_shape = tt::tt_metal::LegacyShape({in_shape[0], 1, in_shape[1] * in_shape[2], in_shape[3]}, output_padding); - output = output.reshape(output_shape); + output = ttnn::experimental::reshape(output, output_shape); output = pool_2d(output, memory_config, output_dtype); return output; diff --git a/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp b/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp index 4ac733211eb0..16fe003ad4aa 100644 --- a/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/split_query_key_value_and_split_heads/split_query_key_value_and_split_heads.cpp @@ -9,6 +9,7 @@ #include "ttnn/cpp/ttnn/operations/experimental/transformer/nlp_create_qkv_heads/nlp_create_qkv_heads.hpp" #include "ttnn/cpp/ttnn/operations/experimental/transformer/nlp_create_qkv_heads_falcon7b/nlp_create_qkv_heads_falcon7b.hpp" #include "ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads/create_qkv_heads.hpp" +#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp" namespace ttnn::operations::transformer { @@ -104,7 +105,7 @@ std::tuple SplitQueryKeyValueAndSplitHeadsOperation::inv head_size, padded_head_size); - const auto input_4d = input_tensor.reshape(ttnn::SimpleShape{ + const auto input_4d = ttnn::experimental::reshape(input_tensor, ttnn::SimpleShape{ input_shape.with_tile_padding()[0], 1, input_shape.with_tile_padding()[1], @@ -168,7 +169,7 @@ std::tuple SplitQueryKeyValueAndSplitHeadsOperation::inv "Invalid operation: KV tensor should not be provided when the input tensor is sharded. Please ensure that " "the KV tensor is only used in non-sharded configurations."); - const auto input_tensor_4d = input_tensor.reshape(ttnn::SimpleShape{ + const auto input_tensor_4d = ttnn::experimental::reshape(input_tensor, ttnn::SimpleShape{ input_shape.with_tile_padding()[0], 1, input_shape.with_tile_padding()[1], @@ -184,7 +185,7 @@ std::tuple SplitQueryKeyValueAndSplitHeadsOperation::inv sequence_size_padded, transpose_key); } else { - const auto input_tensor_4d = input_tensor.reshape(ttnn::SimpleShape{ + const auto input_tensor_4d = ttnn::experimental::reshape(input_tensor, ttnn::SimpleShape{ input_shape.with_tile_padding()[0], 1, input_shape.with_tile_padding()[1], @@ -192,7 +193,7 @@ std::tuple SplitQueryKeyValueAndSplitHeadsOperation::inv std::optional input_tensor_kv_4d = std::nullopt; if (input_tensor_kv.has_value()) { auto padded_input_shape_kv = input_tensor_kv.value().get_shape().with_tile_padding(); - input_tensor_kv_4d = input_tensor_kv.value().reshape( + input_tensor_kv_4d = ttnn::experimental::reshape(input_tensor_kv.value(), ttnn::SimpleShape{padded_input_shape_kv[0], 1, padded_input_shape_kv[1], padded_input_shape_kv[2]}); } const auto outputs = ttnn::experimental::nlp_create_qkv_heads( diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index d6c80e1e92d4..05481d26ff7b 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -626,11 +626,6 @@ const bool Tensor::is_sharded() const { uint32_t Tensor::element_size() const { return tensor_impl::element_size_bytes(this->get_dtype()); } -Tensor Tensor::reshape(const ttnn::SimpleShape& new_shape) const { - return tensor_ops::tensor_reshape(*this, new_shape); -} - -Tensor Tensor::reshape(const ttnn::Shape& new_shape) const { return tensor_ops::tensor_reshape(*this, new_shape); } bool Tensor::is_allocated() const { ZoneScoped; diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 7a2976ac8f22..7ebf6905b957 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -182,8 +182,6 @@ struct Tensor { // ====================================================================================== // Low Level APIs // ====================================================================================== - Tensor reshape(const ttnn::SimpleShape& new_shape) const; - Tensor reshape(const ttnn::Shape& new_shape) const; // ====================================================================================== // Getters diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 8a46d676dc79..e4f2e2c736be 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -349,89 +349,5 @@ Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShap return output; } -Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) { - ZoneScoped; - GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_shape); - const auto& new_padded_shape = new_shape.padded_shape(); - const auto tile = input_tensor.get_tensor_spec().tile(); - TT_ASSERT( - input_tensor.volume() == new_padded_shape.volume(), - "{} != {}", - input_tensor.volume(), - new_padded_shape.volume()); - if (input_tensor.get_layout() == Layout::TILE) { - TT_ASSERT( - new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 && - new_padded_shape[-1] % tile.get_tile_shape()[1] == 0 && - "Expected a multiple of 32 for H, W (or -1 evaluating to such) in Tensor::reshape()!"); - } - auto output = std::visit( - [&input_tensor, &new_shape, &tile](auto&& storage) -> Tensor { - using T = std::decay_t; - const auto& tensor = input_tensor; - if constexpr (std::is_same_v) { - auto updated_storage = std::get(tensor.get_storage()); - for (int i = 0; i < updated_storage.shapes.size(); i++) { - updated_storage.shapes[i] = new_shape; - } - return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); - } - if constexpr (std::is_same_v) { - MultiDeviceStorage updated_storage = std::get(tensor.get_storage()); - std::unordered_map new_shapes; - - for (auto device_id : updated_storage.ordered_device_ids) { - new_shapes.insert({device_id, new_shape}); - } - updated_storage.shapes = new_shapes; - return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); - } - if constexpr (std::is_same_v) { - if (input_tensor.get_layout() == Layout::ROW_MAJOR) { - if (tensor.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { - DeviceStorage device_storage = std::get(tensor.get_storage()); - DeviceBuffer device_buffer = device_storage.get_buffer(); - device_buffer->set_page_size(new_shape[-1] * tensor.element_size()); - device_storage.insert_buffer(device_buffer); - return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); - } else { - DeviceStorage device_storage = std::get(tensor.get_storage()); - DeviceBuffer device_buffer = device_storage.get_buffer(); - ShardSpecBuffer shard_spec_buffer = device_buffer->shard_spec(); - - auto shard_spec = shard_spec_buffer.tensor_shard_spec; - auto shard_shape = shard_spec.shape; - - uint32_t mul_div = new_shape[-1] > shard_shape[1] ? (new_shape[-1] / shard_shape[1]) - : (shard_shape[1] / new_shape[-1]); - shard_spec.shape[0] = - new_shape[-1] > shard_shape[1] ? shard_shape[0] / mul_div : shard_shape[0] * mul_div; - shard_spec.shape[1] = new_shape[-1]; - - shard_spec_buffer.page_shape = {1, new_shape[-1]}; - shard_spec_buffer.tensor2d_shape = {shard_spec.shape[0], 1}; - shard_spec_buffer.set_shard_spec(shard_spec); - - device_buffer->set_shard_spec(shard_spec_buffer); - device_storage.insert_buffer(device_buffer); - - return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); - } - } else { - return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); - } - } else { - return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); - } - }, - input_tensor.get_storage()); - output = tt::tt_metal::set_tensor_id(output); - GraphTracker::instance().track_function_end(output); - return output; -} - -Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::SimpleShape& new_shape) { - return tensor_reshape(input_tensor, ttnn::Shape(new_shape.view())); -} } // namespace tt::tt_metal::tensor_ops diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index 98f8103c151c..348bf4cb5f27 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -49,7 +49,4 @@ Tensor tensor_pad_to_tile(const Tensor& input_tensor, float pad_value); Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShape& output_tensor_shape); -Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::SimpleShape& new_shape); -Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape); - } // namespace tt::tt_metal::tensor_ops