Skip to content

Commit

Permalink
#7543: Fix moreh_nll_loss test
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjin-na committed Jun 14, 2024
1 parent 4950081 commit d4321a8
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ Tensor moreh_nll_loss(
output_mem_config,
compute_kernel_config);

moreh_sum(step1_result, std::nullopt, true, divisor_tensor.value());
moreh_sum(step1_result, std::nullopt, false, divisor_tensor.value());

const Tensor& step2_result = moreh_nll_loss_step2(
input_tensor,
Expand All @@ -331,7 +331,7 @@ Tensor moreh_nll_loss(
reduction_mean,
output_mem_config,
compute_kernel_config);
return moreh_sum(step2_result, std::nullopt, true, output_tensor);
return moreh_sum(step2_result, std::nullopt, false, output_tensor);
} else {
const Tensor& step2_result = moreh_nll_loss_step2(
input_tensor,
Expand All @@ -343,7 +343,7 @@ Tensor moreh_nll_loss(
output_mem_config,
compute_kernel_config);

return moreh_sum(step2_result, std::nullopt, true, output_tensor);
return moreh_sum(step2_result, std::nullopt, false, output_tensor);
}
}

Expand Down
6 changes: 0 additions & 6 deletions tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,6 @@ Tensor moreh_sum(
std::vector<int64_t> dims = get_dim(dim, input.get_legacy_shape().rank());
std::sort(dims.begin(), dims.end());

std::cout << input.get_legacy_shape().rank() << " +_+1 \n";
if (output.has_value()) {
std::cout << output.value().get_legacy_shape().rank() << " +_+2 \n";

}

auto temp_input = input;
for (uint32_t i = dims.size() - 1; i > 0; i--) {
log_debug(LogOp, "{}:{} dim {} keep_batch_dim {}", __func__, __LINE__, dims[i], keep_batch_dim);
Expand Down

0 comments on commit d4321a8

Please sign in to comment.