Skip to content

Commit

Permalink
#15078: Update clamp_bw, clip_bw with min, max tensor (#15255)
Browse files Browse the repository at this point in the history
### Ticket
#15078

### Problem description
Following up issue #13234,
Need to update clamp_bw, clip_bw with min, max tensor support

### What's changed
Updated clamp_bw, clip_bw with min, max tensor support

### Checklist
- [x] [Post commit
CI](https://github.com/tenstorrent/tt-metal/actions/runs/11933224437)
passes
  • Loading branch information
VirdhatchaniKN authored Nov 21, 2024
1 parent b057e09 commit 019a5cc
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
),
)
@pytest.mark.parametrize(
"min, max",
"min_val, max_val",
[
(-10.0, 10.0),
(10.0, -10.0),
Expand All @@ -30,15 +30,32 @@
(None, -0.5),
(1.0, 0.0),
(0.0, 1.0),
("tensor", None),
(None, "tensor"),
("tensor", "tensor"),
],
)
def test_unary_bw_clamp(input_shapes, min, max, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)
def test_unary_bw_clamp_ttnn(input_shapes, min_val, max_val, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device)
if min_val == "tensor":
min, min_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
elif min_val is None:
min, min_tensor = None, None
else:
min, min_tensor = min_val, min_val

if max_val == "tensor":
max, max_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
elif max_val is None:
max, max_tensor = None, None
else:
max, max_tensor = max_val, max_val

if min is None and max is None:
pytest.xfail("Only one of 'min' or 'max' can be None. Please provide one value")
else:
tt_output_tensor_on_device = ttnn.clamp_bw(grad_tensor, input_tensor, min, max)
tt_output_tensor_on_device = ttnn.clamp_bw(grad_tensor, input_tensor, min_tensor, max_tensor)
golden_function = ttnn.get_golden_function(ttnn.clamp_bw)
golden_tensor = golden_function(grad_data, in_data, min, max)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
),
)
@pytest.mark.parametrize(
"min, max",
"min_val, max_val",
[
(-10.0, 10.0),
(10.0, -10.0),
Expand All @@ -30,15 +30,32 @@
(None, -0.5),
(1.0, 0.0),
(0.0, 1.0),
("tensor", None),
(None, "tensor"),
("tensor", "tensor"),
],
)
def test_unary_bw_clip(input_shapes, min, max, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)
def test_unary_bw_clip_ttnn(input_shapes, min_val, max_val, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device)
if min_val == "tensor":
min, min_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
elif min_val is None:
min, min_tensor = None, None
else:
min, min_tensor = min_val, min_val

if max_val == "tensor":
max, max_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
elif max_val is None:
max, max_tensor = None, None
else:
max, max_tensor = max_val, max_val

if min is None and max is None:
pytest.xfail("Only one of 'min' or 'max' can be None. Please provide one value")
else:
tt_output_tensor_on_device = ttnn.clip_bw(grad_tensor, input_tensor, min, max)
tt_output_tensor_on_device = ttnn.clip_bw(grad_tensor, input_tensor, min_tensor, max_tensor)
golden_function = ttnn.get_golden_function(ttnn.clip_bw)
golden_tensor = golden_function(grad_data, in_data, min, max)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
Expand Down
29 changes: 29 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,40 @@ std::vector<Tensor> ExecuteUnaryBackwardClamp::invoke(
return grad_tensor;
}

std::vector<Tensor> ExecuteUnaryBackwardClamp::invoke(
const Tensor& grad, const Tensor& input, std::optional<Tensor> min, std::optional<Tensor> max, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
TT_FATAL((max.has_value() || min.has_value()), "Only one of 'min' or 'max' can be None. Please provide one value");
if (!max.has_value()) {
Tensor minT = ttnn::ge(input, min.value(), std::nullopt, output_mem_config);
Tensor in_grad = ttnn::multiply(grad, minT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(in_grad);
return grad_tensor;
}else if(!min.has_value()) {
Tensor maxT = ttnn::le(input, max.value(), std::nullopt, output_mem_config);
Tensor in_grad = ttnn::multiply(grad, maxT, std::nullopt, output_mem_config);
grad_tensor.emplace_back(in_grad);
return grad_tensor;
}
Tensor minT = ttnn::le(input, min.value(), std::nullopt, output_memory_config);
Tensor maxT = ttnn::ge(input, max.value(), std::nullopt, output_memory_config);
Tensor result = ttnn::logical_and(minT, maxT, std::nullopt, output_memory_config);
result = ttnn::multiply(grad, result, std::nullopt, output_memory_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}

std::vector<Tensor> ExecuteUnaryBackwardClip::invoke(
const Tensor& grad, const Tensor& input, std::optional<float> min, std::optional<float> max, const std::optional<MemoryConfig>& output_mem_config) {
return ExecuteUnaryBackwardClamp::invoke(grad, input, min, max, output_mem_config);
}

std::vector<Tensor> ExecuteUnaryBackwardClip::invoke(
const Tensor& grad, const Tensor& input, std::optional<Tensor> min, std::optional<Tensor> max, const std::optional<MemoryConfig>& output_mem_config) {
return ExecuteUnaryBackwardClamp::invoke(grad, input, min, max, output_mem_config);
}

// Hardtanh
// result: torch.where((input <= min) | (input >= max), 0.0, grad)
std::vector<Tensor> ExecuteUnaryBackwardHardtanh::invoke(
Expand Down
14 changes: 14 additions & 0 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,13 @@ struct ExecuteUnaryBackwardClamp {
std::optional<float> min = std::nullopt,
std::optional<float> max = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt);

static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
std::optional<Tensor> min = std::nullopt,
std::optional<Tensor> max = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryBackwardClip {
Expand All @@ -467,6 +474,13 @@ struct ExecuteUnaryBackwardClip {
std::optional<float> min = std::nullopt,
std::optional<float> max = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt);

static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
std::optional<Tensor> min = std::nullopt,
std::optional<Tensor> max = std::nullopt,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryBackwardRdiv {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,22 @@ void bind_unary_backward_optional_float_params_with_default(
py::arg(parameter_name_a.c_str()) = parameter_a_value,
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::kw_only(),
py::arg("memory_config") = std::nullopt},

ttnn::pybind_overload_t{
[](const unary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
std::optional<Tensor> parameter_a,
std::optional<Tensor> parameter_b,
const std::optional<MemoryConfig>& memory_config) {
return self(grad_tensor, input_tensor, parameter_a, parameter_b, memory_config);
},
py::arg("grad_tensor"),
py::arg("input_tensor"),
py::arg(parameter_name_a.c_str()) = parameter_a_value,
py::arg(parameter_name_b.c_str()) = parameter_b_value,
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

Expand Down

0 comments on commit 019a5cc

Please sign in to comment.