From 2e33bb6fac0fe4a31452fc09e67e1d35ad3ead82 Mon Sep 17 00:00:00 2001 From: andrej Date: Tue, 2 Apr 2024 15:23:21 -0700 Subject: [PATCH] matmul: fix error reporting --- .../ipu-xrt/matrix_multiplication/common.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/reference_designs/ipu-xrt/matrix_multiplication/common.h b/reference_designs/ipu-xrt/matrix_multiplication/common.h index 4a303dc267..6ad8dbb876 100644 --- a/reference_designs/ipu-xrt/matrix_multiplication/common.h +++ b/reference_designs/ipu-xrt/matrix_multiplication/common.h @@ -268,7 +268,7 @@ void print_matrix(const std::vector matrix, int n_cols, #undef print_row } -constexpr int max_printable_errors = 10; +constexpr int max_printable_errors = 32; template struct error { @@ -282,7 +282,7 @@ template std::optional> verify_single(std::ostream &os, int row, int col, Tout expected, Tout actual) { const float absTol = 0.5; - const float relTol = 0.5; + const float relTol = 0.1; if (!nearly_equal(expected, actual, relTol, absTol)) { return (struct error){row, col, expected, actual}; } @@ -298,7 +298,7 @@ void print_error_summary(std::ostream &os, int n_errors, << (float)err.actual << " =!= " << std::setw(4) << std::setprecision(2) << std::fixed << (float)err.expected << std::endl; } - if (n_errors >= max_printable_errors) { + if (n_errors > max_printable_errors) { os << "...and " << std::setw(0) << n_errors - max_printable_errors << " further errors." << std::endl; } @@ -325,10 +325,10 @@ int verify(int M, int N, int K, std::vector A, std::vector B, std::optional> error = verify_single( std::cout, row, col, CRef[row * N + col], C[row * N + col]); if (error.has_value()) { - n_errors++; if (n_errors < max_printable_errors) { errors.push_back(*error); } + n_errors++; } } } @@ -375,10 +375,10 @@ int verify_stochastic(int M, int N, int K, std::vector A, std::optional> error = verify_single(std::cout, row, col, ref, C[row * N + col]); if (error.has_value()) { - n_errors++; if (n_errors < max_printable_errors) { errors.push_back(*error); } + n_errors++; } } std::cout << std::endl;