Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#7783: Port ttnn.upsample to C++ #8237

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions ttnn/cpp/pybind11/operations/data_movement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@ Example::

)doc");

ttnn::bind_registered_operation(
module,
ttnn::upsample,
R"doc(
Upsamples a given multi-channel 2D (spatial) data.
The input data is assumed to be of the form [N, H, W, C].

The algorithms available for upsampling are 'nearest' for now.

Args:
* :attr:`input_tensor`: the input tensor
* :attr:`scale_factor`: multiplier for spatial size. Has to match input size if it is a tuple.
)doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("scale_factor"),
py::arg("memory_config") = std::nullopt
}
);
}

} // namespace data_movement
Expand Down
38 changes: 38 additions & 0 deletions ttnn/cpp/pybind11/operations/others.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ttnn/operations/others.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace others {

void py_module(py::module& module) {

module.def("upsample", &upsample,
py::arg("input_tensor"),
py::arg("scale_factor"),
py::arg("memory_config") = std::nullopt,
R"doc(
Upsamples a given multi-channel 2D (spatial) data.
The input data is assumed to be of the form [N, H, W, C].

The algorithms available for upsampling are 'nearest' for now.

Args:
* :attr:`input_tensor`: the input tensor
* :attr:`scale_factor`: multiplier for spatial size. Has to match input size if it is a tuple.
)doc");

}

} // namespace normalization
} // namespace operations
} // namespace ttnn
196 changes: 141 additions & 55 deletions ttnn/cpp/ttnn/operations/data_movement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@

#pragma once


#include "ttnn/cpp/ttnn/operations/core.hpp"
#include "tt_eager/tt_dnn/op_library/permute/permute_op.hpp"
#include "tt_eager/tt_dnn/op_library/concat/concat_op.hpp"
#include "tt_eager/tt_dnn/op_library/permute/permute_op.hpp"
#include "tt_eager/tt_dnn/op_library/upsample/upsample_op.hpp"
#include "ttnn/cpp/ttnn/operations/core.hpp"

namespace ttnn {
namespace operations {
namespace data_movement {

inline bool is_on_device(const Tensor& t) {
return t.storage_type() == tt::tt_metal::StorageType::DEVICE or t.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE;
return t.storage_type() == tt::tt_metal::StorageType::DEVICE or
t.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE;
}

inline bool has_tile_padding(const Tensor& t) {
if(t.get_shape().rank() > 1) {
if (t.get_shape().rank() > 1) {
auto the_shape = t.get_shape();
auto the_shape_with_padding = t.get_shape().with_tile_padding();
return the_shape[-1] != the_shape_with_padding[-1] or the_shape[-2] != the_shape_with_padding[-2];
Expand All @@ -30,135 +31,145 @@ inline bool has_tile_padding(const Tensor& t, int dim) {
int rank = t.get_shape().rank();
// Wrap dim
dim = dim < 0 ? rank + dim : dim;
TT_FATAL(dim >= 0 and dim < rank, "ttnn: Dimension out of range: dim {} cannot be used for tensors of rank {}", dim, rank);
TT_FATAL(
dim >= 0 and dim < rank,
"ttnn: Dimension out of range: dim {} cannot be used for tensors of rank {}",
dim,
rank);

if(dim < rank) {
if (dim < rank) {
auto the_shape = t.get_shape();
auto the_shape_with_padding = t.get_shape().with_tile_padding();
return the_shape[dim] != the_shape_with_padding[dim];
}
return false;
}

inline ttnn::Tensor permute(
const ttnn::Tensor& input_tensor,
const std::vector<int>& order
) {
inline ttnn::Tensor permute(const ttnn::Tensor& input_tensor, const std::vector<int>& order) {
const bool initial_input_tensor_on_device = is_on_device(input_tensor);
const auto input_layout = input_tensor.get_layout();
const auto input_rank = input_tensor.get_shape().rank();

TT_FATAL(input_rank <= 4);
TT_FATAL(input_rank == order.size(), "The number of dimensions in the tensor input does not match the length of the desired ordering");
TT_FATAL(
input_rank == order.size(),
"The number of dimensions in the tensor input does not match the length of the desired ordering");

auto adjust_order = [](const std::vector<int>& order) {
std::vector<std::int64_t> new_order;
auto adjust_order = [](const std::vector<int>& order) {
std::vector<std::int64_t> new_order;
TT_FATAL(order.size() <= 4);
int additional_ranks = 4 - order.size();
for (int i = 0; i < additional_ranks; i++) {
new_order.push_back(i);
}
for (int i = 0; i < order.size(); i++) {
new_order.push_back(order.at(i)+additional_ranks);
new_order.push_back(order.at(i) + additional_ranks);
}
return new_order;
};
auto itensor = (input_tensor.get_shape().rank() < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor;
auto iorder = adjust_order(order);
auto iorder = adjust_order(order);

if(has_tile_padding(itensor)) {
if (has_tile_padding(itensor)) {
itensor = ttnn::to_layout(itensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr);
}

TT_FATAL(is_on_device(itensor) and itensor.get_shape().rank() == 4);
auto output_tensor = tt::tt_metal::permute(itensor, iorder, ttnn::DRAM_MEMORY_CONFIG);
output_tensor = ttnn::to_layout(output_tensor, input_layout, std::nullopt, std::nullopt, (Device*)nullptr);

if(input_rank < 4){
if (input_rank < 4) {
const auto shape = output_tensor.get_shape();
const auto full_shape = output_tensor.get_shape().with_tile_padding();
std::vector<uint32_t> shape_vec{};
std::vector<uint32_t> full_shape_vec{};
int i = 0;
while(i < 3 and shape[i] == 1) i++;
for(i; i < shape.rank(); i++) {
while (i < 3 and shape[i] == 1) i++;
for (i; i < shape.rank(); i++) {
shape_vec.push_back(shape[i]);
full_shape_vec.push_back(full_shape[i]);
}
auto metal_shape = tt::tt_metal::Shape(shape_vec, full_shape_vec);
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape(metal_shape));
}

if(initial_input_tensor_on_device and not is_on_device(output_tensor)) {
if (initial_input_tensor_on_device and not is_on_device(output_tensor)) {
output_tensor = ttnn::to_device(output_tensor, input_tensor.device(), ttnn::DRAM_MEMORY_CONFIG);
}

return output_tensor;
}

inline ttnn::Tensor concat(
const std::vector<ttnn::Tensor> & input_tensors,
int dim,
const std::optional<MemoryConfig>& memory_config_arg
) {
const std::vector<ttnn::Tensor>& input_tensors, int dim, const std::optional<MemoryConfig>& memory_config_arg) {
TT_FATAL(input_tensors.size() > 0, "ttnn.concat: expected a non-empty list of Tensors!");

const auto memory_config = memory_config_arg.value_or(ttnn::DRAM_MEMORY_CONFIG);

if(input_tensors.size() == 1) {
if (input_tensors.size() == 1) {
return ttnn::to_memory_config(input_tensors.at(0), memory_config, std::nullopt);
}

// TODO: Issue #8426: Add validation for ttnn.concat for sharded inputs
//const bool all_tensors_are_tile_layout_without_padding = std::all_of(input_tensors.begin(), input_tensors.end(), [dim](const ttnn::Tensor& input_tensor){
// const bool all_tensors_are_tile_layout_without_padding = std::all_of(input_tensors.begin(), input_tensors.end(),
// [dim](const ttnn::Tensor& input_tensor){
// return input_tensor.get_layout() == ttnn::TILE_LAYOUT and not has_tile_padding(input_tensor, dim);
//});
//TT_FATAL(all_tensors_are_tile_layout_without_padding, "Not Implemented");
// TT_FATAL(all_tensors_are_tile_layout_without_padding, "Not Implemented");

const ttnn::Tensor& first_tensor = input_tensors.front();
const int rank = first_tensor.get_shape().rank();

// Wrap dim
dim = dim < 0 ? rank + dim : dim;
TT_FATAL(dim >= 0 and dim < rank, "ttnn: Dimension out of range: dim {} cannot be used for tensors of rank {}", dim, rank);
TT_FATAL(
dim >= 0 and dim < rank,
"ttnn: Dimension out of range: dim {} cannot be used for tensors of rank {}",
dim,
rank);

const bool shapes_match = std::all_of(input_tensors.begin(), input_tensors.end(), [first_tensor, dim](const ttnn::Tensor& t){
const bool shapes_match =
std::all_of(input_tensors.begin(), input_tensors.end(), [first_tensor, dim](const ttnn::Tensor& t) {
const auto& ft_shape = first_tensor.get_shape();
const auto& t_shape = t.get_shape();

const auto& ft_shape = first_tensor.get_shape();
const auto& t_shape = t.get_shape();
const bool ranks_match = ft_shape.rank() == t_shape.rank();
bool non_concat_dims_match = true;
for (int i = 0; i < ft_shape.rank(); i++) {
non_concat_dims_match &= dim == i or t_shape[i] == ft_shape[i];
}
// bool non_concat_padded_dims_match = true;
// for(int i = 0; i < ft_shape.rank(); i++) {
// non_concat_padded_dims_match &= dim == i or t_shape.with_tile_padding()[i] ==
// ft_shape.with_tile_padding()[i];
// }
return ranks_match and non_concat_dims_match; // and non_concat_padded_dims_match;
});

const bool ranks_match = ft_shape.rank() == t_shape.rank();
bool non_concat_dims_match = true;
for(int i = 0; i < ft_shape.rank(); i++) {
non_concat_dims_match &= dim == i or t_shape[i] == ft_shape[i];
}
//bool non_concat_padded_dims_match = true;
//for(int i = 0; i < ft_shape.rank(); i++) {
// non_concat_padded_dims_match &= dim == i or t_shape.with_tile_padding()[i] == ft_shape.with_tile_padding()[i];
//}
return ranks_match and non_concat_dims_match; // and non_concat_padded_dims_match;
});

TT_FATAL(shapes_match, "All dimensions must be the same size except for the dimension along which the contenation is taking place.");
TT_FATAL(
shapes_match,
"All dimensions must be the same size except for the dimension along which the contenation is taking place.");

std::vector<ttnn::Tensor> itensor;
std::transform(input_tensors.begin(), input_tensors.end(), std::back_inserter(itensor),
std::transform(
input_tensors.begin(),
input_tensors.end(),
std::back_inserter(itensor),
[rank](const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
auto output = (rank < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor;
return output;
}
);
});
// Convert dim after unsqueeze
dim = dim + 4 - rank;
auto output_tensor = tt::tt_metal::concat(itensor, dim, memory_config);
while(output_tensor.get_shape().rank() > rank) {
while (output_tensor.get_shape().rank() > rank) {
const auto shape = output_tensor.get_shape();
const auto full_shape = output_tensor.get_shape().with_tile_padding();
std::vector<uint32_t> shape_vec{};
std::vector<uint32_t> full_shape_vec{};
//int i = 0;
//while(i < 3 and shape[i] == 1) i++;
for(int i = 1; i < shape.rank(); i++) {
// int i = 0;
// while(i < 3 and shape[i] == 1) i++;
for (int i = 1; i < shape.rank(); i++) {
shape_vec.push_back(shape[i]);
full_shape_vec.push_back(full_shape[i]);
}
Expand All @@ -169,6 +180,81 @@ inline ttnn::Tensor concat(
return output_tensor;
}

} // namespace data_movement
} // namespace operations
} // namespace ttnn
struct UpSample {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
2,
4,
{ttnn::bfloat16, ttnn::bfloat8_b},
{ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT},
true,
false,
false,
false}};
}

template <typename... Args>
static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) {
return std::make_tuple(input_tensor);
}

static ttnn::Tensor execute_on_worker_thread(
const ttnn::Tensor& input_tensor,
std::variant<int, std::array<int, 2>, std::array<int, 3>, std::array<int, 4>> scale_factor,
std::optional<MemoryConfig> output_mem_config = std::nullopt) {
MemoryConfig mem_config = output_mem_config.value_or(ttnn::DRAM_MEMORY_CONFIG);

int scale_h = 1;
int scale_w = 1;
std::visit(
[&scale_h, &scale_w](auto&& sf) {
using T = std::decay_t<decltype(sf)>;
if constexpr (std::is_same_v<T, int>) {
scale_h = sf;
scale_w = sf;
} else if constexpr (std::is_same_v<T, std::array<int, 2>>) {
scale_w = sf.at(0);
int scale_c = sf.at(1);
TT_FATAL(scale_c == 1);
} else if constexpr (std::is_same_v<T, std::array<int, 3>>) {
scale_h = sf.at(0);
scale_w = sf.at(1);
int scale_c = sf.at(2);
TT_FATAL(scale_c == 1);
} else if constexpr (std::is_same_v<T, std::array<int, 4>>) {
int scale_n = sf.at(0);
scale_h = sf.at(1);
scale_w = sf.at(2);
int scale_c = sf.at(3);
TT_FATAL(scale_n == 1);
TT_FATAL(scale_c == 1);
} else {
// static_assert(false, "Unsupported scale factor");
static_assert(sizeof(T) != 0, "Type check failed.");
}
},
scale_factor);

// DEBUG
// fmt::print("scale_h: {}, scale_w: {}\n", scale_h, scale_w);

if (input_tensor.is_sharded()) {
// TT_FATAL(not input_tensor.is_sharded());
int shard_height = input_tensor.memory_config().shard_spec.value().shape[0];
const auto batch_size = input_tensor.get_shape()[0];
const auto input_h = input_tensor.get_shape()[1];
const auto input_w = input_tensor.get_shape()[2];
const auto num_channels = input_tensor.get_shape()[3];
if (shard_height % input_w != 0) {
TT_FATAL(shard_height % input_w != 0);
}
}

return tt::tt_metal::upsample(input_tensor, scale_h, scale_w, mem_config);
}
};

} // namespace data_movement
} // namespace operations
constexpr auto upsample = ttnn::register_operation<ttnn::operations::data_movement::UpSample>("ttnn::upsample");
} // namespace ttnn
Loading
Loading