Skip to content

Commit

Permalink
#0: squash this
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-nshanker committed Feb 22, 2024
1 parent 7da23d1 commit 7b7feaf
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
6 changes: 4 additions & 2 deletions ttnn/cpp/pybind11/operations/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ void py_module(py::module& m_binary) {

m_binary.def(
"add",
static_cast<ttnn::Tensor (*)(const ttnn::Tensor&, const float, const tt::tt_metal::MemoryConfig&)>(&ttnn::operations::binary::add),
static_cast<ttnn::Tensor (*)(const ttnn::Tensor&, const float, const tt::tt_metal::MemoryConfig&, std::optional<DataType>)>(&ttnn::operations::binary::add),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config") = DRAM_MEMORY_CONFIG);
py::arg("memory_config") = DRAM_MEMORY_CONFIG,
py::arg("dtype") = std::nullopt
);
}

} // namespace binary
Expand Down
6 changes: 5 additions & 1 deletion ttnn/cpp/ttnn/operations/binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ inline ttnn::Tensor add(
}

inline ttnn::Tensor add(
const ttnn::Tensor& input_tensor_a, const float input_tensor_b, const tt::tt_metal::MemoryConfig& memory_config) {
const ttnn::Tensor& input_tensor_a, const float input_tensor_b, const tt::tt_metal::MemoryConfig& memory_config,
std::optional<DataType> dtype = std::nullopt) {
if (dtype.has_value()) {
TT_THROW("ttnn.add: cannot change dtype when broadcasting");
}
const auto original_shape = input_tensor_a.ttnn_shape();

auto input_tensor_a_4D = ttnn::unsqueeze_to_4D(input_tensor_a);
Expand Down
13 changes: 1 addition & 12 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,8 @@ def add(
"""
input_tensor_a = input_tensor_a.value
run_with_dtype = False
if dtype is not None:
if isinstance(input_tensor_b, ttnn.Tensor):
run_with_dtype = True
else:
raise TypeError("ttnn.add: cannot change dtype when broadcasting")
input_tensor_b = input_tensor_b.value if isinstance(input_tensor_b, ttnn.Tensor) else input_tensor_b
if run_with_dtype:
output = ttnn._ttnn.operations.binary.add(
input_tensor_a, input_tensor_b, memory_config=memory_config, dtype=dtype
)
else:
output = ttnn._ttnn.operations.binary.add(input_tensor_a, input_tensor_b, memory_config=memory_config)
output = ttnn._ttnn.operations.binary.add(input_tensor_a, input_tensor_b, memory_config=memory_config, dtype=dtype)
return ttnn.Tensor(output)


Expand Down

0 comments on commit 7b7feaf

Please sign in to comment.