Skip to content

Commit

Permalink
#7543: Fix matmul and linear backward test fail
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjin-na committed Jun 14, 2024
1 parent d4321a8 commit a03ed07
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ inline void check_tensor(
check_tensor(tensor.value(), op_name, data_type, layout);
}

inline void expand_to_max_dim(std::vector<uint32_t> &dim, const Shape& shape) {
const auto rank = shape.rank();
for (auto i = 0; i < rank; ++i) {
auto idx = rank - 1 - i;
dim[i] = shape[idx];
}
}

inline void validate_input_tensor_with_dim(const Tensor& input, const int64_t &dim) {
auto input_shape = input.get_legacy_shape();
auto input_shape_wo_padding = input.get_legacy_shape().without_padding();
Expand All @@ -62,21 +70,41 @@ inline void validate_output_tensor_with_keep_batch_dim(const Tensor& input, cons

const auto& output_shape = output.get_legacy_shape();
const auto& output_shape_wo_padding = output_shape.without_padding();
const auto output_rank = output_shape.rank();

const bool is_tile_dim = (dim == input_rank - 1 || dim == input_rank - 2);

log_debug(LogOp, "{}:{} input_shape {}", __func__, __LINE__, input_shape);
log_debug(LogOp, "{}:{} output_shape {}", __func__, __LINE__, output_shape);
log_debug(LogOp, "{}:{} input_shape_wo_padding {}", __func__, __LINE__, input_shape_wo_padding);
log_debug(LogOp, "{}:{} output_shape_wo_padding {}", __func__, __LINE__, output_shape_wo_padding);
log_debug(LogOp, "{}:{} keep_batch_dim {} dim {}", __func__, __LINE__, keep_batch_dim, dim);
log_debug(LogOp, "{}:{} input_shape {} wo_padding {}", __func__, __LINE__, input_shape, input_shape_wo_padding);
log_debug(LogOp, "{}:{} output_shape {} wo_paddoutg {}", __func__, __LINE__, output_shape, output_shape_wo_padding);

if (keep_batch_dim) {
bool ranks_are_equal = (input_rank == output_rank);
input_shape[dim] = (is_tile_dim) ? (TILE_HEIGHT) : (1);
input_shape_wo_padding[dim] = 1;

if (!ranks_are_equal) {
log_warning(
LogOp,
"{}:{} input_rank {} and output_rank {} are not the same in keep_batch_dim mode",
__func__,
__LINE__,
input_rank,
output_rank);
}

std::vector<uint32_t> input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1);
std::vector<uint32_t> output_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1);
std::vector<uint32_t> input_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1);
std::vector<uint32_t> output_dim_wo_padding(tt::tt_metal::MAX_NUM_DIMENSIONS, 1);
expand_to_max_dim(input_dim, input_shape);
expand_to_max_dim(output_dim, output_shape);
expand_to_max_dim(input_dim_wo_padding, input_shape_wo_padding);
expand_to_max_dim(output_dim_wo_padding, output_shape_wo_padding);

for (int i = 0; i < input_rank; ++i) {
TT_FATAL(input_shape[i] == output_shape[i]);
TT_FATAL(input_shape_wo_padding[i] == output_shape_wo_padding[i]);
TT_FATAL(input_dim[i] == output_dim[i]);
TT_FATAL(input_dim_wo_padding[i] == output_dim_wo_padding[i]);
}
} else {
std::vector<uint32_t> expected_output_shape;
Expand Down

0 comments on commit a03ed07

Please sign in to comment.