-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#9755: move ttnn.concat to match the new file structure
- TODO: move program factory, ttlib C++ implementation, ttlib C++/python references and delete tt_eager implementation
- Loading branch information
Showing
5 changed files
with
184 additions
and
154 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
ttnn/cpp/ttnn/operations/data_movement/concat/concat.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "tt_eager/tensor/types.hpp" | ||
#include "ttnn/cpp/ttnn/operations/core.hpp" | ||
#include "tt_eager/tt_dnn/op_library/concat/concat_op.hpp" | ||
|
||
#include <ranges> | ||
|
||
|
||
namespace ttnn { | ||
namespace operations { | ||
namespace data_movement { | ||
|
||
struct Concat { | ||
|
||
// Wrapper for TTDNN | ||
static inline ttnn::Tensor execute_on_worker_thread(uint8_t queue_id, const std::vector<ttnn::Tensor>& input_tensors, int dim, const std::optional<MemoryConfig>& memory_config, std::optional<ttnn::Tensor> &optional_output_tensor) { | ||
TT_FATAL(input_tensors.size() > 0, "ttnn.concat: expected a non-empty list of Tensors!"); | ||
TT_FATAL(!optional_output_tensor.has_value(), "optional output tensor currently unsupported!"); | ||
const auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); // should match input tensor memory config when unpopulated but causes CI errors for now | ||
|
||
if (input_tensors.size() == 1) { | ||
return ttnn::to_memory_config(input_tensors.at(0), mem_config, std::nullopt); | ||
} | ||
|
||
// TODO: Issue #8426: Add validation for ttnn.concat for sharded inputs | ||
// const bool all_tensors_are_tile_layout_without_padding = std::all_of(input_tensors.begin(), input_tensors.end(), | ||
// [dim](const ttnn::Tensor& input_tensor){ | ||
// return input_tensor.get_layout() == ttnn::TILE_LAYOUT and not has_tile_padding(input_tensor, dim); | ||
//}); | ||
// TT_FATAL(all_tensors_are_tile_layout_without_padding, "Not Implemented"); | ||
|
||
const ttnn::Tensor& first_tensor = input_tensors.front(); | ||
const int rank = first_tensor.get_shape().rank(); | ||
|
||
dim = first_tensor.get_legacy_shape().get_normalized_index(dim); | ||
|
||
TT_FATAL( | ||
dim >= 0 and dim < rank, | ||
"ttnn: Dimension out of range: dim {} cannot be used for tensors of rank {}", | ||
dim, | ||
rank); | ||
|
||
const bool shapes_match = | ||
std::all_of(input_tensors.begin(), input_tensors.end(), [first_tensor, dim](const ttnn::Tensor& t) { | ||
const auto& ft_shape = first_tensor.get_shape(); | ||
const auto& t_shape = t.get_shape(); | ||
|
||
const bool ranks_match = ft_shape.rank() == t_shape.rank(); | ||
bool non_concat_dims_match = true; | ||
for (int i = 0; i < ft_shape.rank(); i++) { | ||
non_concat_dims_match &= dim == i or t_shape[i] == ft_shape[i]; | ||
} | ||
// bool non_concat_padded_dims_match = true; | ||
// for(int i = 0; i < ft_shape.rank(); i++) { | ||
// non_concat_padded_dims_match &= dim == i or t_shape.with_tile_padding()[i] == | ||
// ft_shape.with_tile_padding()[i]; | ||
// } | ||
return ranks_match and non_concat_dims_match; // and non_concat_padded_dims_match; | ||
}); | ||
|
||
TT_FATAL( | ||
shapes_match, | ||
"All dimensions must be the same size except for the dimension along which the contenation is taking place."); | ||
|
||
std::vector<ttnn::Tensor> itensor; | ||
std::transform( | ||
input_tensors.begin(), | ||
input_tensors.end(), | ||
std::back_inserter(itensor), | ||
[rank](const ttnn::Tensor& input_tensor) -> ttnn::Tensor { | ||
auto output = (rank < 4) ? ttnn::unsqueeze_to_4D(input_tensor) : input_tensor; | ||
return output; | ||
}); | ||
// Convert dim after unsqueeze | ||
dim = dim + 4 - rank; | ||
auto output_tensor = tt::tt_metal::concat(itensor, dim, mem_config); | ||
while (output_tensor.get_shape().rank() > rank) { | ||
const auto shape = output_tensor.get_shape(); | ||
const auto full_shape = output_tensor.get_shape().with_tile_padding(); | ||
std::vector<uint32_t> shape_vec{}; | ||
std::vector<uint32_t> full_shape_vec{}; | ||
// int i = 0; | ||
// while(i < 3 and shape[i] == 1) i++; | ||
for (int i = 1; i < shape.rank(); i++) { | ||
shape_vec.push_back(shape[i]); | ||
full_shape_vec.push_back(full_shape[i]); | ||
} | ||
output_tensor = ttnn::reshape(output_tensor, ttnn::Shape::from_vector(shape_vec, full_shape_vec)); | ||
} | ||
|
||
return output_tensor; | ||
|
||
} | ||
|
||
static inline ttnn::Tensor execute_on_worker_thread(const std::vector<ttnn::Tensor>& input_tensors, int dim, const std::optional<MemoryConfig>& memory_config, std::optional<ttnn::Tensor> &optional_output_tensor) { | ||
constexpr uint8_t DefaultQueueId = 0; | ||
return execute_on_worker_thread(DefaultQueueId, input_tensors, dim, memory_config, optional_output_tensor); | ||
|
||
} | ||
|
||
}; | ||
|
||
} // namespace data_movement | ||
} // namespace operations | ||
|
||
constexpr auto concat = ttnn::register_operation<ttnn::operations::data_movement::Concat>("ttnn::concat"); | ||
|
||
} // namespace ttnn |
66 changes: 66 additions & 0 deletions
66
ttnn/cpp/ttnn/operations/data_movement/concat/concat_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "ttnn/cpp/pybind11/decorators.hpp" | ||
|
||
#include "concat.hpp" | ||
|
||
namespace ttnn::operations::data_movement::detail { | ||
namespace py = pybind11; | ||
|
||
void bind_concat(py::module& module) { | ||
const auto doc = R"doc( | ||
Concats :attr:`tensors` in the given :attr:`dim`. | ||
Args: | ||
* :attr:`tensors`: the tensors to be concatenated. | ||
* :attr:`dim`: the concatenating dimension. | ||
Keyword Args: | ||
* :attr:`memory_config`: the memory configuration to use for the operation | ||
* :attr:`queue_id` (Optional[uint8]): command queue id | ||
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor | ||
Example: | ||
>>> tensor = ttnn.concat(ttnn.from_torch(torch.zeros((1, 1, 64, 32), ttnn.from_torch(torch.zeros((1, 1, 64, 32), dim=3)), device) | ||
>>> tensor1 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device) | ||
>>> tensor2 = ttnn.from_torch(torch.zeros((1, 1, 64, 32), dtype=torch.bfloat16), device=device) | ||
>>> output = ttnn.concat([tensor1, tensor2], dim=4) | ||
>>> print(output.shape) | ||
[1, 1, 32, 64] | ||
)doc"; | ||
|
||
using OperationType = decltype(ttnn::concat); | ||
ttnn::bind_registered_operation( | ||
module, | ||
ttnn::concat, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[] (const OperationType& self, | ||
const std::vector<ttnn::Tensor>& tensors, | ||
const int dim, | ||
std::optional<ttnn::Tensor> &optional_output_tensor, | ||
std::optional<ttnn::MemoryConfig>& memory_config, | ||
uint8_t queue_id) { | ||
return self(queue_id, tensors, dim, memory_config, optional_output_tensor); | ||
}, | ||
py::arg("tensors"), | ||
py::arg("dim") = 0, | ||
py::kw_only(), | ||
py::arg("output_tensor").noconvert() = std::nullopt, | ||
py::arg("memory_config") = std::nullopt, | ||
py::arg("queue_id") = 0, | ||
}); | ||
} | ||
|
||
|
||
} // namespace ttnn::operations::data_movement::detail |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters