Skip to content

Commit

Permalink
#10778: Update argmax op with ttnn support
Browse files Browse the repository at this point in the history
  • Loading branch information
bharane-ab committed Jul 31, 2024
1 parent 11b397b commit 354c047
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 169 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Pointwise Unary

ttnn/abs
ttnn/acos
ttnn/logical_not_
ttnn/acosh
ttnn/asin
ttnn/asinh
Expand Down
2 changes: 0 additions & 2 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,6 @@ Other Operations

.. autofunction:: tt_lib.tensor.repeat

.. autofunction:: tt_lib.tensor.argmax

Loss Functions
==============

Expand Down
6 changes: 6 additions & 0 deletions docs/source/ttnn/ttnn/ttnn/logical_not_.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _ttnn.logical_not_:

ttnn.logical_not_
###################

.. autofunction:: ttnn.logical_not_
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_argmax(self, input_shapes, dim, all, device):
.to(tt_lib.tensor.Layout.TILE)
.to(device)
)
tt_output_tensor_on_device = tt_lib.tensor.argmax(input_tensor, dim=dim, all=all)
tt_output_tensor_on_device = ttnn.experimental.argmax(input_tensor, dim=dim, all=all)
tt_out_tensor = tt_output_tensor_on_device.cpu().to(tt_lib.tensor.Layout.ROW_MAJOR).to_torch()
if all:
golden_tensor = torch.argmax(input_data)
Expand Down
18 changes: 9 additions & 9 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,23 +1516,23 @@ def pow_float(x):


def argmax_1(x):
tt_lib.tensor.argmax(x, dim=-1)
ttnn.argmax(x, dim=-1)


def argmax_2(x):
tt_lib.tensor.argmax(x, dim=-2)
ttnn.argmax(x, dim=-2)


def argmax_3(x):
tt_lib.tensor.argmax(x, dim=-3)
ttnn.argmax(x, dim=-3)


def argmax_4(x):
tt_lib.tensor.argmax(x, dim=-4)
ttnn.argmax(x, dim=-4)


def argmax_all(x):
tt_lib.tensor.argmax(x, dim=-1, all=True)
ttnn.argmax(x, dim=-1, all=True)


def argmin_1(x):
Expand Down Expand Up @@ -2264,22 +2264,22 @@ def clone(x):
},
{
"op": argmax_1,
"name": "tt_lib.tensor.argmax_dim_3",
"name": "ttnn.argmax_dim_3",
"num_repeats": 2,
},
{
"op": argmax_2,
"name": "tt_lib.tensor.argmax_dim_2",
"name": "ttnn.argmax_dim_2",
"num_repeats": 2,
},
{
"op": argmax_3,
"name": "tt_lib.tensor.argmax_dim_1",
"name": "ttnn.argmax_dim_1",
"num_repeats": 2,
},
{
"op": argmax_all,
"name": "tt_lib.tensor.argmax_all",
"name": "ttnn.argmax_all",
"num_repeats": 2,
},
{
Expand Down
127 changes: 0 additions & 127 deletions ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,14 +733,6 @@ Tensor sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryCo
return operation::decorate_as_composite(__func__, _sfpu_eps)(shape, layout, device, output_mem_config);
}

Tensor triu(
const Tensor& input_a,
int32_t dim /* = -1 */,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _triu)(input_a, dim, output_mem_config);
}


Tensor create_mask(const Tensor& input_a, const MemoryConfig& output_mem_config) {
auto& padded_shape = input_a.get_legacy_shape();
auto& unpadded_shape = padded_shape.without_padding();
Expand All @@ -751,125 +743,6 @@ Tensor create_mask(const Tensor& input_a, const MemoryConfig& output_mem_config)
masked_input = ttnn::where(masked_input, input_a, t_inf, output_mem_config);
return masked_input;
}
// Argmax returns the index of maximum element in the tensor
Tensor _argmax(const Tensor& input_t, int64_t _dim, bool all, const MemoryConfig& output_mem_config) {
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_t}))};
operation::launch_with_autoformat(
[_dim, all, output_mem_config](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
const auto& input = input_tensors.at(0);
auto& input_shape = input.get_legacy_shape();
TT_FATAL(input_shape.rank() == 4, "supported for rank-4 tensors at this time");

Tensor input_a = create_mask(input, output_mem_config);

uint32_t dim = input_shape.get_normalized_index(_dim);
int size = input_a.volume();

if (!all) {
if ((dim == (input_shape.rank() - 1)) || (dim == (input_shape.rank() - 2))) {
bool is_width = (dim == (input_shape.rank() - 1));
Tensor max_val = max(input_a, dim, output_mem_config);
Tensor max_tensor = zeros_like(input_a, output_mem_config);
Tensor tindex = tt::numpy::index_width<bfloat16>(
input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config);
if (is_width) {
max_tensor = ttnn::add(max_tensor, max_val, std::nullopt, output_mem_config);
} else {
tindex = tt::numpy::index_height<bfloat16>(
input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config);
max_tensor = ttnn::add(max_tensor, max_val, std::nullopt, output_mem_config);
}
tindex = tindex.to(input_a.device());
max_val.deallocate();
Tensor cmp_results = ttnn::eq(input_a, max_tensor, std::nullopt, output_mem_config);
max_tensor.deallocate();
Tensor max_indices = ttnn::multiply(cmp_results, tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor result = ttnn::where(ttnn::eqz(max_indices), size, max_indices, output_mem_config);
max_indices.deallocate();
result = min(result, dim, output_mem_config);
Tensor res_index = zeros_like(result, output_mem_config);
result = ttnn::where(ttnn::eq(result, size), res_index, result, output_mem_config);
std::vector<int64_t> permute_dims = {3, 0, 1, 2};
if (is_width) {
res_index = ttnn::add(res_index, result, std::nullopt, output_mem_config);
} else {
res_index = ttnn::add(res_index, result, std::nullopt, output_mem_config);
permute_dims[0] = 2;
permute_dims[3] = 3;
}
result.deallocate();
Tensor transpose_res = ttnn::permute(res_index, permute_dims, output_mem_config);
return {transpose_res};
} else if ((dim == (input_shape.rank() - 3)) || (dim == (input_shape.rank() - 4))) {
bool is_channel = (dim == (input_shape.rank() - 3));
Tensor max_val = max(input_a, dim, output_mem_config);
int repeat = input.get_shape()[dim];
std::vector<Tensor> combined_tensors;
for (int cid = 0; cid < repeat; cid++) combined_tensors.emplace_back(max_val);
max_val.deallocate();
Tensor concat_out = concat(combined_tensors, dim, output_mem_config);
// Needed till `max` stops autoformatting output
concat_out = ttnn::reshape(concat_out, input_a.get_shape());
Tensor cmp_results = ttnn::eq(input_a, concat_out, std::nullopt, output_mem_config);
concat_out.deallocate();
Tensor tindex = tt::numpy::index_channel<bfloat16>(
input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config);
if (!is_channel) {
tindex = tt::numpy::index_batch<bfloat16>(
input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config);
}
tindex = tindex.to(input_a.device());
Tensor max_indices = ttnn::multiply(cmp_results, tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor midx = full_like(max_indices, size);
Tensor result = ttnn::where(ttnn::eqz(max_indices), midx, max_indices, output_mem_config);
max_indices.deallocate();
result = min(result, dim, output_mem_config);
Tensor res_index = zeros_like(result, output_mem_config);
result = ttnn::where(ttnn::eq(result, full_like(result, size)), res_index, result, output_mem_config);
res_index.deallocate();
if (is_channel) {
std::vector<int64_t> permute_dims = {1, 0, 2, 3};
Tensor transpose_res = ttnn::permute(result, permute_dims, output_mem_config);
return {transpose_res};
} else {
return {result};
}
}
}
// TODO: Fix the index generation code. With the fix the code will work for argmax that return entire
// maximum value index
Tensor tindex = tt::numpy::index_all<bfloat16>(
input_shape, DataType::BFLOAT16, Layout::TILE, input_a.device(), output_mem_config);
Tensor max_val = global_max(input_a, output_mem_config);
Tensor max_tensor = zeros_like(input_a, output_mem_config);
max_tensor = ttnn::add(max_tensor, max_val, std::nullopt, output_mem_config);
max_val.deallocate();
Tensor cmp_results = ttnn::eq(input_a, max_tensor, std::nullopt, output_mem_config);
max_tensor.deallocate();
Tensor max_indices = ttnn::multiply(cmp_results, tindex, std::nullopt, output_mem_config);
cmp_results.deallocate();
Tensor result = ttnn::where(ttnn::eqz(max_indices), size, max_indices, output_mem_config);
max_indices.deallocate();
result = global_min(result, output_mem_config);
return {result};
},
{input_t},
output_tensors);
return output_tensors.at(0);
}

Tensor argmax(
const Tensor& input_a,
int64_t dim,
bool all,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _argmax)(input_a, dim, all, output_mem_config);
}

} // namespace tt_metal

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,6 @@ Tensor logical_ori(
// on-device tensor creation with shape and filled with value
Tensor sfpu_eps(const Shape shape, Layout layout, Device* device, const MemoryConfig& output_mem_config);

Tensor argmax(
const Tensor& input_a,
int64_t dim = 0,
bool all = false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -118,30 +118,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
R"doc(Perform an eltwise logical OR (``{0} || {1}``) on input tensor and immediate value.)doc",
R"doc("Scalar", "float", "")doc");

m_tensor.def(
"argmax",
&argmax,
py::arg("input").noconvert(),
py::arg("dim"),
py::arg("all") = false,
py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
R"doc(
Returns the indices of the maximum value of elements in the ``input`` tensor
If ``all`` is set to ``true`` irrespective of given dimension it will return the indices of maximum value of all elements in given ``input``
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Tensor argmax is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"dim", "Dimension to perform argmax", "int", "", "Yes"
"all", "Consider all dimension (ignores ``dim`` param)", "bool", "default to false", "No"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def(
"lerp",
py::overload_cast<const Tensor&, const Tensor&, float, const MemoryConfig&>(&lerp),
Expand Down

0 comments on commit 354c047

Please sign in to comment.