Skip to content

Commit

Permalink
#12750: Replace zeros_like with empty_like in backward ops (#12934)
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW authored Sep 21, 2024
1 parent 4994d18 commit 7a2ca61
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
),
)
# Pytorch Reference
# - name: fill.Tensor(Tensor self, Tensor value) -> Tensor
# - name: fill.Scalar(Tensor self, Scalar value) -> Tensor
# self: zeros_like(grad)
# value: grad.sum()
# result: at::fill(self_t, value_t)
# result: at::fill(self_t, 0)
def test_bw_fill(input_shapes, device):
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1, 1, device)
in_data, input_tensor = data_gen_with_range(input_shapes, -10, 10, device, True)
Expand All @@ -34,7 +33,7 @@ def test_bw_fill(input_shapes, device):
golden_function = ttnn.get_golden_function(ttnn.fill_bw)
golden_tensor = golden_function(grad_data, in_data)

comp_pass = compare_all_close(tt_output_tensor_on_device, golden_tensor, atol=150, rtol=1e-6)
comp_pass = compare_all_close(tt_output_tensor_on_device, golden_tensor, atol=0, rtol=0)
assert comp_pass


Expand All @@ -61,5 +60,5 @@ def test_bw_fill_opt_tensor(input_shapes, device):
golden_tensor = golden_function(grad_data, in_data)

tt_output_tensor_on_device = [input_grad]
comp_pass = compare_all_close(tt_output_tensor_on_device, golden_tensor, atol=150, rtol=1e-6)
comp_pass = compare_all_close(tt_output_tensor_on_device, golden_tensor, atol=0, rtol=0)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteBackwardMul::invoke(
uint8_t queue_id, const Tensor& grad, const Tensor& input, float scalar, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> result;
if(!input_grad.has_value()){
input_grad = ttnn::zeros_like(grad);
input_grad = ttnn::empty_like(grad);
}
ttnn::multiply(queue_id, grad, scalar, std::nullopt, output_mem_config, input_grad);
result.push_back(input_grad);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ std::vector<Tensor> _sigmoid_bw(
std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardRsqrt::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> result;
if(!input_grad.has_value()){
input_grad = ttnn::zeros_like(grad);
input_grad = ttnn::empty_like(grad);
}
float t_inf = std::numeric_limits<float>::infinity();
float t_nan = std::nanf("");
Expand All @@ -415,8 +415,8 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardRsqrt::invoke(uint8
return result;
}

std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardRsqrt::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return ExecuteUnaryBackwardRsqrt::invoke(DefaultQueueId, grad, input, output_mem_config);
std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardRsqrt::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
return ExecuteUnaryBackwardRsqrt::invoke(DefaultQueueId, grad, input, output_mem_config, input_grad);
}

std::vector<std::optional<Tensor>> ExecuteUnaryBackwardNeg::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
Expand All @@ -426,27 +426,32 @@ std::vector<std::optional<Tensor>> ExecuteUnaryBackwardNeg::invoke(uint8_t queue
return result;
}

std::vector<std::optional<Tensor>> ExecuteUnaryBackwardNeg::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
return ExecuteUnaryBackwardNeg::invoke(DefaultQueueId, grad, input, output_mem_config, input_grad);
}

std::vector<Tensor> _relu_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = ttnn::multiply(ttnn::gtz(input, output_mem_config), grad, std::nullopt, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}

// fill_bw:
// name: fill.Scalar(Tensor self, Scalar value) -> Tensor
// self: zeros_like(grad)
// result: at::fill(self_t, 0)
std::vector<std::optional<Tensor>> ExecuteUnaryBackwardFill::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
auto output_memory_config = output_mem_config.value_or(input.memory_config());
std::vector<std::optional<Tensor>> result = {std::nullopt};
input_grad = input_grad.value_or(ttnn::zeros_like(input));

Tensor val = grad;
val = ttnn::sum(val);
Tensor result_val = ttnn::full_like(grad, 0.0f);
ttnn::add(queue_id, result_val, val, std::nullopt, output_mem_config, input_grad);

result[0] = input_grad;
result[0] = input_grad.has_value() ? ttnn::zeros_like(grad, std::nullopt, std::nullopt, std::nullopt, std::nullopt, input_grad) : ttnn::zeros_like(grad);
return result;
}

std::vector<std::optional<Tensor>> ExecuteUnaryBackwardFill::invoke(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
return ExecuteUnaryBackwardFill::invoke(DefaultQueueId, grad, input, output_mem_config, input_grad);
}

std::vector<Tensor> _hardsigmoid_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_a = ttnn::where(
Expand Down Expand Up @@ -1354,7 +1359,7 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> result;
if(!input_grad.has_value()){
input_grad = ttnn::zeros_like(grad);
input_grad = ttnn::empty_like(grad);
}

auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Expand Down Expand Up @@ -1404,8 +1409,9 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
const Tensor& grad,
const Tensor& input,
string approximate,
const std::optional<MemoryConfig>& output_mem_config) {
return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config);
const std::optional<MemoryConfig>& output_mem_config,
std::optional<Tensor> input_grad) {
return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config, input_grad);
}

std::vector<Tensor> _repeat_bw(
Expand Down
20 changes: 17 additions & 3 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ struct ExecuteUnaryBackwardNeg {
const Tensor &input_tensor_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);

static std::vector<std::optional<Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);
};

struct ExecuteUnaryBackwardThreshold {
Expand Down Expand Up @@ -110,7 +116,8 @@ struct ExecuteUnaryBackwardRsqrt {
static std::vector<std::optional<Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);
};

struct ExecuteUnaryBackwardClamp {
Expand Down Expand Up @@ -243,6 +250,12 @@ struct ExecuteUnaryBackwardFill {
const Tensor &input_tensor_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);

static std::vector<std::optional<Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);
};

struct ExecuteUnaryBackwardProd {
Expand Down Expand Up @@ -283,13 +296,14 @@ struct ExecuteUnaryBackwardAbs {

struct ExecuteUnaryBackwardGelu{
static std::vector<std::optional<ttnn::Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
string parameter_a,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);

static std::vector<std::optional<ttnn::Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
string parameter_a,
Expand Down
10 changes: 5 additions & 5 deletions ttnn/ttnn/operations/unary_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,14 +831,14 @@ def _golden_function(grad_tensor, input_tensor, sizes, *args, **kwargs):
ttnn.attach_golden_function(ttnn.repeat_bw, golden_function=_golden_function)


def _golden_function(grad_tensor, input_tensor, *args, **kwargs):
def _golden_function(grad_tensor, input_tensor, *args, value=2.0, **kwargs):
import torch

pyt_y = torch.zeros_like(grad_tensor)
grad_sum = grad_tensor.sum()
pyt_y.fill_(grad_sum)
input_tensor.retain_grad()
pyt_y = torch.fill(input_tensor, value)
pyt_y.backward(gradient=grad_tensor)

return [pyt_y]
return [input_tensor.grad]


ttnn.attach_golden_function(ttnn.fill_bw, golden_function=_golden_function)
Expand Down

0 comments on commit 7a2ca61

Please sign in to comment.