From 07653e1c6daf5e553fd1cf161d8f1cd50c5aa121 Mon Sep 17 00:00:00 2001
From: michelletran-codecov
 <167130096+michelletran-codecov@users.noreply.github.com>
Date: Thu, 16 May 2024 10:05:15 -0400
Subject: [PATCH] Round the coverage numbers before saving to database (#447)

---
 database/models/reports.py        | 12 +++++++++--
 helpers/number.py                 | 19 ++++++++++++++++++
 helpers/tests/unit/test_number.py | 32 ++++++++++++++++++++++++++++++
 services/report/__init__.py       | 33 +++++++++++++++++++++++++++----
 services/yaml/reader.py           | 13 +++++-------
 5 files changed, 95 insertions(+), 14 deletions(-)
 create mode 100644 helpers/number.py
 create mode 100644 helpers/tests/unit/test_number.py

diff --git a/database/models/reports.py b/database/models/reports.py
index 5fa7a6037..809d4465a 100644
--- a/database/models/reports.py
+++ b/database/models/reports.py
@@ -1,5 +1,6 @@
 import logging
 import uuid
+from decimal import Decimal
 from functools import cached_property
 
 from shared.reports.types import ReportTotals, SessionTotalsArray
@@ -13,6 +14,7 @@
 from database.utils import ArchiveField
 from helpers.clock import get_utc_now
 from helpers.config import should_write_data_to_storage_config_check
+from helpers.number import precise_round
 
 log = logging.getLogger(__name__)
 
@@ -196,10 +198,16 @@ class AbstractTotals(MixinBaseClass):
     partials = Column(types.Integer)
     files = Column(types.Integer)
 
-    def update_from_totals(self, totals):
+    def update_from_totals(self, totals, precision=2, rounding="down"):
         self.branches = totals.branches
+        if totals.coverage is not None:
+            coverage: Decimal = Decimal(totals.coverage)
+            self.coverage = precise_round(
+                coverage, precision=precision, rounding=rounding
+            )
         # Temporary until the table starts accepting NULLs
-        self.coverage = totals.coverage if totals.coverage is not None else 0
+        else:
+            self.coverage = 0
         self.hits = totals.hits
         self.lines = totals.lines
         self.methods = totals.methods
diff --git a/helpers/number.py b/helpers/number.py
new file mode 100644
index 000000000..36ce841bc
--- /dev/null
+++ b/helpers/number.py
@@ -0,0 +1,19 @@
+from decimal import ROUND_CEILING, ROUND_FLOOR, ROUND_HALF_EVEN, Decimal
+
+
+def precise_round(
+    number: Decimal, precision: int = 2, rounding: str = "down"
+) -> Decimal:
+    """
+    Helper function to do more precise rounding given a precision and rounding strategy.
+    :param number: Number to round
+    :param precision: The number of decimal places to round to
+    :param rounding: Rounding strategy to use, which can be "down", "up" or "nearest"
+    :return: The rounded number as a Decimal object
+    """
+    quantizer = Decimal("0.1") ** precision
+    if rounding == "up":
+        return number.quantize(quantizer, rounding=ROUND_CEILING)
+    if rounding == "down":
+        return number.quantize(quantizer, rounding=ROUND_FLOOR)
+    return number.quantize(quantizer, rounding=ROUND_HALF_EVEN)
diff --git a/helpers/tests/unit/test_number.py b/helpers/tests/unit/test_number.py
new file mode 100644
index 000000000..91da08a83
--- /dev/null
+++ b/helpers/tests/unit/test_number.py
@@ -0,0 +1,32 @@
+from decimal import Decimal
+
+import pytest
+
+from helpers.number import precise_round
+
+
+@pytest.mark.parametrize(
+    "number,precision,rounding,expected_rounding",
+    [
+        (Decimal("1.129"), 2, "down", Decimal("1.12")),
+        (Decimal("1.121"), 2, "up", Decimal("1.13")),
+        (Decimal("1.125"), 1, "nearest", Decimal("1.1")),
+        (Decimal("1.18"), 1, "nearest", Decimal("1.2")),
+        (Decimal("1.15"), 1, "nearest", Decimal("1.2")),
+        (Decimal("1.25"), 1, "nearest", Decimal("1.2")),
+    ],
+    ids=[
+        "number rounds down",
+        "number rounds up",
+        "number rounds nearest (down)",
+        "number rounds nearest (up)",
+        "number rounds half-even (up)",
+        "number rounds half-even (down)",
+    ],
+)
+def test_precise_round(
+    number: Decimal, precision: int, rounding: str, expected_rounding: Decimal
+):
+    assert expected_rounding == precise_round(
+        number, precision=precision, rounding=rounding
+    )
diff --git a/services/report/__init__.py b/services/report/__init__.py
index 05c14373f..d974cacab 100644
--- a/services/report/__init__.py
+++ b/services/report/__init__.py
@@ -58,7 +58,7 @@
 from services.report.parser.types import ParsedRawReport
 from services.report.raw_upload_processor import process_raw_upload
 from services.repository import get_repo_provider_service
-from services.yaml.reader import get_paths_from_flags
+from services.yaml.reader import get_paths_from_flags, read_yaml_field
 
 
 @dataclass
@@ -1050,6 +1050,12 @@ def build_report_from_raw_content(
     def update_upload_with_processing_result(
         self, upload_obj: Upload, processing_result: ProcessingResult
     ):
+        rounding: str = read_yaml_field(
+            self.current_yaml, ("coverage", "round"), "nearest"
+        )
+        precision: int = read_yaml_field(
+            self.current_yaml, ("coverage", "precision"), 2
+        )
         db_session = upload_obj.get_db_session()
         session = processing_result.session
         if processing_result.error is None:
@@ -1078,7 +1084,9 @@ def update_upload_with_processing_result(
                 )
                 db_session.add(upload_totals)
             if session.totals is not None:
-                upload_totals.update_from_totals(session.totals)
+                upload_totals.update_from_totals(
+                    session.totals, precision=precision, rounding=rounding
+                )
         else:
             error = processing_result.error
             upload_obj.state = "error"
@@ -1092,6 +1100,12 @@ def update_upload_with_processing_result(
             db_session.flush()
 
     def save_report(self, commit: Commit, report: Report, report_code=None):
+        rounding: str = read_yaml_field(
+            self.current_yaml, ("coverage", "round"), "nearest"
+        )
+        precision: int = read_yaml_field(
+            self.current_yaml, ("coverage", "precision"), 2
+        )
         if len(report._chunks) > 2 * len(report._files) and len(report._files) > 0:
             report.repack()
         archive_service = self.get_archive_service(commit.repository)
@@ -1144,7 +1158,10 @@ def save_report(self, commit: Commit, report: Report, report_code=None):
             if report_totals is None:
                 report_totals = ReportLevelTotals(report_id=commit.report.id)
                 db_session.add(report_totals)
-            report_totals.update_from_totals(report.totals)
+
+            report_totals.update_from_totals(
+                report.totals, precision=precision, rounding=rounding
+            )
             db_session.flush()
         log.info(
             "Archived report",
@@ -1172,6 +1189,12 @@ def save_full_report(self, commit: Commit, report: Report, report_code=None):
         Returns:
             TYPE: Description
         """
+        rounding: str = read_yaml_field(
+            self.current_yaml, ("coverage", "round"), "nearest"
+        )
+        precision: int = read_yaml_field(
+            self.current_yaml, ("coverage", "precision"), 2
+        )
         res = self.save_report(commit, report, report_code)
         db_session = commit.get_db_session()
         for sess_id, session in report.sessions.items():
@@ -1200,7 +1223,9 @@ def save_full_report(self, commit: Commit, report: Report, report_code=None):
             if session.totals is not None:
                 upload_totals = UploadLevelTotals(upload_id=upload.id_)
                 db_session.add(upload_totals)
-                upload_totals.update_from_totals(session.totals)
+                upload_totals.update_from_totals(
+                    session.totals, precision=precision, rounding=rounding
+                )
         return res
 
     async def save_parallel_report_to_archive(
diff --git a/services/yaml/reader.py b/services/yaml/reader.py
index c456fd2e0..2c00784ae 100644
--- a/services/yaml/reader.py
+++ b/services/yaml/reader.py
@@ -1,10 +1,11 @@
 import logging
-from decimal import ROUND_CEILING, ROUND_FLOOR, ROUND_HALF_EVEN, Decimal
+from decimal import Decimal
 from typing import Any, List, Mapping
 
 from shared.yaml.user_yaml import UserYaml
 
 from helpers.components import Component
+from helpers.number import precise_round
 
 log = logging.getLogger(__name__)
 
@@ -32,14 +33,10 @@ def get_minimum_precision(yaml_dict: Mapping[str, Any]) -> Decimal:
     return Decimal("0.1") ** precision
 
 
-def round_number(yaml_dict: UserYaml, number: Decimal):
+def round_number(yaml_dict: UserYaml, number: Decimal) -> Decimal:
     rounding = read_yaml_field(yaml_dict, ("coverage", "round"), "nearest")
-    quantizer = get_minimum_precision(yaml_dict)
-    if rounding == "up":
-        return number.quantize(quantizer, rounding=ROUND_CEILING)
-    if rounding == "down":
-        return number.quantize(quantizer, rounding=ROUND_FLOOR)
-    return number.quantize(quantizer, rounding=ROUND_HALF_EVEN)
+    precision = read_yaml_field(yaml_dict, ("coverage", "precision"), 2)
+    return precise_round(number, precision=precision, rounding=rounding)
 
 
 def get_paths_from_flags(yaml_dict: UserYaml, flags):