Skip to content

Commit

Permalink
#9302: Add additional composite ops
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jun 12, 2024
1 parent 2a4a22c commit 6012f8a
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 80 deletions.
3 changes: 3 additions & 0 deletions tests/ttnn/unit_tests/test_model_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def forward(self, x):
output_tensor = ttnn.to_device(output_tensor, device=device)
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
output_tensor = ttnn.relu(output_tensor)
print(f"Output shape is {output_tensor.shape}")
output_tensor = ttnn.permute(output_tensor, (0, 3, 1, 2))
print(f"Expected shape after permute is ttnn.Shape([1, 128, 28, 28[32]]")
print(f"Actual shape is {output_tensor.shape}")
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.get_fallback_function(ttnn.reshape)(output_tensor, (-1, num_output_channels))
output_tensor = ttnn.to_layout(output_tensor, ttnn.TILE_LAYOUT)
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,8 @@ const Shape Tensor::strides() const { return detail::compute_strides(this->get_l

uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); }

uint32_t Tensor::intended_volume() const { return tt::tt_metal::compute_volume(this->get_shape()); }

Tensor create_device_tensor(
const Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config) {
ZoneScoped;
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ struct Tensor {
StorageType storage_type() const;
const Shape strides() const;
uint32_t volume() const;
uint32_t intended_volume() const;

bool is_allocated() const;

Expand Down
8 changes: 4 additions & 4 deletions tt_eager/tt_dnn/op_library/auto_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class AutoFormat {


static Shape pad_to_tile_shape(const Shape& unpadded_shape, bool pad_c=false, bool pad_n=false, bool pad_h=true, bool pad_w=true) {
auto n = pad_n ? round_up(unpadded_shape[0], TILE_HEIGHT) : unpadded_shape[0];
auto c = pad_c ? round_up(unpadded_shape[1], TILE_WIDTH) : unpadded_shape[1];
auto h = pad_h ? round_up(unpadded_shape[2], TILE_HEIGHT) : unpadded_shape[2];
auto w = pad_w ? round_up(unpadded_shape[3], TILE_WIDTH) : unpadded_shape[3];
auto n = pad_n ? round_up(unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1, TILE_HEIGHT) : unpadded_shape.rank() >= 4 ? unpadded_shape[-4] : 1;
auto c = pad_c ? round_up(unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1, TILE_WIDTH) : unpadded_shape.rank() >= 3 ? unpadded_shape[-3] : 1;
auto h = pad_h ? round_up(unpadded_shape[-2], TILE_HEIGHT) : unpadded_shape[-2];
auto w = pad_w ? round_up(unpadded_shape[-1], TILE_WIDTH) : unpadded_shape[-1];
Shape padded_shape = {n, c, h, w};
return padded_shape;
}
Expand Down
16 changes: 16 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,22 @@ std::vector<Tensor> _prod_bw(
std::vector<int64_t> after_permute_dims = {3, 1, 2, 0};
Tensor tensor_1 = permute(tensor_1_temp, after_permute_dims, output_mem_config);
Tensor tensor_2 = permute(temp, after_permute_dims, output_mem_config);

// put the tensor back on device because permute is throws it off device
// See: Remove auto format within permute_op.cpp #9404

const auto intended_shape = tensor_2.get_shape();
const auto device_shape = tensor_2.get_legacy_shape();
const auto new_intended_shape = std::array<std::uint32_t, 4>{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]};
const auto new_device_shape = std::array<std::uint32_t, 4>{
device_shape[0],
device_shape[1],
(device_shape[-2] % TILE_HEIGHT != 0 ? (device_shape[-2] / TILE_HEIGHT + 1) * TILE_HEIGHT : device_shape[-2]),
(device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1])
};
const auto new_shape = tt_metal::Shape(new_intended_shape, new_device_shape);
tensor_2 = AutoFormat::format_input_tensor(tensor_2, tensor_1.device(), new_shape, 0.0, tensor_1.get_layout());

Tensor result = permute(
bcast(tensor_1, tensor_2, BcastOpMath::MUL, BcastOpDim::W, output_mem_config),
after_permute_dims,
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ inline Tensor bcast(
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))};
operation::launch_with_autoformat(
operation::launch_op(
[bcast_op, bcast_dim, output_mem_config] (const std::vector<Tensor>& input_tensors, const std::vector<std::optional<const Tensor>>& optional_input_tensors, const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
using tt::constants::TILE_HEIGHT;
using tt::constants::TILE_WIDTH;
Expand Down Expand Up @@ -109,7 +109,7 @@ inline Tensor bcast(
input_tensor_b.get_legacy_shape()[-1] == TILE_WIDTH));
}
}
return operation::run_with_autoformat(
return operation::run(
EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config}, {input_tensor_a, input_tensor_b});
}, {input_tensor_a, input_tensor_b}, output_tensors);
return output_tensors.at(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace tt_metal {
operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b, const Tensor& output, BcastOpMath bcast_math/*, BcastOpDim bcast_dim*/){
const auto ashape = a.get_legacy_shape();
const auto bshape = b.get_legacy_shape();
uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3];
uint32_t bN = bshape[0], bC = bshape[1], bH = bshape[2], bW = bshape[3];
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1, C = ashape.rank() >= 3 ? ashape[-3] : 1, H = ashape[-2], W = ashape[-1];
uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1, bC = bshape.rank() >= 3 ? bshape[-3] : 1, bH = bshape[-2], bW = bshape[-1];
uint32_t NC = N*C;


Expand Down
84 changes: 53 additions & 31 deletions tt_eager/tt_dnn/op_library/composite/composite_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,27 @@
#include "tt_eager/tt_dnn/op_library/pad/pad_op.hpp"
#include "tt_eager/tt_dnn/op_library/unpad/unpad_op.hpp"
#include "tt_numpy/functions.hpp"
#include "ttnn/cpp/ttnn/operations/creation.hpp"

namespace tt {

namespace tt_metal {

Tensor mk_zero_tensor_like(const Tensor& reference_tensor, const MemoryConfig& output_mem_config) {
// Tensor zero_like = bcast(reference_tensor, , BcastOpMath::MUL, BcastOpDim::HW);
Tensor zero = mk_tiled_scalar(0.0f, reference_tensor.get_dtype());
Tensor zero = ttnn::operations::creation::create_scalar(0.0f, reference_tensor.get_dtype(),Layout::TILE, reference_tensor.device());
Tensor zero_like = bcast(reference_tensor, zero, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
zero.deallocate();
return zero_like;
}

// TODO: enable zeroes(), ones() and eye() type functions on-device using this type of logic
template <typename T>
Tensor mk_filled_tensor_like(const Tensor& reference_tensor, T val, const MemoryConfig& output_mem_config) {
Tensor k = mk_tiled_scalar(val, reference_tensor.get_dtype());
Tensor k = ttnn::operations::creation::create_scalar(val, reference_tensor.get_dtype(), Layout::TILE, reference_tensor.device());
Tensor zero_like = mk_zero_tensor_like(reference_tensor, output_mem_config);
Tensor result = bcast(zero_like, k, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
k.deallocate();
zero_like.deallocate();
return result;
}

Expand Down Expand Up @@ -255,19 +258,19 @@ Tensor mish(const Tensor& a, const MemoryConfig& output_mem_config) {
Tensor _selu(const Tensor& x, const float scale, const float alpha, const MemoryConfig& output_mem_config) {
// term 2
Tensor x_Exp = exp(x, output_mem_config);
Tensor minus_one = mk_tiled_scalar(-1.0f);
Tensor minus_one = ttnn::operations::creation::create_scalar(-1.0f,x.get_dtype(),Layout::TILE, x.device());
Tensor x_Exp_minus_1 = bcast(x_Exp, minus_one, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
x_Exp.deallocate();
minus_one.deallocate();
Tensor t_alpha = mk_tiled_scalar(alpha);
Tensor t_alpha = ttnn::operations::creation::create_scalar(alpha,x.get_dtype(),Layout::TILE, x.device());
Tensor result_t2_ = bcast(x_Exp_minus_1, t_alpha, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
x_Exp_minus_1.deallocate();
t_alpha.deallocate();
Tensor result_term2 = mul(gtz(result_t2_, output_mem_config), result_t2_, std::nullopt, output_mem_config);
result_t2_.deallocate();

// term 1
Tensor t_scale = mk_tiled_scalar(scale);
Tensor t_scale = ttnn::operations::creation::create_scalar(scale,x.get_dtype(),Layout::TILE, x.device());
Tensor x_relu = relu(x, output_mem_config);
Tensor result_term1 = bcast(x_relu, t_scale, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
t_scale.deallocate();
Expand All @@ -288,7 +291,10 @@ Tensor selu(const Tensor& x, const float scale, const float alpha, const MemoryC
Tensor rpow(const Tensor& a, float k, const MemoryConfig& output_mem_config) {
TT_ASSERT(k > 0.0, "rpow cannot be calcualted for non-positive numbers");
float log_k = logf(k);
Tensor result = bcast(a, mk_tiled_scalar(log_k), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);

Tensor scalar = ttnn::operations::creation::create_scalar(log_k,a.get_dtype(),Layout::TILE, a.device());
Tensor result = bcast(a, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
scalar.deallocate();
return exp(result, output_mem_config);
}

Expand Down Expand Up @@ -349,13 +355,20 @@ Tensor _polyval(const Tensor& input_tensor, std::vector<float> coeffs, const Mem
return mk_filled_tensor_like(input_tensor, coeffs[0], output_mem_config);
}

Tensor scalar = ttnn::operations::creation::create_scalar(coeffs[0], input_tensor.get_dtype(), Layout::TILE, input_tensor.device());
Tensor result =
bcast(input_tensor, mk_tiled_scalar(coeffs[0]), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
bcast(input_tensor, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
scalar.deallocate();
for (int idx = 1; idx < coeffs.size() - 1; idx++) {
result = bcast(result, mk_tiled_scalar(coeffs[idx]), BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
Tensor scalar = ttnn::operations::creation::create_scalar(coeffs[idx], input_tensor.get_dtype(), Layout::TILE, input_tensor.device());
result = bcast(result, scalar, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
scalar.deallocate();
result = mul(input_tensor, result, std::nullopt, output_mem_config);
}
return bcast(result, mk_tiled_scalar(coeffs.back()), BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
Tensor last_coeffs = ttnn::operations::creation::create_scalar(coeffs.back(), input_tensor.get_dtype(), Layout::TILE, input_tensor.device());
Tensor final_tensor = bcast(result, last_coeffs, BcastOpMath::ADD, BcastOpDim::HW, output_mem_config);
last_coeffs.deallocate();
return final_tensor;
}
Tensor polyval(const Tensor& input_tensor, std::vector<float> coeffs, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _polyval)(input_tensor, coeffs, output_mem_config);
Expand All @@ -365,9 +378,9 @@ Tensor polyval(const Tensor& input_tensor, std::vector<float> coeffs, const Memo
// compute multiply-accumulate: y = a * b + c, over various 8 combinations of a, b, c
// being a scalar or tensor
Tensor _mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfig& output_mem_config) {
bool a_is_scalar = a.volume() == 1;
bool b_is_scalar = b.volume() == 1;
bool c_is_scalar = c.volume() == 1;
bool a_is_scalar = a.intended_volume() == 1;
bool b_is_scalar = b.intended_volume() == 1;
bool c_is_scalar = c.intended_volume() == 1;

const auto dim = BcastOpDim::HW;
if (!a_is_scalar && !b_is_scalar && !c_is_scalar) {
Expand Down Expand Up @@ -405,9 +418,12 @@ Tensor mac(const Tensor& a, const Tensor& b, const Tensor& c, const MemoryConfig
}

Tensor _mac_overload(const Tensor& a, float b, float c, const MemoryConfig& output_mem_config) {
Tensor t_b = mk_scalar(b);
Tensor t_c = mk_scalar(c);
return mac(a, t_b, t_c, output_mem_config);
Tensor t_b = ttnn::operations::creation::create_scalar(b, a.get_dtype(), Layout::TILE, a.device());
Tensor t_c = ttnn::operations::creation::create_scalar(c, a.get_dtype(), Layout::TILE, a.device());
Tensor return_tensor = mac(a, t_b, t_c, output_mem_config);
t_b.deallocate();
t_c.deallocate();
return return_tensor;
}
Tensor mac(const Tensor& input_a, float b, float c, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _mac_overload)(input_a, b, c, output_mem_config);
Expand Down Expand Up @@ -451,7 +467,9 @@ Tensor _sinh(const Tensor& input_a, const MemoryConfig& output_mem_config) {
Tensor nr_term = sub(e_pos_x, e_neg_x, std::nullopt, output_mem_config);
e_pos_x.deallocate();
e_neg_x.deallocate();
return bcast(nr_term, mk_tiled_scalar(0.5f), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
Tensor scalar = ttnn::operations::creation::create_scalar(0.5f, input_a.get_dtype(), Layout::TILE, input_a.device());
return bcast(nr_term, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
scalar.deallocate();
}
Tensor sinh(const Tensor& input_a, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _sinh)(input_a, output_mem_config);
Expand All @@ -464,7 +482,9 @@ Tensor _cosh(const Tensor& input_a, const MemoryConfig& output_mem_config) {
Tensor nr_term = add(e_pos_x, e_neg_x, std::nullopt, output_mem_config);
e_pos_x.deallocate();
e_neg_x.deallocate();
return bcast(nr_term, mk_tiled_scalar(0.5f), BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
Tensor scalar = ttnn::operations::creation::create_scalar(0.5f, input_a.get_dtype(), Layout::TILE, input_a.device());
return bcast(nr_term, scalar, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
scalar.deallocate();
}
Tensor cosh(const Tensor& input_a, const MemoryConfig& output_mem_config) {
return operation::decorate_as_composite(__func__, _cosh)(input_a, output_mem_config);
Expand Down Expand Up @@ -509,12 +529,14 @@ Tensor _acosh(const Tensor& input_a, const MemoryConfig& output_mem_config) {
// To handle inputs <= 1
// input < 1, output is nan
// input > 1, output is acosh(input)
Tensor scalar = ttnn::operations::creation::create_scalar(std::nanf(""), input_a.get_dtype(), Layout::TILE, input_a.device());
Tensor nan_res = bcast(
lte(input_a, t_one, std::nullopt, output_mem_config),
mk_tiled_scalar(std::nanf("")),
scalar,
BcastOpMath::MUL,
BcastOpDim::HW,
output_mem_config);
scalar.deallocate();
t_result = mul(gt(input_a, t_one, std::nullopt, output_mem_config), ln_res, std::nullopt, output_mem_config);
t_result = add(nan_res, t_result, std::nullopt, output_mem_config);
}
Expand Down Expand Up @@ -553,7 +575,7 @@ Tensor atanh(const Tensor& input_a, const MemoryConfig& output_mem_config) {

// lerp(input, end, weight) = start + weight * (end - start)
Tensor _lerp(const Tensor& input_a, const Tensor& input_b, float value, const MemoryConfig& output_mem_config) {
Tensor t_value = mk_tiled_scalar(value);
Tensor t_value = ttnn::operations::creation::create_scalar(value,input_a.get_dtype(), Layout::TILE, input_a.device());
Tensor t_diff = sub(input_b, input_a, std::nullopt, output_mem_config);
Tensor t_mul = bcast(t_diff, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
Tensor result = add(input_a, t_mul, std::nullopt, output_mem_config);
Expand Down Expand Up @@ -826,7 +848,7 @@ Tensor _addcmul(
const Tensor& input_c,
float value,
const MemoryConfig& output_mem_config) {
Tensor t_value = mk_tiled_scalar(value);
Tensor t_value = ttnn::operations::creation::create_scalar(value,input_a.get_dtype(), Layout::TILE, input_a.device());
Tensor t_mul = mul(input_b, input_c, std::nullopt, output_mem_config);
Tensor t_factor = bcast(t_mul, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
t_mul.deallocate();
Expand All @@ -850,7 +872,7 @@ Tensor _addcdiv(
const Tensor& input_c,
float value,
const MemoryConfig& output_mem_config) {
Tensor t_value = mk_tiled_scalar(value);
Tensor t_value = ttnn::operations::creation::create_scalar(value,input_a.get_dtype(), Layout::TILE, input_a.device());
Tensor t_div = mul(input_b, recip(input_c, output_mem_config), std::nullopt, output_mem_config);
Tensor t_factor = bcast(t_div, t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
t_div.deallocate();
Expand Down Expand Up @@ -1294,14 +1316,14 @@ Tensor scatter(const Tensor& input_a, const Tensor& input_b, const MemoryConfig&
}

// threshold(a,t,v) = (a <= t)*v + (a > t)*a
Tensor _threshold(const Tensor& input_a, float threshold, float value, const MemoryConfig& output_mem_config) {
Tensor t_threshold = mk_tiled_scalar(threshold, input_a.get_dtype());
Tensor t0 = bcast(input_a, t_threshold, BcastOpMath::SUB, BcastOpDim::HW, output_mem_config);
Tensor _threshold(const Tensor& input_tensor, float threshold, float value, const MemoryConfig& output_mem_config) {
Tensor t_threshold = ttnn::operations::creation::create_scalar(threshold,input_tensor.get_dtype(), Layout::TILE, input_tensor.device());
Tensor t0 = bcast(input_tensor, t_threshold, BcastOpMath::SUB, BcastOpDim::HW, output_mem_config);
t_threshold.deallocate();
Tensor t_value = mk_tiled_scalar(value, input_a.get_dtype());
Tensor t_value = ttnn::operations::creation::create_scalar(value,input_tensor.get_dtype(), Layout::TILE, input_tensor.device());
Tensor t1 = bcast(lez(t0), t_value, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
t_value.deallocate();
Tensor t2 = mul(gtz(t0, output_mem_config), input_a, std::nullopt, output_mem_config);
Tensor t2 = mul(gtz(t0, output_mem_config), input_tensor, std::nullopt, output_mem_config);
return add(t1, t2, std::nullopt, output_mem_config);
}
Tensor threshold(const Tensor& input_a, float threshold, float value, const MemoryConfig& output_mem_config) {
Expand Down Expand Up @@ -1352,16 +1374,16 @@ Tensor digamma(const Tensor& input_a, const MemoryConfig& output_mem_config) {

// cbrt(a) = pow(a,1/3) or (cbrt(a))**3 = a.
// = exp[ (1/3)*log[a] ]
Tensor _cbrt(const Tensor& input_a, const MemoryConfig& output_mem_config) {
Tensor _cbrt(const Tensor& input_tensor, const MemoryConfig& output_mem_config) {
constexpr float scale = (float)(1.0 / 3.0);
Tensor t_scale = mk_tiled_scalar(scale);
Tensor t_ln_input = log(abs(input_a, output_mem_config), output_mem_config); // negative log is not useful here
Tensor t_scale = ttnn::operations::creation::create_scalar(scale,input_tensor.get_dtype(), Layout::TILE, input_tensor.device());
Tensor t_ln_input = log(abs(input_tensor, output_mem_config), output_mem_config); // negative log is not useful here
Tensor t1 = bcast(t_ln_input, t_scale, BcastOpMath::MUL, BcastOpDim::HW, output_mem_config);
t_scale.deallocate();
t_ln_input.deallocate();
Tensor t2 = exp(t1, output_mem_config);
t1.deallocate();
Tensor t3 = mul(t2, sign(input_a, output_mem_config), std::nullopt, output_mem_config);
Tensor t3 = mul(t2, sign(input_tensor, output_mem_config), std::nullopt, output_mem_config);
return t3;
}
Tensor cbrt(const Tensor& input_a, const MemoryConfig& output_mem_config) {
Expand Down
Loading

0 comments on commit 6012f8a

Please sign in to comment.