Skip to content

Commit

Permalink
#15642: Update files
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Dec 11, 2024
1 parent a813ef3 commit de2b60b
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,8 @@ Tensor ExecutePrelu::invoke(

Tensor ExecutePrelu::invoke(
const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
const auto s_a = input_a.get_shape();
const auto s_a = input_a.get_logical_shape();
const auto volume = input_b.get_logical_volume();

TT_FATAL(
s_a[1] == volume,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = {} and channel size = {}.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_f
}
if (height_b == 1) {
if (tensor_args.input_tensor_a.is_sharded()) {
if (tensor_args.input_tensor_a.get_logical_shape()[0] ==
tensor_args.input_tensor_b->get_logical_shape()[0] ||
tensor_args.input_tensor_a.get_logical_shape()[0] > 1 and
tensor_args.input_tensor_b->get_logical_shape()[0] == 1) {
if (tensor_args.input_tensor_a.get_padded_shape()[0] ==
tensor_args.input_tensor_b->get_padded_shape()[0] ||
tensor_args.input_tensor_a.get_padded_shape()[0] > 1 and
tensor_args.input_tensor_b->get_padded_shape()[0] == 1) {
return BroadcastHeightMultiCoreShardedOptimized{};
}
return BroadcastHeightMultiCoreSharded{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create(
const auto& b = tensor_args.input_tensor_b;
auto& output = tensor_return_value;
auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type);
const auto ashape = a.get_logical_shape();
const auto bshape = b.has_value() ? b->get_logical_shape() : Shape{1, 1};
const auto ashape = a.get_padded_shape();
const auto bshape = b.has_value() ? b->get_padded_shape() : Shape{1, 1};
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down Expand Up @@ -298,8 +298,8 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a

auto dst_buffer = output_tensor.buffer();

const auto ashape = input_tensor_a.get_logical_shape();
const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_logical_shape() : Shape{1, 1};
const auto ashape = input_tensor_a.get_padded_shape();
const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_padded_shape() : Shape{1, 1};
uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1;
uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1;
uint32_t H = ashape[-2];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1874,7 +1874,7 @@ std::vector<Tensor> ExecuteUnaryBackwardProd::invoke(
// all_dimensions = False
Tensor updated_grad = prod_result;
auto step = ttnn::SmallVector<uint32_t>({1, 1, 1, 1});
if (prod_result.get_logical_shape() != grad.get_logical_shape()) {
if (prod_result.get_logical_shape() != grad.get_padded_shape()) {
if (dim == 3 || dim == -1) {
ttnn::SmallVector<int64_t> after_permute_dims = {0, 3, 1, 2};
Tensor required = ttnn::permute(grad, after_permute_dims, output_memory_config);
Expand Down

0 comments on commit de2b60b

Please sign in to comment.