From 07653e1c6daf5e553fd1cf161d8f1cd50c5aa121 Mon Sep 17 00:00:00 2001 From: michelletran-codecov <167130096+michelletran-codecov@users.noreply.github.com> Date: Thu, 16 May 2024 10:05:15 -0400 Subject: [PATCH] Round the coverage numbers before saving to database (#447) --- database/models/reports.py | 12 +++++++++-- helpers/number.py | 19 ++++++++++++++++++ helpers/tests/unit/test_number.py | 32 ++++++++++++++++++++++++++++++ services/report/__init__.py | 33 +++++++++++++++++++++++++++---- services/yaml/reader.py | 13 +++++------- 5 files changed, 95 insertions(+), 14 deletions(-) create mode 100644 helpers/number.py create mode 100644 helpers/tests/unit/test_number.py diff --git a/database/models/reports.py b/database/models/reports.py index 5fa7a6037..809d4465a 100644 --- a/database/models/reports.py +++ b/database/models/reports.py @@ -1,5 +1,6 @@ import logging import uuid +from decimal import Decimal from functools import cached_property from shared.reports.types import ReportTotals, SessionTotalsArray @@ -13,6 +14,7 @@ from database.utils import ArchiveField from helpers.clock import get_utc_now from helpers.config import should_write_data_to_storage_config_check +from helpers.number import precise_round log = logging.getLogger(__name__) @@ -196,10 +198,16 @@ class AbstractTotals(MixinBaseClass): partials = Column(types.Integer) files = Column(types.Integer) - def update_from_totals(self, totals): + def update_from_totals(self, totals, precision=2, rounding="down"): self.branches = totals.branches + if totals.coverage is not None: + coverage: Decimal = Decimal(totals.coverage) + self.coverage = precise_round( + coverage, precision=precision, rounding=rounding + ) # Temporary until the table starts accepting NULLs - self.coverage = totals.coverage if totals.coverage is not None else 0 + else: + self.coverage = 0 self.hits = totals.hits self.lines = totals.lines self.methods = totals.methods diff --git a/helpers/number.py b/helpers/number.py new file mode 100644 index 000000000..36ce841bc --- /dev/null +++ b/helpers/number.py @@ -0,0 +1,19 @@ +from decimal import ROUND_CEILING, ROUND_FLOOR, ROUND_HALF_EVEN, Decimal + + +def precise_round( + number: Decimal, precision: int = 2, rounding: str = "down" +) -> Decimal: + """ + Helper function to do more precise rounding given a precision and rounding strategy. + :param number: Number to round + :param precision: The number of decimal places to round to + :param rounding: Rounding strategy to use, which can be "down", "up" or "nearest" + :return: The rounded number as a Decimal object + """ + quantizer = Decimal("0.1") ** precision + if rounding == "up": + return number.quantize(quantizer, rounding=ROUND_CEILING) + if rounding == "down": + return number.quantize(quantizer, rounding=ROUND_FLOOR) + return number.quantize(quantizer, rounding=ROUND_HALF_EVEN) diff --git a/helpers/tests/unit/test_number.py b/helpers/tests/unit/test_number.py new file mode 100644 index 000000000..91da08a83 --- /dev/null +++ b/helpers/tests/unit/test_number.py @@ -0,0 +1,32 @@ +from decimal import Decimal + +import pytest + +from helpers.number import precise_round + + +@pytest.mark.parametrize( + "number,precision,rounding,expected_rounding", + [ + (Decimal("1.129"), 2, "down", Decimal("1.12")), + (Decimal("1.121"), 2, "up", Decimal("1.13")), + (Decimal("1.125"), 1, "nearest", Decimal("1.1")), + (Decimal("1.18"), 1, "nearest", Decimal("1.2")), + (Decimal("1.15"), 1, "nearest", Decimal("1.2")), + (Decimal("1.25"), 1, "nearest", Decimal("1.2")), + ], + ids=[ + "number rounds down", + "number rounds up", + "number rounds nearest (down)", + "number rounds nearest (up)", + "number rounds half-even (up)", + "number rounds half-even (down)", + ], +) +def test_precise_round( + number: Decimal, precision: int, rounding: str, expected_rounding: Decimal +): + assert expected_rounding == precise_round( + number, precision=precision, rounding=rounding + ) diff --git a/services/report/__init__.py b/services/report/__init__.py index 05c14373f..d974cacab 100644 --- a/services/report/__init__.py +++ b/services/report/__init__.py @@ -58,7 +58,7 @@ from services.report.parser.types import ParsedRawReport from services.report.raw_upload_processor import process_raw_upload from services.repository import get_repo_provider_service -from services.yaml.reader import get_paths_from_flags +from services.yaml.reader import get_paths_from_flags, read_yaml_field @dataclass @@ -1050,6 +1050,12 @@ def build_report_from_raw_content( def update_upload_with_processing_result( self, upload_obj: Upload, processing_result: ProcessingResult ): + rounding: str = read_yaml_field( + self.current_yaml, ("coverage", "round"), "nearest" + ) + precision: int = read_yaml_field( + self.current_yaml, ("coverage", "precision"), 2 + ) db_session = upload_obj.get_db_session() session = processing_result.session if processing_result.error is None: @@ -1078,7 +1084,9 @@ def update_upload_with_processing_result( ) db_session.add(upload_totals) if session.totals is not None: - upload_totals.update_from_totals(session.totals) + upload_totals.update_from_totals( + session.totals, precision=precision, rounding=rounding + ) else: error = processing_result.error upload_obj.state = "error" @@ -1092,6 +1100,12 @@ def update_upload_with_processing_result( db_session.flush() def save_report(self, commit: Commit, report: Report, report_code=None): + rounding: str = read_yaml_field( + self.current_yaml, ("coverage", "round"), "nearest" + ) + precision: int = read_yaml_field( + self.current_yaml, ("coverage", "precision"), 2 + ) if len(report._chunks) > 2 * len(report._files) and len(report._files) > 0: report.repack() archive_service = self.get_archive_service(commit.repository) @@ -1144,7 +1158,10 @@ def save_report(self, commit: Commit, report: Report, report_code=None): if report_totals is None: report_totals = ReportLevelTotals(report_id=commit.report.id) db_session.add(report_totals) - report_totals.update_from_totals(report.totals) + + report_totals.update_from_totals( + report.totals, precision=precision, rounding=rounding + ) db_session.flush() log.info( "Archived report", @@ -1172,6 +1189,12 @@ def save_full_report(self, commit: Commit, report: Report, report_code=None): Returns: TYPE: Description """ + rounding: str = read_yaml_field( + self.current_yaml, ("coverage", "round"), "nearest" + ) + precision: int = read_yaml_field( + self.current_yaml, ("coverage", "precision"), 2 + ) res = self.save_report(commit, report, report_code) db_session = commit.get_db_session() for sess_id, session in report.sessions.items(): @@ -1200,7 +1223,9 @@ def save_full_report(self, commit: Commit, report: Report, report_code=None): if session.totals is not None: upload_totals = UploadLevelTotals(upload_id=upload.id_) db_session.add(upload_totals) - upload_totals.update_from_totals(session.totals) + upload_totals.update_from_totals( + session.totals, precision=precision, rounding=rounding + ) return res async def save_parallel_report_to_archive( diff --git a/services/yaml/reader.py b/services/yaml/reader.py index c456fd2e0..2c00784ae 100644 --- a/services/yaml/reader.py +++ b/services/yaml/reader.py @@ -1,10 +1,11 @@ import logging -from decimal import ROUND_CEILING, ROUND_FLOOR, ROUND_HALF_EVEN, Decimal +from decimal import Decimal from typing import Any, List, Mapping from shared.yaml.user_yaml import UserYaml from helpers.components import Component +from helpers.number import precise_round log = logging.getLogger(__name__) @@ -32,14 +33,10 @@ def get_minimum_precision(yaml_dict: Mapping[str, Any]) -> Decimal: return Decimal("0.1") ** precision -def round_number(yaml_dict: UserYaml, number: Decimal): +def round_number(yaml_dict: UserYaml, number: Decimal) -> Decimal: rounding = read_yaml_field(yaml_dict, ("coverage", "round"), "nearest") - quantizer = get_minimum_precision(yaml_dict) - if rounding == "up": - return number.quantize(quantizer, rounding=ROUND_CEILING) - if rounding == "down": - return number.quantize(quantizer, rounding=ROUND_FLOOR) - return number.quantize(quantizer, rounding=ROUND_HALF_EVEN) + precision = read_yaml_field(yaml_dict, ("coverage", "precision"), 2) + return precise_round(number, precision=precision, rounding=rounding) def get_paths_from_flags(yaml_dict: UserYaml, flags):