Skip to content

Commit

Permalink
#15642: Replace shapes in eltwise
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Dec 11, 2024
1 parent 8e49222 commit a813ef3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ Tensor ExecutePrelu::invoke(
if (s_a.rank() > 2) {
SmallVector<uint32_t> reshape(s_a.rank(), 1);
reshape[1] = s_a[1];
b = ttnn::reshape(input_b, ttnn::Shape(reshape));
b = ttnn::reshape(input_b, ttnn::SimpleShape(reshape));
}

Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_f
}
if (height_b == 1) {
if (tensor_args.input_tensor_a.is_sharded()) {
if (tensor_args.input_tensor_a.get_padded_shape()[0] ==
tensor_args.input_tensor_b->get_padded_shape()[0] ||
tensor_args.input_tensor_a.get_padded_shape()[0] > 1 and
tensor_args.input_tensor_b->get_padded_shape()[0] == 1) {
if (tensor_args.input_tensor_a.get_logical_shape()[0] ==
tensor_args.input_tensor_b->get_logical_shape()[0] ||
tensor_args.input_tensor_a.get_logical_shape()[0] > 1 and
tensor_args.input_tensor_b->get_logical_shape()[0] == 1) {
return BroadcastHeightMultiCoreShardedOptimized{};
}
return BroadcastHeightMultiCoreSharded{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create(
const auto& b = tensor_args.input_tensor_b;
auto& output = tensor_return_value;
auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type);
const auto ashape = a.get_padded_shape();
const auto bshape = b.has_value() ? b->get_padded_shape() : Shape{1, 1};
const auto ashape = a.get_logical_shape();
const auto bshape = b.has_value() ? b->get_logical_shape() : Shape{1, 1};
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down Expand Up @@ -298,8 +298,8 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a

auto dst_buffer = output_tensor.buffer();

const auto ashape = input_tensor_a.get_padded_shape();
const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_padded_shape() : Shape{1, 1};
const auto ashape = input_tensor_a.get_logical_shape();
const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_logical_shape() : Shape{1, 1};
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ std::vector<Tensor> ExecuteUnaryBackwardRepeat::invoke(
ttnn::SmallVector<int64_t> dim = {0};
TT_FATAL(shape[1] == 1 && shape[2] == 1 && shape[3] == 1, "repeat[1], [2], [3] should be 1");
std::array<std::uint32_t, 4> intended_shape_array = {1, shape_wh[1], shape_wh[2], shape_wh[3]};
const ttnn::Shape required = ttnn::Shape(intended_shape_array);
const auto required = ttnn::SimpleShape(intended_shape_array);
Tensor result = ttnn::moreh_sum(
grad,
dim,
Expand All @@ -1813,7 +1813,7 @@ std::vector<Tensor> ExecuteUnaryBackwardRepeat::invoke(
ttnn::SmallVector<int64_t> dim = {1};
TT_FATAL(shape[0] == 1 && shape[2] == 1 && shape[3] == 1, "repeat[0], [2], [3] should be 1");
std::array<std::uint32_t, 4> intended_shape_array = {shape_wh[0], 1, shape_wh[2], shape_wh[3]};
const ttnn::Shape required = ttnn::Shape(intended_shape_array);
const auto required = ttnn::SimpleShape(intended_shape_array);
Tensor result = ttnn::moreh_sum(
grad,
dim,
Expand Down Expand Up @@ -1874,7 +1874,7 @@ std::vector<Tensor> ExecuteUnaryBackwardProd::invoke(
// all_dimensions = False
Tensor updated_grad = prod_result;
auto step = ttnn::SmallVector<uint32_t>({1, 1, 1, 1});
if (prod_result.get_logical_shape() != grad.get_padded_shape()) {
if (prod_result.get_logical_shape() != grad.get_logical_shape()) {
if (dim == 3 || dim == -1) {
ttnn::SmallVector<int64_t> after_permute_dims = {0, 3, 1, 2};
Tensor required = ttnn::permute(grad, after_permute_dims, output_memory_config);
Expand Down

0 comments on commit a813ef3

Please sign in to comment.