diff --git a/tasks/tests/unit/test_upload_task.py b/tasks/tests/unit/test_upload_task.py index ce05cdf85..fed9dbdcd 100644 --- a/tasks/tests/unit/test_upload_task.py +++ b/tasks/tests/unit/test_upload_task.py @@ -723,6 +723,7 @@ def test_upload_task_no_bot( commit.report, mocker.ANY, mocker.ANY, + mocker.ANY, ) assert not mocked_fetch_yaml.called @@ -777,6 +778,7 @@ def test_upload_task_bot_no_permissions( commit.report, mocker.ANY, mocker.ANY, + mocker.ANY, ) assert not mocked_fetch_yaml.called @@ -852,6 +854,7 @@ def test_upload_task_bot_unauthorized( commit.report, mocker.ANY, mocker.ANY, + mocker.ANY, ) def test_upload_task_upload_already_created( @@ -943,6 +946,7 @@ def fail_if_try_to_create_upload(*args, **kwargs): report, mocker.ANY, mocker.ANY, + mocker.ANY, ) @@ -1056,11 +1060,7 @@ def test_schedule_task_with_no_tasks(self, dbsession): dbsession.add(commit) dbsession.flush() result = UploadTask().schedule_task( - commit, - commit_yaml, - argument_list, - ReportFactory.create(), - None, + commit, commit_yaml, argument_list, ReportFactory.create(), None, dbsession ) assert result is None @@ -1074,11 +1074,7 @@ def test_schedule_task_with_one_task(self, dbsession, mocker): dbsession.add(commit) dbsession.flush() result = UploadTask().schedule_task( - commit, - commit_yaml, - argument_list, - ReportFactory.create(), - None, + commit, commit_yaml, argument_list, ReportFactory.create(), None, dbsession ) assert result == mocked_chain.return_value.apply_async.return_value t1 = upload_processor_task.signature( diff --git a/tasks/upload.py b/tasks/upload.py index 71a27b29e..616627772 100644 --- a/tasks/upload.py +++ b/tasks/upload.py @@ -1,3 +1,4 @@ +import copy import logging import re import uuid @@ -16,6 +17,7 @@ TorngitObjectNotFoundError, TorngitRepoNotFoundError, ) +from shared.utils.sessions import SessionType from shared.validation.exceptions import InvalidYamlException from shared.yaml import UserYaml @@ -23,6 +25,7 @@ from database.enums import CommitErrorTypes, ReportType from database.models import Commit, CommitReport from database.models.core import GITHUB_APP_INSTALLATION_DEFAULT_NAME +from database.models.reports import Upload from helpers.checkpoint_logger import _kwargs_key from helpers.checkpoint_logger import from_kwargs as checkpoints_from_kwargs from helpers.checkpoint_logger.flows import UploadFlow @@ -513,6 +516,7 @@ def run_impl_within_lock( argument_list, commit_report, upload_context, + db_session, checkpoints, ) else: @@ -580,6 +584,7 @@ def schedule_task( argument_list, commit_report: CommitReport, upload_context: UploadContext, + db_session, checkpoints=None, ): commit_yaml = commit_yaml.to_dict() @@ -595,6 +600,7 @@ def schedule_task( argument_list, commit_report, upload_context, + db_session, checkpoints=checkpoints, ) elif commit_report.report_type == ReportType.BUNDLE_ANALYSIS.value: @@ -628,6 +634,7 @@ def _schedule_coverage_processing_task( argument_list, commit_report, upload_context: UploadContext, + db_session, checkpoints=None, ): chunk_size = CHUNK_SIZE @@ -662,33 +669,90 @@ def _schedule_coverage_processing_task( redis_key = get_parallel_upload_processing_session_counter_redis_key( repoid=commit.repository.repoid, commitid=commit.commitid ) - + report_service = ReportService(commit_yaml) + sessions = report_service.build_sessions(commit=commit) + commit_yaml = UserYaml(commit_yaml) # if session count expired due to TTL (which is unlikely for most cases), recalculate the # session ids used and set it in redis. if self.parallel_session_count_key_expired( redis_key, upload_context.redis_connection ): - report_service = ReportService(commit_yaml) - sessions = report_service.build_sessions(commit=commit) upload_context.redis_connection.set( redis_key, max(sessions.keys()) + 1 if sessions.keys() else 0, ) - # increment redis to claim session ids - parallel_session_id = ( - upload_context.redis_connection.incrby( - name=redis_key, - amount=num_sessions, - ) - - num_sessions - ) - upload_context.redis_connection.expire( - name=redis_key, - time=PARALLEL_UPLOAD_PROCESSING_SESSION_COUNTER_TTL, - ) + # try to scrap the redis counter idea to fully mimic how session ids are allocated in the + # serial flow. This change is technically less performant, and would not allow for concurrent + # chords to be running at the same time. For now this is just a temporary change, just for + # verifying correctness. + # + # # increment redis to claim session ids + # parallel_session_id = ( + # upload_context.redis_connection.incrby( + # name=redis_key, + # amount=num_sessions, + # ) + # - num_sessions + # ) + # upload_context.redis_connection.expire( + # name=redis_key, + # time=PARALLEL_UPLOAD_PROCESSING_SESSION_COUNTER_TTL, + # ) + + # copied from shared/reports/resources.py Report.next_session_number() + def next_session_number(session_dict): + start_number = len(session_dict) + while start_number in session_dict or str(start_number) in session_dict: + start_number += 1 + return start_number + + # copied and cut down from worker/services/report/raw_upload_processor.py + # this version stripped out all the ATS label stuff + def _adjust_sessions( + original_sessions: dict, + to_merge_flags, + current_yaml, + ): + session_ids_to_fully_delete = [] + flags_under_carryforward_rules = [ + f for f in to_merge_flags if current_yaml.flag_has_carryfoward(f) + ] + if flags_under_carryforward_rules: + for sess_id, curr_sess in original_sessions.items(): + if curr_sess.session_type == SessionType.carriedforward: + if curr_sess.flags: + if any( + f in flags_under_carryforward_rules + for f in curr_sess.flags + ): + session_ids_to_fully_delete.append(sess_id) + if session_ids_to_fully_delete: + # delete sessions from dict + for id in session_ids_to_fully_delete: + original_sessions.pop(id, None) + return + + mock_sessions = copy.deepcopy(sessions) + session_ids_for_parallel_idx = [] + + # iterate over all uploads, get the next session id, and adjust sessions (remove CFF logic) + for i in range(num_sessions): + next_session_id = next_session_number(mock_sessions) + + upload_pk = argument_list[i]["upload_pk"] + upload = db_session.query(Upload).filter_by(id_=upload_pk).first() + to_merge_session = report_service.build_session(upload) + flags = upload.flag_names + + mock_sessions[next_session_id] = to_merge_session + _adjust_sessions(mock_sessions, flags, commit_yaml) + + session_ids_for_parallel_idx.append(next_session_id) parallel_processing_tasks = [] + commit_yaml = commit_yaml.to_dict() + for i in range(0, num_sessions, parallel_chunk_size): chunk = argument_list[i : i + parallel_chunk_size] if chunk: @@ -700,7 +764,9 @@ def _schedule_coverage_processing_task( commit_yaml=commit_yaml, arguments_list=chunk, report_code=commit_report.code, - parallel_idx=i + parallel_session_id, + parallel_idx=session_ids_for_parallel_idx[ + i + ], # i + parallel_session_id, in_parallel=True, ), )