Skip to content

Commit

Permalink
#8835: cleaned up ttnn operation registration on C++ side
Browse files Browse the repository at this point in the history
  • Loading branch information
johanna-rock-tt authored and arakhmati committed Jul 4, 2024
1 parent 27cbc9f commit 30c6a63
Show file tree
Hide file tree
Showing 54 changed files with 421 additions and 1,261 deletions.
49 changes: 43 additions & 6 deletions docs/source/ttnn/ttnn/adding_new_ttnn_operation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Adding New ttnn Operation
Wormhole, or others).


FAQ
***

What is a ttnn operation?
-------------------------

Expand All @@ -25,11 +28,13 @@ What steps are needed to add ttnn operation in Python?
2. (Optional) Attach golden function to the operation using `ttnn.attach_golden_function`. This is useful for debugging and testing.


Example of Adding a new Device Operation
****************************************

C++ Implementation
------------------

Step 1: Implement device operation (Optional)
Step 1: Implement device operation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In order to add a new device operation, follow the directory structure shown below:
Expand Down Expand Up @@ -110,17 +115,49 @@ Finally, call the module defined in `examples/example/example_pybind.hpp` wherev



Step 2: Add golden function for the operation in Python
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Step 2: (Optional) Add golden function for the operation in Python
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

A golden function can be added to an operation in order to compare its output with an equivalent `torch` implementation

Add the following code in a python file:

.. code-block:: python
import ttnn
def example_golden_function(input_tensor, *args, **kwargs):
output_tensor = ...
# For the golden function, use the same signature as the operation
# Keep in mind that all `ttnn.Tensor`s are converted to `torch.Tensor`s
# And arguments not needed by torch can be ignored using `*args` and `**kwargs`
def golden_function(input_tensor: "torch.Tensor", *args, **kwargs):
output_tensor: "torch.Tensor" = ...
return output_tensor
ttnn.attach_golden_function(ttnn.example, example_golden_function)
# ttnn Tensors are converted to torch tensors before calling the golden function automatically
# And the outputs are converted back to ttnn Tensors
# But in some cases you may need to preprocess the inputs and postprocess the outputs manually
# In order to preprocess the inputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def preprocess_golden_function_inputs(args, kwargs):
# i.e.
ttnn_input_tensor = args[0]
return ttnn.to_torch(ttnn_input_tensor)
# In order to postprocess the outputs manually, use the following signature
# Note that the arguments are not packed into *args and **kwargs as in the golden function!!!
def postprocess_golden_function_outputs(args, kwargs, output):
# i.e.
ttnn_input_tensor = args[0]
torch_output_tensor = outputs[0]
return ttnn.from_torch(torch_output_tensor, dtype=ttnn_input_tensor.dtype, device=ttnn_input_tensor.device)
ttnn.attach_golden_function(
ttnn.example,
golden_function=golden_function,
preprocess_golden_function_inputs=preprocess_golden_function_inputs, # Optional
postprocess_golden_function_outputs=postprocess_golden_function_outputs # Optional
)
.. note::
`ttnn.example` is the name of the operation in Python because the operation was registered as `ttnn::example` in C++.
4 changes: 2 additions & 2 deletions tests/ttnn/unit_tests/test_deallocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

@pytest.mark.parametrize("h", [32])
@pytest.mark.parametrize("w", [2 * 32])
@pytest.mark.requires_fast_runtime_mode_off
def test_deallocate(device, h, w):
torch_input_tensor = torch.rand((h, w), dtype=torch.bfloat16)

Expand All @@ -27,4 +26,5 @@ def test_deallocate(device, h, w):
ttnn.deallocate(output_tensor)
with pytest.raises(RuntimeError) as exception:
output_tensor_reference + output_tensor_reference
assert "Tensor must be allocated!" in str(exception.value)

assert "MemoryConfig can only be obtained if the buffer is not null" in str(exception.value)
29 changes: 0 additions & 29 deletions tests/ttnn/unit_tests/test_validate_decorator.py

This file was deleted.

37 changes: 0 additions & 37 deletions tt_eager/tensor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,45 +824,8 @@ static std::ostream &operator<<(std::ostream &os, const Shape &self) {
return os;
}


struct TensorSchema {
const std::size_t min_rank;
const std::size_t max_rank;
const std::set<DataType> dtypes;
const std::set<Layout> layouts;
const bool can_be_on_device;
const bool can_be_on_cpu;
const bool can_be_scalar;
const bool is_optional;

static constexpr auto attribute_names() {
return std::forward_as_tuple(
"min_rank",
"max_rank",
"dtypes",
"layouts",
"can_be_on_device",
"can_be_on_cpu",
"can_be_scalar",
"is_optional");
}

const auto attribute_values() const {
return std::forward_as_tuple(
this->min_rank,
this->max_rank,
this->dtypes,
this->layouts,
this->can_be_on_device,
this->can_be_on_cpu,
this->can_be_scalar,
this->is_optional);
}
};

} // namespace types

using types::Shape;
using types::TensorSchema;

} // namespace ttnn
64 changes: 6 additions & 58 deletions ttnn/cpp/pybind11/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct pybind_overload_t {
};

template <typename registered_operation_t, typename concrete_operation_t, typename T, typename... py_args_t>
void add_operator_call(T& py_operation, const pybind_arguments_t<py_args_t...>& overload) {
void define_call_operator(T& py_operation, const pybind_arguments_t<py_args_t...>& overload) {
std::apply(
[&py_operation](auto... args) {
py_operation.def(
Expand All @@ -55,60 +55,12 @@ template <
typename T,
typename function_t,
typename... py_args_t>
void add_operator_call(T& py_operation, const pybind_overload_t<function_t, py_args_t...>& overload) {
void define_call_operator(T& py_operation, const pybind_overload_t<function_t, py_args_t...>& overload) {
std::apply(
[&py_operation, &overload](auto... args) { py_operation.def("__call__", overload.function, args...); },
overload.args.value);
}

template <auto id, typename concrete_operation_t>
std::string append_input_tensor_schemas_to_doc(
const operation_t<id, concrete_operation_t>& operation, const std::string& doc) {
std::stringstream updated_doc;

auto write_row = [&updated_doc]<typename Tuple>(const Tuple& tuple) {
auto index = 0;

std::apply(
[&index, &updated_doc](const auto&... args) {
(
[&index, &updated_doc](const auto& item) {
updated_doc << " ";
if (index == 0) {
updated_doc << " * - ";
} else {
updated_doc << " - ";
}
updated_doc << fmt::format("{}", item);
updated_doc << "\n";
index++;
}(args),
...);
},
tuple);
};

if constexpr (detail::has_input_tensor_schemas<concrete_operation_t>()) {
if constexpr (std::tuple_size_v<decltype(concrete_operation_t::input_tensor_schemas())> > 0) {
updated_doc << doc << "\n\n";
auto tensor_index = 0;
for (const auto& schema : concrete_operation_t::input_tensor_schemas()) {
updated_doc << " .. list-table:: Input Tensor " << tensor_index << "\n\n";
write_row(ttnn::TensorSchema::attribute_names());
write_row(schema.attribute_values());
tensor_index++;
updated_doc << "\n";
}
updated_doc << "\n";
return updated_doc.str();
} else {
return doc;
}
} else {
return doc;
}
}

auto bind_registered_operation_helper(
py::module& module, const auto& operation, const std::string& doc, auto attach_call_operator) {
using registered_operation_t = std::decay_t<decltype(operation)>;
Expand All @@ -117,15 +69,11 @@ auto bind_registered_operation_helper(

py::class_<registered_operation_t> py_operation(module, operation.class_name().c_str());

if constexpr (requires { append_input_tensor_schemas_to_doc(operation, doc); }) {
py_operation.doc() = append_input_tensor_schemas_to_doc(operation, doc).c_str();
} else {
py_operation.doc() = doc;
}
py_operation.doc() = doc;

py_operation.def_property_readonly(
"name",
[](const registered_operation_t& self) -> const std::string { return self.name(); },
[](const registered_operation_t& self) -> const std::string { return self.base_name(); },
"Shortened name of the api");

py_operation.def_property_readonly(
Expand All @@ -139,7 +87,7 @@ auto bind_registered_operation_helper(

attach_call_operator(py_operation);

module.attr(operation.name().c_str()) = operation; // Bind an instance of the operation to the module
module.attr(operation.base_name().c_str()) = operation; // Bind an instance of the operation to the module

return py_operation;
}
Expand All @@ -155,7 +103,7 @@ auto bind_registered_operation(
auto attach_call_operator = [&](auto& py_operation) {
(
[&py_operation](auto&& overload) {
add_operator_call<registered_operation_t, concrete_operation_t>(py_operation, overload);
define_call_operator<registered_operation_t, concrete_operation_t>(py_operation, overload);
}(overloads),
...);
};
Expand Down
5 changes: 2 additions & 3 deletions ttnn/cpp/pybind11/operations/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace detail {

void bind_global_typecast(py::module& module) {
auto doc = fmt::format(
R"doc({0}(input_tensor: ttnn.Tensor, dtype: ttnn.DataType, *, memory_config: Optional[ttnn.MemoryConfig] = None, output_tensor : Optional[ttnn.Tensor] = None, queue_id : Optional[int]) -> ttnn.Tensor
R"doc({0}(input_tensor: ttnn.Tensor, dtype: ttnn.DataType, *, memory_config: Optional[ttnn.MemoryConfig] = None, output_tensor : Optional[ttnn.Tensor] = None, queue_id : Optional[int]) -> ttnn.Tensor
Applies {0} to :attr:`input_tensor`.
Expand All @@ -40,8 +40,7 @@ Example::
>>> tensor = ttnn.typecast(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), ttnn.uint16)
)doc",
ttnn::typecast.name());

ttnn::typecast.base_name());

using TypecastType = decltype(ttnn::typecast);
bind_registered_operation(
Expand Down
10 changes: 5 additions & 5 deletions ttnn/cpp/pybind11/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <typename creation_operation_t>
void bind_full_operation(py::module& module, const creation_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(shape: ttnn.Shape, fill_value: Union[int, float], dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc",
operation.name());
operation.base_name());

bind_registered_operation(
module,
Expand Down Expand Up @@ -66,7 +66,7 @@ template <typename creation_operation_t>
void bind_full_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(shape: ttnn.Shape, dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc",
operation.name());
operation.base_name());

bind_registered_operation(
module,
Expand All @@ -92,7 +92,7 @@ template <typename creation_operation_t>
void bind_full_like_operation(py::module& module, const creation_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(tensor: ttnn.Tensor, fill_value: Union[int, float], dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc",
operation.name());
operation.base_name());

bind_registered_operation(
module,
Expand Down Expand Up @@ -136,7 +136,7 @@ template <typename creation_operation_t>
void bind_full_like_operation_with_hard_coded_value(py::module& module, const creation_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(tensor: ttnn.Tensor, dtype: Optional[ttnn.DataType] = None, layout: Optional[ttnn.Layout] = None, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None)doc",
operation.name());
operation.base_name());

bind_registered_operation(
module,
Expand All @@ -162,7 +162,7 @@ template <typename creation_operation_t>
void bind_arange_operation(py::module& module, const creation_operation_t& operation) {
auto doc = fmt::format(
R"doc({0}(start: int = 0, stop: int, step: int = 1, dtype: ttnn.DataType = ttnn.bfloat16, device: ttnn.Device = None, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG)doc",
operation.name());
operation.base_name());

bind_registered_operation(
module,
Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/pybind11/operations/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void bind_fill_cache_for_user_(py::module& module, const kv_cache_operation_t& o
* :attr:`batch_index` (int): The index into the cache tensor.
)doc",
operation.name(),
operation.base_name(),
operation.python_fully_qualified_name());

bind_registered_operation(
Expand Down Expand Up @@ -64,7 +64,7 @@ void bind_update_cache_for_token_(py::module& module, const kv_cache_operation_t
* :attr:`batch_offset` (int): The batch_offset into the cache tensor.
)doc",
operation.name(),
operation.base_name(),
operation.python_fully_qualified_name());

bind_registered_operation(
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/operations/pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void bind_global_avg_pool2d(py::module& module) {
>>> tensor = ttnn.from_torch(torch.randn((10, 3, 32, 32), dtype=ttnn.bfloat16), device=device)
>>> output = {1}(tensor)
)doc",
ttnn::global_avg_pool2d.name(),
ttnn::global_avg_pool2d.base_name(),
ttnn::global_avg_pool2d.python_fully_qualified_name());

bind_registered_operation(
Expand Down
Loading

0 comments on commit 30c6a63

Please sign in to comment.