Skip to content

Commit

Permalink
#14982: Update Unary examples docs (#15417)
Browse files Browse the repository at this point in the history
### Ticket
#14982

### Problem description
Update Unary examples docs

### What's changed
Updated examples, description, sweep tests, supported dtypes for the
following ops :

- trunc
- rpow
- hardtanh
- rdiv
- softplus

### Checklist
- [x] [Post commit
CI](https://github.com/tenstorrent/tt-metal/actions/runs/12094193572)
<img width="966" alt="Screenshot 2024-11-25 at 4 23 58 PM"
src="https://github.com/user-attachments/assets/ae23a68d-e4e9-45e6-8f85-b65124e6e8df">
<img width="965" alt="Screenshot 2024-11-25 at 10 47 54 PM"
src="https://github.com/user-attachments/assets/535b1484-bf73-4fb3-b67b-b44a40e38e27">
<img width="982" alt="Screenshot 2024-11-25 at 10 51 23 PM"
src="https://github.com/user-attachments/assets/518c550a-851a-4e7e-977a-37942e0d778c">
<img width="965" alt="Screenshot 2024-11-25 at 11 10 14 PM"
src="https://github.com/user-attachments/assets/634900fc-c6f9-4448-a849-7070fa527dfd">
<img width="961" alt="Screenshot 2024-11-25 at 11 59 57 PM"
src="https://github.com/user-attachments/assets/12bb297b-4188-4a25-abc9-b1a243ee39f6">
  • Loading branch information
VirdhatchaniKN authored Nov 30, 2024
1 parent c2132fa commit 35cd995
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
# If invalidated, the vector will still be stored but will be skipped.
# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid.
def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
if test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT or test_vector["input_a_dtype"] == ttnn.bfloat8_b:
return True, "Row Major layout and bfloat8_b are not supported"
if test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT:
return True, "Row Major layout is not supported"
return False, None


Expand Down
64 changes: 47 additions & 17 deletions ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,9 @@ void bind_unary_rdiv(
{9}
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2}, {4} = {6})
>>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> {2} = 2
>>> output = {1}(tensor, {2}, {4} = None)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
Expand Down Expand Up @@ -738,18 +739,31 @@ void bind_softplus(py::module& module, const unary_operation_t& operation) {
input_tensor (ttnn.Tensor): the input tensor.
Keyword Args:
beta (float): Scales the input before applying the Softplus function. By modifying :attr:`beta`, you can adjust the steepness of the function. A higher :attr:`beta` value makes the function steeper, approaching a hard threshold like the ReLU function for large values of :attr:`beta`.
threshold (float): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values.
beta (float, optional): Scales the input before applying the Softplus function. By modifying :attr:`beta`, you can adjust the steepness of the function. A higher :attr:`beta` value makes the function steeper, approaching a hard threshold like the ReLU function for large values of :attr:`beta`. Defaults to `1`.
threshold (float, optional): Used to switch to a linear function for large values to improve numerical stability. This avoids issues with floating-point representation for very large values. Defaults to `20`.
memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`.
output_tensor (ttnn.Tensor, optional): preallocated output tensor. Defaults to `None`.
queue_id (int, optional): command queue id. Defaults to `0`.
Returns:
ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16
- TILE
- 2, 3, 4
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, beta=1.0, threshold=20.0)
>>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> output = {1}(tensor, beta = 1.0, threshold = 20.0)
)doc",
ttnn::softplus.base_name(),
ttnn::softplus.python_fully_qualified_name());
Expand Down Expand Up @@ -1270,7 +1284,7 @@ void bind_hardtanh(
const std::string& parameter_name_b,
const std::string& parameter_b_doc,
float parameter_b_value,
const std::string& supported_dtype = "BFLOAT16",
const std::string& supported_dtype = "BFLOAT16, BFLOAT8_B",
const std::string& info_doc = "") {
auto doc = fmt::format(
R"doc(
Expand Down Expand Up @@ -1303,8 +1317,10 @@ void bind_hardtanh(
{9}
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2} = {4}, {5} = {7})
>>> tensor = ttnn.from_torch(input, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> min = 2
>>> max = 8
>>> output = {1}(tensor, min, max)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
Expand Down Expand Up @@ -1465,7 +1481,7 @@ void bind_unary_composite_threshold(
}

template <typename unary_operation_t>
void bind_unary_composite_operation(py::module& module, const unary_operation_t& operation, const std::string& description) {
void bind_unary_composite_trunc(py::module& module, const unary_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc(
Applies {0} to :attr:`input_tensor` element-wise.
Expand All @@ -1484,8 +1500,21 @@ void bind_unary_composite_operation(py::module& module, const unary_operation_t&
Returns:
ttnn.Tensor: the output tensor.
Note:
Supported dtypes, layouts, and ranks:
.. list-table::
:header-rows: 1
* - Dtypes
- Layouts
- Ranks
* - BFLOAT16, BFLOAT8_B
- TILE
- 2, 3, 4
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> output = {1}(tensor)
)doc",
operation.base_name(),
Expand Down Expand Up @@ -1581,7 +1610,7 @@ void bind_unary_composite_float_with_default(
}

template <typename unary_operation_t>
void bind_unary_composite_float(
void bind_unary_composite_rpow(
py::module& module,
const unary_operation_t& operation,
const std::string& parameter_name_a,
Expand Down Expand Up @@ -1620,8 +1649,9 @@ void bind_unary_composite_float(
{7}
Example:
>>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
>>> output = {1}(tensor, {2})
>>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
>>> exponent = 2
>>> output = {1}(tensor, exponent)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
Expand Down Expand Up @@ -1898,7 +1928,7 @@ void py_module(py::module& module) {
detail::bind_unary_composite(module, ttnn::normalize_global, R"doc(Performs normalize_global function on :attr:`input_tensor`.)doc", "", R"doc(BFLOAT16)doc",
R"doc(ROW_MAJOR, TILE)doc", R"doc(4)doc", "", R"doc(torch.rand([1, 1, 32, 32], dtype=torch.bfloat16))doc");
detail::bind_unary_composite(module, ttnn::frac, R"doc(Performs frac function on :attr:`input_tensor`.)doc", "", R"doc(BFLOAT16, BFLOAT8_B)doc");
detail::bind_unary_composite_operation(module, ttnn::trunc, R"doc(Not supported for grayskull.)doc");
detail::bind_unary_composite_trunc(module, ttnn::trunc, R"doc(Not supported for grayskull.)doc");

detail::bind_unary_composite_floats_with_default(
module,
Expand Down Expand Up @@ -1990,7 +2020,7 @@ void py_module(py::module& module) {
"eps", "eps", 0.0f, R"doc(BFLOAT16)doc",
R"doc(Not available for Wormhole_B0.)doc");

detail::bind_unary_composite_float(
detail::bind_unary_composite_rpow(
module,
ttnn::rpow,
"exponent", "exponent value. Non-positive values are not supported.",
Expand All @@ -2009,7 +2039,7 @@ void py_module(py::module& module) {
Output tensor will have BFLOAT16 data type.)doc",

R"doc(BFLOAT16, BFLOAT8_B)doc", R"doc(System memory is not supported.)doc");
R"doc(BFLOAT16)doc", R"doc(System memory is not supported.)doc");
}

} // namespace unary
Expand Down

0 comments on commit 35cd995

Please sign in to comment.