Skip to content

Commit

Permalink
#7783: Port ttnn.upsample to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
xanderchin committed May 17, 2024
1 parent d29f733 commit 1f623e8
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 6 deletions.
19 changes: 19 additions & 0 deletions ttnn/cpp/pybind11/operations/data_movement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@ Example::
)doc");

ttnn::bind_registered_operation(
module,
ttnn::upsample,
R"doc(
Upsamples a given multi-channel 2D (spatial) data.
The input data is assumed to be of the form [N, H, W, C].
The algorithms available for upsampling are 'nearest' for now.
Args:
* :attr:`input_tensor`: the input tensor
* :attr:`scale_factor`: multiplier for spatial size. Has to match input size if it is a tuple.
)doc",
ttnn::pybind_arguments_t{
py::arg("input_tensor"),
py::arg("scale_factor"),
py::arg("memory_config") = std::nullopt
}
);
}

} // namespace data_movement
Expand Down
38 changes: 38 additions & 0 deletions ttnn/cpp/pybind11/operations/others.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ttnn/operations/others.hpp"

namespace py = pybind11;

namespace ttnn {
namespace operations {
namespace others {

void py_module(py::module& module) {

module.def("upsample", &upsample,
py::arg("input_tensor"),
py::arg("scale_factor"),
py::arg("memory_config") = std::nullopt,
R"doc(
Upsamples a given multi-channel 2D (spatial) data.
The input data is assumed to be of the form [N, H, W, C].
The algorithms available for upsampling are 'nearest' for now.
Args:
* :attr:`input_tensor`: the input tensor
* :attr:`scale_factor`: multiplier for spatial size. Has to match input size if it is a tuple.
)doc");

}

} // namespace normalization
} // namespace operations
} // namespace ttnn
70 changes: 70 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ttnn/cpp/ttnn/operations/core.hpp"
#include "tt_eager/tt_dnn/op_library/permute/permute_op.hpp"
#include "tt_eager/tt_dnn/op_library/concat/concat_op.hpp"
#include "tt_eager/tt_dnn/op_library/upsample/upsample_op.hpp"

namespace ttnn {
namespace operations {
Expand Down Expand Up @@ -169,6 +170,75 @@ inline ttnn::Tensor concat(
return output_tensor;
}

struct UpSample {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
2, 4, {ttnn::bfloat16, ttnn::bfloat8_b}, {ttnn::TILE_LAYOUT, ttnn::ROW_MAJOR_LAYOUT}, true, false, false, false}};
}

template <typename... Args>
static auto input_tensors_to_validate(const ttnn::Tensor& input_tensor, Args&&... args) {
return std::make_tuple(input_tensor);
}

static ttnn::Tensor execute(
const ttnn::Tensor& input_tensor,
std::variant<int, std::vector<int>> scale_factor,
std::optional<MemoryConfig> output_mem_config = std::nullopt) {

MemoryConfig mem_config = output_mem_config.value_or(ttnn::DRAM_MEMORY_CONFIG);

int scale_h = 1;
int scale_w = 1;
if (std::holds_alternative<int>(scale_factor)) {
scale_h = std::get<int>(scale_factor);
scale_w = std::get<int>(scale_factor);
}
else if(std::holds_alternative<std::vector<int>>(scale_factor)) {
auto sf = std::get<std::vector<int>>(scale_factor);
if(sf.size() == 2) {
scale_w = sf.at(0);
int scale_c = sf.at(1);
TT_FATAL(scale_c == 1);
}
else if(sf.size() == 3) {
scale_h = sf.at(0);
scale_w = sf.at(1);
int scale_c = sf.at(2);
TT_FATAL(scale_c == 1);
}
else if(sf.size() == 4) {
int scale_n = sf.at(0);
scale_h = sf.at(1);
scale_w = sf.at(2);
int scale_c = sf.at(3);
TT_FATAL(scale_n == 1);
TT_FATAL(scale_c == 1);
}
else {
TT_THROW("Unsupported scale factor");
}
}
// DEBUG
//fmt::print("scale_h: {}, scale_w: {}\n", scale_h, scale_w);

if(input_tensor.is_sharded()){
//TT_FATAL(not input_tensor.is_sharded());
int shard_height = input_tensor.memory_config().shard_spec.value().shape[0];
const auto batch_size = input_tensor.get_shape()[0];
const auto input_h = input_tensor.get_shape()[1];
const auto input_w = input_tensor.get_shape()[2];
const auto num_channels = input_tensor.get_shape()[3];
if(shard_height % input_w != 0) {
TT_FATAL(shard_height % input_w != 0);
}
}

return tt::tt_metal::upsample(input_tensor, scale_h, scale_w, mem_config);
}
};

} // namespace data_movement
} // namespace operations
constexpr auto upsample = ttnn::register_operation<ttnn::operations::data_movement::UpSample>("ttnn::upsample");
} // namespace ttnn
15 changes: 9 additions & 6 deletions ttnn/ttnn/operations/data_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,16 @@ def _golden_function(input_tensor: ttnn.Tensor, scale_factor: Tuple[float, float
ret = ret.permute(0, 2, 3, 1)
return ret


@ttnn.register_operation(
name="ttnn.upsample",
validate_input_tensors=_upsample_validate_input_tensors,
upsample = ttnn.register_operation(
golden_function=_golden_function,
)
def upsample(
)(ttnn._ttnn.operations.data_movement.upsample)

#@ttnn.register_operation(
# name="ttnn.upsample",
# validate_input_tensors=_upsample_validate_input_tensors,
# golden_function=_golden_function,
#)
def _upsample(
input_tensor: ttnn.Tensor,
scale_factor: Union[float, Tuple[float, float], Tuple[float, float, float], Tuple[float, float, float, float]],
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
Expand Down

0 comments on commit 1f623e8

Please sign in to comment.