Skip to content

Commit

Permalink
#10778: Update Argmin with ttnn support
Browse files Browse the repository at this point in the history
  • Loading branch information
bharane-ab committed Jul 31, 2024
1 parent 0f35c21 commit 11b397b
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 55 deletions.
3 changes: 0 additions & 3 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,6 @@ Other Operations

.. autofunction:: tt_lib.tensor.argmax

.. autofunction:: tt_lib.tensor.argmin


Loss Functions
==============

Expand Down
18 changes: 9 additions & 9 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,23 +1536,23 @@ def argmax_all(x):


def argmin_1(x):
tt_lib.tensor.argmin(x, dim=-1)
ttnn.argmin(x, dim=-1)


def argmin_2(x):
tt_lib.tensor.argmin(x, dim=-2)
ttnn.argmin(x, dim=-2)


def argmin_3(x):
tt_lib.tensor.argmin(x, dim=-3)
ttnn.argmin(x, dim=-3)


def argmin_4(x):
tt_lib.tensor.argmin(x, dim=-4)
ttnn.argmin(x, dim=-4)


def argmin_all(x):
tt_lib.tensor.argmin(x, dim=-1, all=True)
ttnn.argmin(x, dim=-1, all=True)


def primary_moreh_softmax_0(x):
Expand Down Expand Up @@ -2284,22 +2284,22 @@ def clone(x):
},
{
"op": argmin_1,
"name": "tt_lib.tensor.argmin_dim_3",
"name": "ttnn.argmin_dim_3",
"num_repeats": 2,
},
{
"op": argmin_2,
"name": "tt_lib.tensor.argmin_dim_2",
"name": "ttnn.argmin_dim_2",
"num_repeats": 2,
},
{
"op": argmin_3,
"name": "tt_lib.tensor.argmin_dim_1",
"name": "ttnn.argmin_dim_1",
"num_repeats": 2,
},
{
"op": argmin_all,
"name": "tt_lib.tensor.argmin_all",
"name": "ttnn.argmin_all",
"num_repeats": 2,
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,17 +871,6 @@ Tensor argmax(
return operation::decorate_as_composite(__func__, _argmax)(input_a, dim, all, output_mem_config);
}

Tensor _argmin(const Tensor& input_a, int64_t _dim, bool all, const MemoryConfig& output_mem_config) {
Tensor neg_input = ttnn::neg(input_a, output_mem_config);
return (argmax(neg_input, _dim, all, output_mem_config));
}
Tensor argmin(
const Tensor& input_a,
int64_t dim,
bool all,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _argmin)(input_a, dim, all, output_mem_config);
}
} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,6 @@ Tensor argmax(
bool all = false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor argmin(
const Tensor& input_a,
int64_t dim = 0,
bool all = false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -142,30 +142,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def(
"argmin",
&argmin,
py::arg("input").noconvert(),
py::arg("dim"),
py::arg("all") = false,
py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
R"doc(
Returns the indices of the minimum value of elements in the ``input`` tensor
If ``all`` is set to ``true`` irrespective of given dimension it will return the indices of minimum value of all elements in given ``input``
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Tensor argmin is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"dim", "Dimension to perform argmin", "int", "", "Yes"
"all", "Consider all dimension (ignores ``dim`` param)", "bool", "default to false", "No"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def(
"lerp",
py::overload_cast<const Tensor&, const Tensor&, float, const MemoryConfig&>(&lerp),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ std::vector<Tensor> _digamma_bw(const Tensor& grad, const Tensor& input, const s
auto output_memory_config = output_mem_config.value_or(input.memory_config());
float t_inf = std::numeric_limits<float>::infinity();
float t_nan = std::nanf("");
Tensor grad_a = ttnn::multiply(grad, polygamma(input, 1, output_memory_config), std::nullopt, output_mem_config);
Tensor grad_a = ttnn::multiply(grad, ttnn::polygamma(input, 1, output_mem_config), std::nullopt, output_mem_config);
grad_a = where(
ttnn::logical_and(ttnn::eqz(input, output_mem_config), ttnn::eqz(grad, output_mem_config), std::nullopt, output_mem_config),
t_nan,
Expand Down Expand Up @@ -1286,7 +1286,7 @@ std::vector<Tensor> _polygamma_bw(
if (n == 2 || n == 4 || n == 6 || n == 8 || n == 10) {
pos_neg = -1.0f;
}
Tensor grad_a = ttnn::multiply(grad, polygamma(input, (n + 1), output_memory_config), std::nullopt, output_mem_config);
Tensor grad_a = ttnn::multiply(grad, ttnn::polygamma(input, (n + 1), output_mem_config), std::nullopt, output_mem_config);
grad_a = where(
ttnn::logical_and(
ttnn::le(input, 0.0, std::nullopt, output_mem_config), ttnn::eqz(grad, output_mem_config), std::nullopt, output_mem_config),
Expand Down

0 comments on commit 11b397b

Please sign in to comment.