Skip to content

Commit

Permalink
Add permissions checking, fix some specs, refactor Panoptes update
Browse files Browse the repository at this point in the history
  • Loading branch information
zwolf committed Jun 4, 2024
1 parent 8b3db20 commit c2ec0ce
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 34 deletions.
41 changes: 27 additions & 14 deletions panoptes_aggregation/batch_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import pandas as pd
import os
import sys
import urllib3
from shutil import make_archive
import uuid
Expand All @@ -12,9 +13,6 @@
from panoptes_aggregation.workflow_config import workflow_extractor_config
from panoptes_aggregation.scripts import batch_utils

import logging
panoptes_client_logger = logging.getLogger('panoptes_client').setLevel(logging.ERROR)

celery = Celery(__name__)
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
Expand All @@ -23,6 +21,12 @@
@celery.task(name="run_aggregation")
def run_aggregation(project_id, workflow_id, user_id):
ba = BatchAggregator(project_id, workflow_id, user_id)

if not ba.check_permission():
print(f'Batch Aggregation: Unauthorized attempt by user {user_id} to aggregate workflow {workflow_id}')
# Exit the task gracefully without retrying or erroring
sys.exit()

ba.save_exports()

ba.process_wf_export(ba.wf_csv)
Expand All @@ -48,11 +52,16 @@ def run_aggregation(project_id, workflow_id, user_id):
reduced_data[reducer] = batch_utils.batch_reduce(extract_df, reducer_config)
filename = f'{ba.output_path}/{ba.workflow_id}_reductions.csv'
reduced_data[reducer].to_csv(filename, mode='a')

# Upload zip & reduction files to blob storage
ba.upload_files()

# This could catch PanoptesAPIException, but what to do if it fails?
ba.update_panoptes()
success_attrs = {'uuid': ba.id, 'status': 'completed'}
ba.update_panoptes(success_attrs)

# STDOUT messages get printed to kubernetes logs
print(f'Batch Aggregation: Run successful for workflow {workflow_id} by user {user_id}')

class BatchAggregator:
"""
Expand Down Expand Up @@ -116,23 +125,27 @@ def upload_files(self):
zipfile = make_archive(f'tmp/{self.id}', 'zip', self.output_path)
self.upload_file_to_storage(self.id, zipfile)

def update_panoptes(self):
def update_panoptes(self, body_attributes):
# An Aggregation class can be added to the python client to avoid doing this manually
params = {'workflow_id': self.workflow_id, 'user_id': self.user_id}
response = Panoptes.client().get('/aggregations/', params=params)
params = {'workflow_id': self.workflow_id}
response = Panoptes.client().get('/aggregations', params=params)
agg_id = response[0]['aggregations'][0]['id']
fresh_etag = response[1]

Panoptes.client().put(
'/aggregations/',
f'/aggregations/{agg_id}',
etag=fresh_etag,
json={
'aggregations': {
'uuid': self.id,
'status': 'completed'
}
}
json={'aggregations': body_attributes}
)

def check_permission(self):
project = Project.find(self.project_id)
permission = False
for user in project.collaborators():
if user.id == self.user_id:
permission = True
return permission

def _generate_uuid(self):
self.id = uuid.uuid4().hex

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,37 @@

@patch("panoptes_aggregation.batch_aggregation.BatchAggregator._connect_api_client", new=MagicMock())
class TestBatchAggregation(unittest.TestCase):
@patch("panoptes_aggregation.batch_aggregation.BatchAggregator")
def test_run_aggregation_permission_failure(self, mock_aggregator):
mock_aggregator_instance = mock_aggregator.return_value
mock_aggregator_instance.check_permission.return_value = False

with self.assertRaises(SystemExit) as leaver:
run_aggregation(1, 10, 100)
mock_aggregator_instance.update_panoptes.assert_not_called()

@patch("panoptes_aggregation.batch_aggregation.workflow_extractor_config")
@patch("panoptes_aggregation.batch_aggregation.BatchAggregator")
def test_run_aggregation(self, mock_aggregator, mock_wf_ext_conf):
def test_run_aggregation_success(self, mock_aggregator, mock_wf_ext_conf):
mock_aggregator_instance = mock_aggregator.return_value
mock_aggregator_instance.check_permission.return_value = True

mock_df = MagicMock()
test_extracts = {'question_extractor': mock_df}
batch_utils.batch_extract = MagicMock(return_value=test_extracts)
mock_reducer = MagicMock()
batch_utils.batch_reduce = mock_reducer

run_aggregation(1, 10, 100)
mock_aggregator_instance.check_permission.assert_called_once()
mock_aggregator.assert_called_once_with(1, 10, 100)
mock_wf_ext_conf.assert_called_once()
batch_utils.batch_extract.assert_called_once()
mock_df.to_csv.assert_called()
batch_utils.batch_reduce.assert_called()
self.assertEqual(mock_reducer.call_count, 2)

# The reducer's call list includes subsequent calls to to_csv, but the args are methods called on the mock
# rather than use the set values i.e. "<MagicMock name='BatchAggregator().output_path' id='140281634764400'>"
# mock_aggregator.workflow_id = '10'
# mock_aggregator.output_path = 'tmp/10'
# mock_reducer.assert_has_calls([
# call(mock_df, {'reducer_config': {'question_reducer': {}}}),
# call().to_csv('tmp/10/10_reducers.csv', mode='a'),
# call(mock_df, {'reducer_config': {'question_consensus_reducer': {}}}),
# call().to_csv('tmp/10/10_reducers.csv', mode='a'),
# ])

# How do I test the specific instance of BatchAggregator rather than the mocked class?
# mock_aggregator.upload_files.assert_called_once()
mock_aggregator_instance.upload_files.assert_called_once()
mock_aggregator_instance.update_panoptes.assert_called_once()

@patch("panoptes_aggregation.batch_aggregation.os.mkdir")
@patch("panoptes_aggregation.batch_aggregation.Workflow")
Expand Down Expand Up @@ -98,14 +99,53 @@ def test_upload_file_to_storage(self):
ba.upload_file_to_storage('container', cls_export)
mock_client.upload_blob.assert_called_once

@patch("panoptes_aggregation.batch_aggregation.Project")
def test_check_permission_success(self, mock_project):
mock_user = MagicMock()
mock_user.id = 100
mock_project.find().collaborators.return_value = [mock_user]

ba = batch_agg.BatchAggregator(1, 10, 100)
ba.check_permission()
mock_project.find.assert_called_with(1)
mock_project.find().collaborators.assert_called()
self.assertEqual(ba.check_permission(), True)

@patch("panoptes_aggregation.batch_aggregation.Project")
def test_check_permission_failure(self, mock_project):
mock_user = MagicMock()

# List of collaborators does not include initiating user
mock_user.id = 999
mock_project.find().collaborators.return_value = [mock_user]

ba = batch_agg.BatchAggregator(1, 10, 100)
ba.update_panoptes = MagicMock()
ba.check_permission()
mock_project.find.assert_called_with(1)
mock_project.find().collaborators.assert_called()
self.assertEqual(ba.check_permission(), False)
ba.update_panoptes.assert_not_called()

@patch("panoptes_aggregation.batch_aggregation.Panoptes.put")
@patch("panoptes_aggregation.batch_aggregation.Panoptes.get")
def test_update_panoptes_success(self, mock_get, mock_put):
ba = batch_agg.BatchAggregator(1, 10, 100)
mock_get.return_value = ({'aggregations': [{'id': 5555}]}, 'thisisanetag')
body = {'uuid': ba.id, 'status': 'completed'}
ba.update_panoptes(body)
mock_get.assert_called_with('/aggregations', params={'workflow_id': 10})
mock_put.assert_called_with('/aggregations/5555', etag='thisisanetag', json={'aggregations': body })

@patch("panoptes_aggregation.batch_aggregation.Panoptes.put")
@patch("panoptes_aggregation.batch_aggregation.Panoptes.get")
def test_update_panoptes(self, mock_get, mock_put):
def test_update_panoptes_failure(self, mock_get, mock_put):
ba = batch_agg.BatchAggregator(1, 10, 100)
mock_get.return_value = ({}, 'thisisanetag')
ba.update_panoptes()
mock_get.assert_called_with('/aggregations/', params={'workflow_id': 10, 'user_id': 100})
mock_put.assert_called_with('/aggregations/', etag='thisisanetag', json={'aggregations': {'uuid': ba.id, 'status': 'completed'}})
mock_get.return_value = ({'aggregations': [{'id': 5555}]}, 'thisisanetag')
body = {'status': 'failure'}
ba.update_panoptes(body)
mock_get.assert_called_with('/aggregations', params={'workflow_id': 10})
mock_put.assert_called_with('/aggregations/5555', etag='thisisanetag', json={'aggregations': body })

@patch("panoptes_aggregation.batch_aggregation.BlobServiceClient")
def test_connect_blob_storage(self, mock_client):
Expand Down

0 comments on commit c2ec0ce

Please sign in to comment.