Skip to content

Commit

Permalink
#9755: move ttnn.concat to match the new file structure
Browse files Browse the repository at this point in the history
- TODO: move program factory, ttlib C++ implementation, ttlib C++/python references and delete tt_eager implementation
  • Loading branch information
sjameelTT committed Jul 4, 2024
1 parent be24c41 commit fec343c
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 154 deletions.
80 changes: 0 additions & 80 deletions ttnn/cpp/ttnn/operations/data_movement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,86 +106,6 @@ inline ttnn::Tensor permute(const ttnn::Tensor& input_tensor, const std::vector<
return output_tensor;
}

inline ttnn::Tensor concat(
const std::vector<ttnn::Tensor>& input_tensors, int dim, const std::optional<MemoryConfig>& memory_config_arg) {
TT_FATAL(input_tensors.size() > 0, "ttnn.concat: expected a non-empty list of Tensors!");

const auto memory_config = memory_config_arg.value_or(ttnn::DRAM_MEMORY_CONFIG);

if (input_tensors.size() == 1) {
return ttnn::to_memory_config(input_tensors.at(0), memory_config, std::nullopt);
}

// TODO: Issue #8426: Add validation for ttnn.concat for sharded inputs
// const bool all_tensors_are_tile_layout_without_padding = std::all_of(input_tensors.begin(), input_tensors.end(),
// [dim](const ttnn::Tensor& input_tensor){
// return input_tensor.get_layout() == ttnn::TILE_LAYOUT and not has_tile_padding(input_tensor, dim);
//});
// TT_FATAL(all_tensors_are_tile_layout_without_padding, "Not Implemented");

const ttnn::Tensor& first_tensor = input_tensors.front();
const int rank = first_tensor.get_shape().rank();

// Wrap dim
dim = dim < 0 ? rank + dim : dim;
TT_FATAL(
dim >= 0 and dim < rank,
"ttnn: Dimension out of range: dim {} cannot be used for tensors of rank {}",
dim,
rank);

const bool shapes_match =
std::all_of(input_tensors.begin(), input_tensors.end(), [first_tensor, dim](const ttnn::Tensor& t) {
const auto& ft_shape = first_tensor.get_shape();
const auto& t_shape = t.get_shape();

const bool ranks_match = ft_shape.rank() == t_shape.rank();
bool non_concat_dims_match = true;
for (int i = 0; i < ft_shape.rank(); i++) {
non_concat_dims_match &= dim == i or t_shape[i] == ft_shape[i];
}
// bool non_concat_padded_dims_match = true;
// for(int i = 0; i < ft_shape.rank(); i++) {
// non_concat_padded_dims_match &= dim == i or t_shape.with_tile_padding()[i] ==
// ft_shape.with_tile_padding()[i];
// }
return ranks_match and non_concat_dims_match; // and non_concat_padded_dims_match;
});

TT_FATAL(
shapes_match,
"All dimensions must be the same size except for the dimension along which the contenation is taking place.");

std::vector<ttnn::Tensor> itensor;
std::transform(
input_tensors.begin(),
input_tensors.end(),
std::back_inserter(itensor),
[rank](const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
auto output = (rank < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor;
return output;
});
// Convert dim after unsqueeze
dim = dim + 4 - rank;
auto output_tensor = tt::tt_metal::concat(itensor, dim, memory_config);
while (output_tensor.get_shape().rank() > rank) {
const auto shape = output_tensor.get_shape();
const auto full_shape = output_tensor.get_shape().with_tile_padding();
std::vector<uint32_t> shape_vec{};
std::vector<uint32_t> full_shape_vec{};
// int i = 0;
// while(i < 3 and shape[i] == 1) i++;
for (int i = 1; i < shape.rank(); i++) {
shape_vec.push_back(shape[i]);
full_shape_vec.push_back(full_shape[i]);
}
auto metal_shape = tt::tt_metal::Shape(shape_vec, full_shape_vec);
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape(metal_shape));
}

return output_tensor;
}

struct UpSample {
static inline const std::array<TensorSchema, 1> input_tensor_schemas() {
return {ttnn::TensorSchema{
Expand Down
113 changes: 113 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "tt_eager/tensor/types.hpp"
#include "ttnn/cpp/ttnn/operations/core.hpp"
#include "tt_eager/tt_dnn/op_library/concat/concat_op.hpp"

#include <ranges>


namespace ttnn {
namespace operations {
namespace data_movement {

struct Concat {

// Wrapper for TTDNN
static inline ttnn::Tensor execute_on_worker_thread(uint8_t queue_id, const std::vector<ttnn::Tensor>& input_tensors, int dim, const std::optional<MemoryConfig>& memory_config, std::optional<ttnn::Tensor> &optional_output_tensor) {
TT_FATAL(input_tensors.size() > 0, "ttnn.concat: expected a non-empty list of Tensors!");
TT_FATAL(!optional_output_tensor.has_value(), "optional output tensor currently unsupported!");
const auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); // should match input tensor memory config when unpopulated but causes CI errors for now

if (input_tensors.size() == 1) {
return ttnn::to_memory_config(input_tensors.at(0), mem_config, std::nullopt);
}

// TODO: Issue #8426: Add validation for ttnn.concat for sharded inputs
// const bool all_tensors_are_tile_layout_without_padding = std::all_of(input_tensors.begin(), input_tensors.end(),
// [dim](const ttnn::Tensor& input_tensor){
// return input_tensor.get_layout() == ttnn::TILE_LAYOUT and not has_tile_padding(input_tensor, dim);
//});
// TT_FATAL(all_tensors_are_tile_layout_without_padding, "Not Implemented");

const ttnn::Tensor& first_tensor = input_tensors.front();
const int rank = first_tensor.get_shape().rank();

dim = first_tensor.get_legacy_shape().get_normalized_index(dim);

TT_FATAL(
dim >= 0 and dim < rank,
"ttnn: Dimension out of range: dim {} cannot be used for tensors of rank {}",
dim,
rank);

const bool shapes_match =
std::all_of(input_tensors.begin(), input_tensors.end(), [first_tensor, dim](const ttnn::Tensor& t) {
const auto& ft_shape = first_tensor.get_shape();
const auto& t_shape = t.get_shape();

const bool ranks_match = ft_shape.rank() == t_shape.rank();
bool non_concat_dims_match = true;
for (int i = 0; i < ft_shape.rank(); i++) {
non_concat_dims_match &= dim == i or t_shape[i] == ft_shape[i];
}
// bool non_concat_padded_dims_match = true;
// for(int i = 0; i < ft_shape.rank(); i++) {
// non_concat_padded_dims_match &= dim == i or t_shape.with_tile_padding()[i] ==
// ft_shape.with_tile_padding()[i];
// }
return ranks_match and non_concat_dims_match; // and non_concat_padded_dims_match;
});

TT_FATAL(
shapes_match,
"All dimensions must be the same size except for the dimension along which the contenation is taking place.");

std::vector<ttnn::Tensor> itensor;
std::transform(
input_tensors.begin(),
input_tensors.end(),
std::back_inserter(itensor),
[rank](const ttnn::Tensor& input_tensor) -> ttnn::Tensor {
auto output = (rank < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor;
return output;
});
// Convert dim after unsqueeze
dim = dim + 4 - rank;
auto output_tensor = tt::tt_metal::concat(itensor, dim, mem_config);
while (output_tensor.get_shape().rank() > rank) {
const auto shape = output_tensor.get_shape();
const auto full_shape = output_tensor.get_shape().with_tile_padding();
std::vector<uint32_t> shape_vec{};
std::vector<uint32_t> full_shape_vec{};
// int i = 0;
// while(i < 3 and shape[i] == 1) i++;
for (int i = 1; i < shape.rank(); i++) {
shape_vec.push_back(shape[i]);
full_shape_vec.push_back(full_shape[i]);
}
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape::from_vector(shape_vec, full_shape_vec));
}

return output_tensor;

}

static inline ttnn::Tensor execute_on_worker_thread(const std::vector<ttnn::Tensor>& input_tensors, int dim, const std::optional<MemoryConfig>& memory_config, std::optional<ttnn::Tensor> &optional_output_tensor) {
constexpr uint8_t DefaultQueueId = 0;
return execute_on_worker_thread(DefaultQueueId, input_tensors, dim, memory_config, optional_output_tensor);

}

};

} // namespace data_movement
} // namespace operations

constexpr auto concat = ttnn::register_operation<ttnn::operations::data_movement::Concat>("ttnn::concat");

} // namespace ttnn
66 changes: 66 additions & 0 deletions ttnn/cpp/ttnn/operations/data_movement/concat/concat_pybind.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "ttnn/cpp/pybind11/decorators.hpp"

#include "concat.hpp"

namespace ttnn::operations::data_movement::detail {
namespace py = pybind11;

void bind_concat(py::module& module) {
const auto doc = R"doc(
Concats :attr:`tensors` in the given :attr:`dim`.
Args:
* :attr:`tensors`: the tensors to be concatenated.
* :attr:`dim`: the concatenating dimension.
Keyword Args:
* :attr:`memory_config`: the memory configuration to use for the operation
* :attr:`queue_id` (Optional[uint8]): command queue id
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
Example:
>>> tensor = ttnn.concat(ttnn.from_torch(torch.zeros((1, 1, 64, 32), ttnn.from_torch(torch.zeros((1, 1, 64, 32), dim=3)), device)
>>> tensor1 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device)
>>> tensor2 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device)
>>> output = ttnn.concat([tensor1, tensor2], dim=4)
>>> print(output.shape)
[1, 1, 32, 64]
)doc";

using OperationType = decltype(ttnn::concat);
ttnn::bind_registered_operation(
module,
ttnn::concat,
doc,
ttnn::pybind_overload_t{
[] (const OperationType& self,
const std::vector<ttnn::Tensor>& tensors,
const int dim,
std::optional<ttnn::Tensor> &optional_output_tensor,
std::optional<ttnn::MemoryConfig>& memory_config,
uint8_t queue_id) {
return self(queue_id, tensors, dim, memory_config, optional_output_tensor);
},
py::arg("tensors"),
py::arg("dim") = 0,
py::kw_only(),
py::arg("output_tensor").noconvert() = std::nullopt,
py::arg("memory_config") = std::nullopt,
py::arg("queue_id") = 0,
});
}


} // namespace ttnn::operations::data_movement::detail
35 changes: 2 additions & 33 deletions ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ttnn/cpp/pybind11/decorators.hpp"
#include "ttnn/operations/data_movement.hpp"
#include "ttnn/operations/data_movement/pad/pad_pybind.hpp"
#include "ttnn/operations/data_movement/concat/concat_pybind.hpp"

namespace py = pybind11;

Expand Down Expand Up @@ -42,38 +43,6 @@ Permutes :attr:`input_tensor` using :attr:`order`.
doc);
}

void bind_concat(py::module& module) {
const auto doc = R"doc(
Concats :attr:`tensors` in the given :attr:`dim`.
Args:
* :attr:`tensors`: the tensors to be concatenated.
* :attr:`dim`: the concatenating dimension.
Keyword Args:
* :attr:`memory_config`: the memory configuration to use for the operation
Example:
>>> tensor = ttnn.concat(ttnn.from_torch(torch.zeros((1, 1, 64, 32), ttnn.from_torch(torch.zeros((1, 1, 64, 32), dim=3)), device)
>>> tensor1 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device)
>>> tensor2 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device)
>>> output = ttnn.concat([tensor1, tensor2], dim=4)
>>> print(output.shape)
[1, 1, 32, 64]
)doc";
module.def(
"concat",
&concat,
py::arg("input_tensor"),
py::arg("dim") = 0,
py::kw_only(),
py::arg("memory_config") = std::nullopt,
doc);
}

void bind_upsample(py::module& module) {
const auto doc = R"doc(
Upsamples a given multi-channel 2D (spatial) data.
Expand Down Expand Up @@ -173,7 +142,7 @@ Keyword Args:

void py_module(py::module& module) {
bind_permute(module);
bind_concat(module);
detail::bind_concat(module);
bind_upsample(module);
detail::bind_pad(module);
bind_repeat(module);
Expand Down
44 changes: 3 additions & 41 deletions ttnn/ttnn/operations/data_movement.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,48 +119,10 @@ def _golden_function(tensors, dim=0, **_):
return torch.concat(tensors, dim)


def _concat_validate_input_tensors(operation_name, tensors, dim, *args, **kwargs):
for input_tensor in tensors:
ttnn.validate_input_tensor(
operation_name,
input_tensor,
ranks=(2, 3, 4),
dtypes=(ttnn.bfloat16, ttnn.bfloat8_b, ttnn.uint8, ttnn.uint16, ttnn.int32, ttnn.uint32),
layouts=(ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT),
can_be_on_device=True,
can_be_on_cpu=False,
)


doc = r"""
concat(tensors: List[ttnn.Tensor], dim: int = 0, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Concats :attr:`tensors` in the given :attr:`dim`.
Args:
* :attr:`tensors`: the tensors to be concatenated.
* :attr:`dim`: the concatenating dimension.
Keyword Args:
* :attr:`memory_config`: the memory configuration to use for the operation
Example::
>>> tensor = ttnn.concat(ttnn.from_torch(torch.zeros((1, 1, 64, 32), ttnn.from_torch(torch.zeros((1, 1, 64, 32), dim=3)), device)
>>> tensor1 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device)
>>> tensor2 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device)
>>> output = ttnn.concat([tensor1, tensor2], dim=4)
>>> print(output.shape)
[1, 1, 32, 64]
"""
ttnn.register_operation(
name="ttnn.concat",
validate_input_tensors=_concat_validate_input_tensors,
ttnn.attach_golden_function(
ttnn._ttnn.operations.data_movement.concat,
golden_function=_golden_function,
doc=doc,
)(ttnn._ttnn.operations.data_movement.concat)
)


def _golden_function(tensor, repeats, dim=0, **_):
Expand Down

0 comments on commit fec343c

Please sign in to comment.