Skip to content

Commit

Permalink
fix: use serial session id allocation for parallel experiment (#378)
Browse files Browse the repository at this point in the history
* use serial session id allocation

* add adjust_session to session id emulation

* fix tests

* remove redundant var
  • Loading branch information
daniel-codecov authored Apr 11, 2024
1 parent c7d4f45 commit 7d9c198
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 26 deletions.
16 changes: 6 additions & 10 deletions tasks/tests/unit/test_upload_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def test_upload_task_no_bot(
commit.report,
mocker.ANY,
mocker.ANY,
mocker.ANY,
)
assert not mocked_fetch_yaml.called

Expand Down Expand Up @@ -777,6 +778,7 @@ def test_upload_task_bot_no_permissions(
commit.report,
mocker.ANY,
mocker.ANY,
mocker.ANY,
)
assert not mocked_fetch_yaml.called

Expand Down Expand Up @@ -852,6 +854,7 @@ def test_upload_task_bot_unauthorized(
commit.report,
mocker.ANY,
mocker.ANY,
mocker.ANY,
)

def test_upload_task_upload_already_created(
Expand Down Expand Up @@ -943,6 +946,7 @@ def fail_if_try_to_create_upload(*args, **kwargs):
report,
mocker.ANY,
mocker.ANY,
mocker.ANY,
)


Expand Down Expand Up @@ -1056,11 +1060,7 @@ def test_schedule_task_with_no_tasks(self, dbsession):
dbsession.add(commit)
dbsession.flush()
result = UploadTask().schedule_task(
commit,
commit_yaml,
argument_list,
ReportFactory.create(),
None,
commit, commit_yaml, argument_list, ReportFactory.create(), None, dbsession
)
assert result is None

Expand All @@ -1074,11 +1074,7 @@ def test_schedule_task_with_one_task(self, dbsession, mocker):
dbsession.add(commit)
dbsession.flush()
result = UploadTask().schedule_task(
commit,
commit_yaml,
argument_list,
ReportFactory.create(),
None,
commit, commit_yaml, argument_list, ReportFactory.create(), None, dbsession
)
assert result == mocked_chain.return_value.apply_async.return_value
t1 = upload_processor_task.signature(
Expand Down
98 changes: 82 additions & 16 deletions tasks/upload.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import logging
import re
import uuid
Expand All @@ -16,13 +17,15 @@
TorngitObjectNotFoundError,
TorngitRepoNotFoundError,
)
from shared.utils.sessions import SessionType
from shared.validation.exceptions import InvalidYamlException
from shared.yaml import UserYaml

from app import celery_app
from database.enums import CommitErrorTypes, ReportType
from database.models import Commit, CommitReport
from database.models.core import GITHUB_APP_INSTALLATION_DEFAULT_NAME
from database.models.reports import Upload
from helpers.checkpoint_logger import _kwargs_key
from helpers.checkpoint_logger import from_kwargs as checkpoints_from_kwargs
from helpers.checkpoint_logger.flows import UploadFlow
Expand Down Expand Up @@ -513,6 +516,7 @@ def run_impl_within_lock(
argument_list,
commit_report,
upload_context,
db_session,
checkpoints,
)
else:
Expand Down Expand Up @@ -580,6 +584,7 @@ def schedule_task(
argument_list,
commit_report: CommitReport,
upload_context: UploadContext,
db_session,
checkpoints=None,
):
commit_yaml = commit_yaml.to_dict()
Expand All @@ -595,6 +600,7 @@ def schedule_task(
argument_list,
commit_report,
upload_context,
db_session,
checkpoints=checkpoints,
)
elif commit_report.report_type == ReportType.BUNDLE_ANALYSIS.value:
Expand Down Expand Up @@ -628,6 +634,7 @@ def _schedule_coverage_processing_task(
argument_list,
commit_report,
upload_context: UploadContext,
db_session,
checkpoints=None,
):
chunk_size = CHUNK_SIZE
Expand Down Expand Up @@ -662,33 +669,90 @@ def _schedule_coverage_processing_task(
redis_key = get_parallel_upload_processing_session_counter_redis_key(
repoid=commit.repository.repoid, commitid=commit.commitid
)

report_service = ReportService(commit_yaml)
sessions = report_service.build_sessions(commit=commit)
commit_yaml = UserYaml(commit_yaml)
# if session count expired due to TTL (which is unlikely for most cases), recalculate the
# session ids used and set it in redis.
if self.parallel_session_count_key_expired(
redis_key, upload_context.redis_connection
):
report_service = ReportService(commit_yaml)
sessions = report_service.build_sessions(commit=commit)
upload_context.redis_connection.set(
redis_key,
max(sessions.keys()) + 1 if sessions.keys() else 0,
)

# increment redis to claim session ids
parallel_session_id = (
upload_context.redis_connection.incrby(
name=redis_key,
amount=num_sessions,
)
- num_sessions
)
upload_context.redis_connection.expire(
name=redis_key,
time=PARALLEL_UPLOAD_PROCESSING_SESSION_COUNTER_TTL,
)
# try to scrap the redis counter idea to fully mimic how session ids are allocated in the
# serial flow. This change is technically less performant, and would not allow for concurrent
# chords to be running at the same time. For now this is just a temporary change, just for
# verifying correctness.
#
# # increment redis to claim session ids
# parallel_session_id = (
# upload_context.redis_connection.incrby(
# name=redis_key,
# amount=num_sessions,
# )
# - num_sessions
# )
# upload_context.redis_connection.expire(
# name=redis_key,
# time=PARALLEL_UPLOAD_PROCESSING_SESSION_COUNTER_TTL,
# )

# copied from shared/reports/resources.py Report.next_session_number()
def next_session_number(session_dict):
start_number = len(session_dict)
while start_number in session_dict or str(start_number) in session_dict:
start_number += 1
return start_number

# copied and cut down from worker/services/report/raw_upload_processor.py
# this version stripped out all the ATS label stuff
def _adjust_sessions(
original_sessions: dict,
to_merge_flags,
current_yaml,
):
session_ids_to_fully_delete = []
flags_under_carryforward_rules = [
f for f in to_merge_flags if current_yaml.flag_has_carryfoward(f)
]
if flags_under_carryforward_rules:
for sess_id, curr_sess in original_sessions.items():
if curr_sess.session_type == SessionType.carriedforward:
if curr_sess.flags:
if any(
f in flags_under_carryforward_rules
for f in curr_sess.flags
):
session_ids_to_fully_delete.append(sess_id)
if session_ids_to_fully_delete:
# delete sessions from dict
for id in session_ids_to_fully_delete:
original_sessions.pop(id, None)
return

mock_sessions = copy.deepcopy(sessions)
session_ids_for_parallel_idx = []

# iterate over all uploads, get the next session id, and adjust sessions (remove CFF logic)
for i in range(num_sessions):
next_session_id = next_session_number(mock_sessions)

upload_pk = argument_list[i]["upload_pk"]
upload = db_session.query(Upload).filter_by(id_=upload_pk).first()
to_merge_session = report_service.build_session(upload)
flags = upload.flag_names

mock_sessions[next_session_id] = to_merge_session
_adjust_sessions(mock_sessions, flags, commit_yaml)

session_ids_for_parallel_idx.append(next_session_id)

parallel_processing_tasks = []
commit_yaml = commit_yaml.to_dict()

for i in range(0, num_sessions, parallel_chunk_size):
chunk = argument_list[i : i + parallel_chunk_size]
if chunk:
Expand All @@ -700,7 +764,9 @@ def _schedule_coverage_processing_task(
commit_yaml=commit_yaml,
arguments_list=chunk,
report_code=commit_report.code,
parallel_idx=i + parallel_session_id,
parallel_idx=session_ids_for_parallel_idx[
i
], # i + parallel_session_id,
in_parallel=True,
),
)
Expand Down

0 comments on commit 7d9c198

Please sign in to comment.