Skip to content

Commit

Permalink
#12146: Add optional output tensor to silu (#12209)
Browse files Browse the repository at this point in the history
* #12146: Add optional output tensor to silu

* #12146: Remove are_required_outputs param

* #12137: remove if else condition in backward ops

* #12146: Use invoke function

* #12146: Use value_or function

* #12146: Use invoke function

* #0: Rebased

---------

Co-authored-by: Bharane AB <[email protected]>
  • Loading branch information
Aswinmcw and bharane-ab authored Sep 10, 2024
1 parent 247621a commit 09f1bd1
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_bw_exp_output(input_shapes, device):
tt_output_tensor_on_device = ttnn.exp_bw(
grad_tensor,
input_tensor,
are_required_outputs=[True],
input_grad=input_grad,
queue_id=cq_id,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def test_bw_unary_pow_output(input_shapes, exponent_and_pcc, device):
grad_tensor,
input_tensor,
exponent=exponent,
are_required_outputs=[True],
input_grad=input_grad,
queue_id=cq_id,
)
Expand Down
27 changes: 27 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_silu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,30 @@ def test_bw_silu(input_shapes, device):
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_silu_opt_tensor(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -1e4, 1e4, device)

_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.silu_bw(grad_tensor, input_tensor, input_grad=input_grad)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

golden_function = ttnn.get_golden_function(ttnn.silu_bw)
golden_tensor = golden_function(grad_data, in_data)

tt_output_tensor_on_device = [input_grad]
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def test_bw_sqrt_output(input_shapes, device):
tt_output_tensor_on_device = ttnn.sqrt_bw(
grad_tensor,
input_tensor,
are_required_outputs=[True],
input_grad=input_grad,
queue_id=cq_id,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_bw_tanh_with_output(input_shapes, device):
tt_output_tensor_on_device = ttnn.tanh_bw(
grad_tensor,
input_tensor,
are_required_outputs=[True],
input_grad=input_grad,
queue_id=cq_id,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ std::vector<Tensor> _rdiv_bw(

// unary_pow:
// grad_input = grad * exponent * torch.pow(input, exponent - 1)
std::vector<std::optional<Tensor>> _pow_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> ExecuteUnaryBackwardPow::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, float exponent, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> grad_tensor;
TT_FATAL(are_required_outputs.at(0) , "input_grad derivative is required output");
input_grad = input_grad.value_or(ttnn::zeros_like(input));
const float ZERO_THRESHOLD = std::numeric_limits<float>::epsilon() * 10.0f;
TT_FATAL(exponent >= 0.0, "negative exponents are not supported; use recip(pow(input,abs(exponent)))");
if (std::abs(exponent) < ZERO_THRESHOLD) {
Expand All @@ -173,76 +173,53 @@ std::vector<std::optional<Tensor>> _pow_bw(uint8_t queue_id, const Tensor& grad,
Tensor final_result = ttnn::multiply(queue_id, result, grad, std::nullopt, output_mem_config);
result.deallocate();
Tensor temp = where(queue_id, ttnn::le(queue_id, final_result, -3.4e+38, std::nullopt, output_mem_config), -std::numeric_limits<float>::infinity(), final_result, output_mem_config);
if(input_grad.has_value()){
where(queue_id, ttnn::ge(queue_id, final_result, 3.4e+38, std::nullopt, output_mem_config), std::numeric_limits<float>::infinity(), temp, output_mem_config, input_grad);
} else {
input_grad = where(queue_id, ttnn::ge(queue_id, final_result, 3.4e+38, std::nullopt, output_mem_config), std::numeric_limits<float>::infinity(), temp, output_mem_config);
}
grad_tensor.emplace_back(input_grad);
return grad_tensor;
}

std::vector<std::optional<Tensor>> _exp_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> ExecuteUnaryBackwardExp::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> grad_tensor;
TT_FATAL(are_required_outputs.at(0), "input_grad derivative is a required output");

input_grad = input_grad.value_or(ttnn::zeros_like(input));
float t_inf = std::numeric_limits<float>::infinity();
Tensor exp_result = ttnn::exp(queue_id, input, false, output_mem_config);
Tensor result = ttnn::multiply(queue_id, grad, exp_result, std::nullopt, output_mem_config);
result = where(queue_id, ttnn::ge(queue_id, result, 1e+38, std::nullopt, output_mem_config), t_inf, result, output_mem_config);
result = where(queue_id, ttnn::ge(queue_id, result, -1e+38, std::nullopt, output_mem_config), -t_inf, result, output_mem_config);
if(input_grad.has_value()){
where(queue_id,
ttnn::logical_and(
ttnn::ge(queue_id, ttnn::abs(queue_id, exp_result, output_mem_config), 1e+38, std::nullopt, output_mem_config),
ttnn::ltz(queue_id, grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, result, output_mem_config, input_grad);
} else {
input_grad = where(queue_id,
ttnn::logical_and(
ttnn::ge(queue_id, ttnn::abs(queue_id, exp_result, output_mem_config), 1e+38, std::nullopt, output_mem_config),
ttnn::ltz(queue_id, grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, result, output_mem_config);
}
where(queue_id,
ttnn::logical_and(
ttnn::ge(queue_id, ttnn::abs(queue_id, exp_result, output_mem_config), 1e+38, std::nullopt, output_mem_config),
ttnn::ltz(queue_id, grad, output_mem_config), std::nullopt, output_mem_config), -t_inf, result, output_mem_config, input_grad);

grad_tensor.emplace_back(input_grad);
return grad_tensor;
}

std::vector<std::optional<Tensor>> _tanh_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> ExecuteUnaryBackwardTanh::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> grad_tensor;
TT_FATAL(are_required_outputs.at(0), "input_grad derivative is required output");

input_grad = input_grad.value_or(ttnn::zeros_like(input));
Tensor tanh_res = ttnn::tanh(queue_id, input, output_mem_config);
tanh_res = ttnn::square(queue_id, tanh_res, output_mem_config);
tanh_res = ttnn::rsub(queue_id, tanh_res, 1.0f, output_mem_config);
if(input_grad.has_value()){
ttnn::multiply(queue_id, grad, tanh_res, std::nullopt, output_mem_config, input_grad);
} else {
input_grad = ttnn::multiply(queue_id, grad, tanh_res, std::nullopt, output_mem_config);
}
ttnn::multiply(queue_id, grad, tanh_res, std::nullopt, output_mem_config, input_grad);
grad_tensor.emplace_back(input_grad);
return grad_tensor;
}

std::vector<std::optional<Tensor>> _sqrt_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> ExecuteUnaryBackwardSqrt::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> grad_tensor;
TT_FATAL(are_required_outputs.at(0), "input_grad derivative is required output");

float t_nan = std::nanf("");
float t_inf = std::numeric_limits<float>::infinity();

if(input_grad.has_value()){
input_grad = input_grad.value_or(ttnn::zeros_like(input));
ttnn::sqrt(queue_id, input, output_mem_config, input_grad);
ttnn::multiply(queue_id, grad, ttnn::reciprocal(queue_id, ttnn::multiply(queue_id, input_grad.value(), 2.0, std::nullopt, output_mem_config), output_mem_config),std::nullopt,output_mem_config, input_grad);
where(queue_id, ttnn::lez(queue_id, input, output_mem_config), t_nan, input_grad.value(), output_mem_config, input_grad);
where(queue_id,ttnn::logical_and(queue_id, ttnn::eqz(queue_id, input, output_mem_config), ttnn::ltz(queue_id, grad, output_mem_config), std::nullopt, output_mem_config), -t_inf,input_grad.value(),output_mem_config,input_grad);
where(queue_id, ttnn::logical_and(queue_id, ttnn::eqz(queue_id, input, output_mem_config), ttnn::gtz(queue_id, grad, output_mem_config), std::nullopt, output_mem_config), t_inf,input_grad.value(),output_mem_config,input_grad);
} else {
Tensor sqrt_result = ttnn::sqrt(queue_id, input, output_mem_config);
Tensor result = ttnn::multiply(queue_id, grad, ttnn::reciprocal(queue_id, ttnn::multiply(queue_id, sqrt_result, 2.0, std::nullopt, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
sqrt_result.deallocate();
input_grad = where(queue_id, ttnn::lez(queue_id, input, output_mem_config), t_nan, result, output_mem_config);
input_grad = where(queue_id, ttnn::logical_and(queue_id, ttnn::eqz(queue_id, input, output_mem_config), ttnn::ltz(queue_id, grad, output_mem_config), std::nullopt, output_mem_config),-t_inf, input_grad.value(),output_mem_config);
input_grad = where(queue_id, ttnn::logical_and(queue_id, ttnn::eqz(queue_id, input, output_mem_config), ttnn::gtz(queue_id, grad, output_mem_config), std::nullopt, output_mem_config),t_inf, input_grad.value(), output_mem_config);
}
grad_tensor.emplace_back(input_grad);
return grad_tensor;
}
Expand Down Expand Up @@ -744,21 +721,23 @@ std::vector<Tensor> _abs_bw(const Tensor& grad, const Tensor& input, const std::

// Silu
// result: grad * sigmoid_result * (1 + input * (1 - sigmoid_result))
std::vector<Tensor> _silu_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_sigmoid = ttnn::multiply(grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config);
Tensor add_sub = ttnn::add(
ttnn::multiply(ttnn::subtract(ttnn::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config),
std::vector<std::optional<Tensor>> ExecuteUnaryBackwardSilu::invoke(uint8_t queue_id, const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config, std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> result = {std::nullopt};

input_grad = input_grad.value_or(ttnn::zeros_like(input));
Tensor grad_sigmoid = ttnn::multiply(queue_id, grad, ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config);
Tensor add_sub = ttnn::add(queue_id,
ttnn::multiply(queue_id, ttnn::subtract(queue_id, ttnn::full_like(input, 1.0f) , ttnn::sigmoid(input, output_mem_config), std::nullopt, output_mem_config),
input,
std::nullopt,
output_mem_config),
1.0f,
std::nullopt,
output_mem_config);
Tensor grad_result = ttnn::multiply(grad_sigmoid, add_sub, std::nullopt, output_mem_config);
ttnn::multiply(queue_id, grad_sigmoid, add_sub, std::nullopt, output_mem_config, input_grad);

grad_tensor.emplace_back(grad_result);
return grad_tensor;
result[0] = input_grad;
return result;
}

// Selu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ enum class UnaryBackwardOpType {
SOFTPLUS_BW,
DIV_BW,
RDIV_BW,
POW_BW,
TANH_BW,
EXP_BW,
SQRT_BW,
ASSIGN_BW,
MULTIGAMMALN_BW,
ADD_BW,
Expand Down Expand Up @@ -59,7 +55,6 @@ enum class UnaryBackwardOpType {
LOG_BW,
RELU6_BW,
ABS_BW,
SILU_BW,
SELU_BW,
SQUARE_BW,
HARDSWISH_BW,
Expand Down Expand Up @@ -129,7 +124,6 @@ std::vector<Tensor> _log1p_bw(const Tensor& grad, const Tensor& input, const std
std::vector<Tensor> _erfc_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _relu6_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _abs_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _silu_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _selu_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _square_bw(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);

Expand Down Expand Up @@ -174,12 +168,6 @@ std::vector<Tensor> _gelu_bw( const Tensor& grad, const Tensor& input, string ap

std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::Shape& shape, const std::optional<MemoryConfig>& output_mem_config);

std::vector<std::optional<Tensor>> _pow_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config , const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad);

std::vector<std::optional<Tensor>> _exp_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad);
std::vector<std::optional<Tensor>> _tanh_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad);
std::vector<std::optional<Tensor>> _sqrt_bw(uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad);

std::vector<Tensor> _prod_bw( const Tensor& grad, const Tensor& input, bool all_dimensions = true, int64_t dim = 0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config);

Expand Down Expand Up @@ -460,13 +448,6 @@ struct OpHandler<UnaryBackwardOpType::ABS_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SILU_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
return _silu_bw(grad, input, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SELU_BW> {
static std::vector<Tensor> handle(const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config) {
Expand Down Expand Up @@ -656,34 +637,6 @@ struct OpHandler<UnaryBackwardOpType::RDIV_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::POW_BW> {
static std::vector<std::optional<Tensor>> handle( uint8_t queue_id, const Tensor& grad, const Tensor& input, float exponent, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad ) {
return _pow_bw(queue_id, grad, input, exponent, output_mem_config, are_required_outputs, input_grad);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::EXP_BW> {
static std::vector<std::optional<Tensor>> handle( uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad ) {
return _exp_bw(queue_id, grad, input, output_mem_config, are_required_outputs, input_grad);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::TANH_BW> {
static std::vector<std::optional<Tensor>> handle( uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad ) {
return _tanh_bw(queue_id, grad, input, output_mem_config, are_required_outputs, input_grad);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SQRT_BW> {
static std::vector<std::optional<Tensor>> handle( uint8_t queue_id, const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config, const std::vector<bool>& are_required_outputs, std::optional<Tensor> input_grad ) {
return _sqrt_bw(queue_id, grad, input, output_mem_config, are_required_outputs, input_grad);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::GELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, string approximate, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down
Loading

0 comments on commit 09f1bd1

Please sign in to comment.