Skip to content

Commit

Permalink
#13745:remove tensor.reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
nardoTT committed Dec 5, 2024
1 parent 2091153 commit 7963384
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 105 deletions.
5 changes: 3 additions & 2 deletions ttnn/cpp/ttnn/operations/data_movement/fold/fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.hpp"
#include "ttnn/cpp/ttnn/operations/data_movement/pad/pad.hpp"
#include "tt_metal/common/constants.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp"

#include "fold.hpp"

Expand Down Expand Up @@ -221,7 +222,7 @@ std::vector<Tensor> fold_with_transpose_sharded_(
// reshape
n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], c = tt_output_tensor.shape()[2],
h = tt_output_tensor.shape()[3];
tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h});
tt_output_tensor = ttnn::experimental::reshape(tt_output_tensor, ttnn::SimpleShape{n, (w / stride_w), (c * stride_w), h});

tt::log_debug("reshape_hc_output: {}", tt_output_tensor.shape());

Expand All @@ -234,7 +235,7 @@ std::vector<Tensor> fold_with_transpose_sharded_(
// reshape
n = tt_output_tensor.shape()[0], w = tt_output_tensor.shape()[1], h = tt_output_tensor.shape()[2],
c = tt_output_tensor.shape()[3];
tt_output_tensor = tt_output_tensor.reshape(ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)});
tt_output_tensor = ttnn::experimental::reshape(tt_output_tensor, ttnn::SimpleShape{n, w, (h / stride_h), (c * stride_h)});

tt::log_debug("reshape_hw_output: {}", tt_output_tensor.shape());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ttnn/operations/experimental/auto_format/auto_format.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "device/reshape_op.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp"

namespace ttnn::operations::data_movement {

Expand Down Expand Up @@ -57,7 +58,7 @@ ttnn::Tensor ReshapeOperation::invoke(
padded_output_shape[3] == input_tensor.get_padded_shape()[3])) {
// Don't need to do a check here to see the H and W both divisible by 32
// since handled within the tensor reshape method
return input_tensor.reshape(output_shape);
return ttnn::experimental::reshape(input_tensor, output_shape);
}
if (input_tensor.get_padded_shape() == padded_output_shape) {
return ttnn::operations::experimental::auto_format::AutoFormat::move_tensor_to_mem_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp"
#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp"
#include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp"

namespace ttnn::operations::data_movement {

Expand Down Expand Up @@ -53,7 +54,7 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape)
//This function is due to embedding issue 15558, once the issue is fixed we want to delete it
tt::log_warning("host_reshape is deprecated and will be removed in the near future");
if (!ttnn::has_storage_type_of(tensor, ttnn::StorageType::DEVICE)) {
return tensor.reshape(shape);
return ttnn::experimental::reshape(tensor, shape);
}
auto tensor_shape = tensor.shape();
auto layout = tensor.layout();
Expand All @@ -73,7 +74,7 @@ ttnn::Tensor host_reshape(const ttnn::Tensor& tensor, const ttnn::Shape& shape)
host_tensor_4d = ttnn::slice(host_tensor_4d, begins, ends, step, std::nullopt);
host_tensor = squeeze_from_4D(host_tensor_4d, tensor_shape.rank());
}
auto host_reshape_tensor = rm_tensor.reshape(shape);
auto host_reshape_tensor = ttnn::experimental::reshape(rm_tensor, shape);
auto final_layout_tensor =
ttnn::to_layout(host_reshape_tensor, layout, std::nullopt, std::nullopt, (Device*)nullptr);
auto device_tensor = ttnn::data_transfer_to_device(final_layout_tensor, device, memory_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "moreh_getitem_device_operation.hpp"
#include "ttnn/operations/moreh/moreh_helper_functions.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp"

namespace {
namespace CMAKE_UNIQUE_NAMESPACE {
Expand Down Expand Up @@ -59,7 +60,7 @@ MorehGetItemOperation::MorehGetItemRmFactory::cached_program_t MorehGetItemOpera
uint32_t index_end_dim = index_dims.back();

Tensor input_5d = input;
input_5d = input_5d.reshape(input_5d_shape);
input_5d = ttnn::experimental::reshape(input_5d, input_5d_shape);

auto input_5d_shape_without_padding = input_5d_shape.value.without_padding();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ttnn/operations/pool/global_avg_pool/global_avg_pool.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp"

namespace tt {
namespace tt_metal {
Expand Down Expand Up @@ -40,7 +41,7 @@ Tensor global_avg_pool2d(
input_padding.pad_value());
auto output_shape =
tt::tt_metal::LegacyShape({in_shape[0], 1, in_shape[1] * in_shape[2], in_shape[3]}, output_padding);
output = output.reshape(output_shape);
output = ttnn::experimental::reshape(output, output_shape);

output = pool_2d<PoolType::AVG>(output, memory_config, output_dtype);
return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ttnn/cpp/ttnn/operations/experimental/transformer/nlp_create_qkv_heads/nlp_create_qkv_heads.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/transformer/nlp_create_qkv_heads_falcon7b/nlp_create_qkv_heads_falcon7b.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/transformer/create_qkv_heads/create_qkv_heads.hpp"
#include "ttnn/cpp/ttnn/operations/experimental/reshape/reshape.hpp"

namespace ttnn::operations::transformer {

Expand Down Expand Up @@ -104,7 +105,7 @@ std::tuple<Tensor, Tensor, Tensor> SplitQueryKeyValueAndSplitHeadsOperation::inv
head_size,
padded_head_size);

const auto input_4d = input_tensor.reshape(ttnn::SimpleShape{
const auto input_4d = ttnn::experimental::reshape(input_tensor, ttnn::SimpleShape{
input_shape.with_tile_padding()[0],
1,
input_shape.with_tile_padding()[1],
Expand Down Expand Up @@ -168,7 +169,7 @@ std::tuple<Tensor, Tensor, Tensor> SplitQueryKeyValueAndSplitHeadsOperation::inv
"Invalid operation: KV tensor should not be provided when the input tensor is sharded. Please ensure that "
"the KV tensor is only used in non-sharded configurations.");

const auto input_tensor_4d = input_tensor.reshape(ttnn::SimpleShape{
const auto input_tensor_4d = ttnn::experimental::reshape(input_tensor, ttnn::SimpleShape{
input_shape.with_tile_padding()[0],
1,
input_shape.with_tile_padding()[1],
Expand All @@ -184,15 +185,15 @@ std::tuple<Tensor, Tensor, Tensor> SplitQueryKeyValueAndSplitHeadsOperation::inv
sequence_size_padded,
transpose_key);
} else {
const auto input_tensor_4d = input_tensor.reshape(ttnn::SimpleShape{
const auto input_tensor_4d = ttnn::experimental::reshape(input_tensor, ttnn::SimpleShape{
input_shape.with_tile_padding()[0],
1,
input_shape.with_tile_padding()[1],
input_shape.with_tile_padding()[2]});
std::optional<Tensor> input_tensor_kv_4d = std::nullopt;
if (input_tensor_kv.has_value()) {
auto padded_input_shape_kv = input_tensor_kv.value().get_shape().with_tile_padding();
input_tensor_kv_4d = input_tensor_kv.value().reshape(
input_tensor_kv_4d = ttnn::experimental::reshape(input_tensor_kv.value(),
ttnn::SimpleShape{padded_input_shape_kv[0], 1, padded_input_shape_kv[1], padded_input_shape_kv[2]});
}
const auto outputs = ttnn::experimental::nlp_create_qkv_heads(
Expand Down
5 changes: 0 additions & 5 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,6 @@ const bool Tensor::is_sharded() const {

uint32_t Tensor::element_size() const { return tensor_impl::element_size_bytes(this->get_dtype()); }

Tensor Tensor::reshape(const ttnn::SimpleShape& new_shape) const {
return tensor_ops::tensor_reshape(*this, new_shape);
}

Tensor Tensor::reshape(const ttnn::Shape& new_shape) const { return tensor_ops::tensor_reshape(*this, new_shape); }

bool Tensor::is_allocated() const {
ZoneScoped;
Expand Down
2 changes: 0 additions & 2 deletions ttnn/cpp/ttnn/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ struct Tensor {
// ======================================================================================
// Low Level APIs
// ======================================================================================
Tensor reshape(const ttnn::SimpleShape& new_shape) const;
Tensor reshape(const ttnn::Shape& new_shape) const;

// ======================================================================================
// Getters
Expand Down
84 changes: 0 additions & 84 deletions ttnn/cpp/ttnn/tensor/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,89 +349,5 @@ Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShap
return output;
}

Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape) {
ZoneScoped;
GraphTracker::instance().track_function_start("Tensor::reshape", input_tensor, new_shape);
const auto& new_padded_shape = new_shape.padded_shape();
const auto tile = input_tensor.get_tensor_spec().tile();
TT_ASSERT(
input_tensor.volume() == new_padded_shape.volume(),
"{} != {}",
input_tensor.volume(),
new_padded_shape.volume());
if (input_tensor.get_layout() == Layout::TILE) {
TT_ASSERT(
new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 &&
new_padded_shape[-1] % tile.get_tile_shape()[1] == 0 &&
"Expected a multiple of 32 for H, W (or -1 evaluating to such) in Tensor::reshape()!");
}
auto output = std::visit(
[&input_tensor, &new_shape, &tile](auto&& storage) -> Tensor {
using T = std::decay_t<decltype(storage)>;
const auto& tensor = input_tensor;
if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
auto updated_storage = std::get<T>(tensor.get_storage());
for (int i = 0; i < updated_storage.shapes.size(); i++) {
updated_storage.shapes[i] = new_shape;
}
return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
if constexpr (std::is_same_v<T, MultiDeviceStorage>) {
MultiDeviceStorage updated_storage = std::get<T>(tensor.get_storage());
std::unordered_map<int, ttnn::Shape> new_shapes;

for (auto device_id : updated_storage.ordered_device_ids) {
new_shapes.insert({device_id, new_shape});
}
updated_storage.shapes = new_shapes;
return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
if constexpr (std::is_same_v<T, DeviceStorage>) {
if (input_tensor.get_layout() == Layout::ROW_MAJOR) {
if (tensor.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
DeviceStorage device_storage = std::get<T>(tensor.get_storage());
DeviceBuffer device_buffer = device_storage.get_buffer();
device_buffer->set_page_size(new_shape[-1] * tensor.element_size());
device_storage.insert_buffer(device_buffer);
return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
} else {
DeviceStorage device_storage = std::get<T>(tensor.get_storage());
DeviceBuffer device_buffer = device_storage.get_buffer();
ShardSpecBuffer shard_spec_buffer = device_buffer->shard_spec();

auto shard_spec = shard_spec_buffer.tensor_shard_spec;
auto shard_shape = shard_spec.shape;

uint32_t mul_div = new_shape[-1] > shard_shape[1] ? (new_shape[-1] / shard_shape[1])
: (shard_shape[1] / new_shape[-1]);
shard_spec.shape[0] =
new_shape[-1] > shard_shape[1] ? shard_shape[0] / mul_div : shard_shape[0] * mul_div;
shard_spec.shape[1] = new_shape[-1];

shard_spec_buffer.page_shape = {1, new_shape[-1]};
shard_spec_buffer.tensor2d_shape = {shard_spec.shape[0], 1};
shard_spec_buffer.set_shard_spec(shard_spec);

device_buffer->set_shard_spec(shard_spec_buffer);
device_storage.insert_buffer(device_buffer);

return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
} else {
return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
} else {
return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile);
}
},
input_tensor.get_storage());
output = tt::tt_metal::set_tensor_id(output);
GraphTracker::instance().track_function_end(output);
return output;
}

Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::SimpleShape& new_shape) {
return tensor_reshape(input_tensor, ttnn::Shape(new_shape.view()));
}

} // namespace tt::tt_metal::tensor_ops
3 changes: 0 additions & 3 deletions ttnn/cpp/ttnn/tensor/tensor_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,4 @@ Tensor tensor_pad_to_tile(const Tensor& input_tensor, float pad_value);

Tensor tensor_unpad_from_tile(const Tensor& input_tensor, const ttnn::SimpleShape& output_tensor_shape);

Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::SimpleShape& new_shape);
Tensor tensor_reshape(const Tensor& input_tensor, const ttnn::Shape& new_shape);

} // namespace tt::tt_metal::tensor_ops

0 comments on commit 7963384

Please sign in to comment.