diff --git a/panoptes_aggregation/batch_aggregation.py b/panoptes_aggregation/batch_aggregation.py index fe78edb0..c3d26139 100644 --- a/panoptes_aggregation/batch_aggregation.py +++ b/panoptes_aggregation/batch_aggregation.py @@ -2,6 +2,7 @@ import json import pandas as pd import os +import sys import urllib3 from shutil import make_archive import uuid @@ -12,9 +13,6 @@ from panoptes_aggregation.workflow_config import workflow_extractor_config from panoptes_aggregation.scripts import batch_utils -import logging -panoptes_client_logger = logging.getLogger('panoptes_client').setLevel(logging.ERROR) - celery = Celery(__name__) celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379") @@ -23,6 +21,12 @@ @celery.task(name="run_aggregation") def run_aggregation(project_id, workflow_id, user_id): ba = BatchAggregator(project_id, workflow_id, user_id) + + if not ba.check_permission(): + print(f'Batch Aggregation: Unauthorized attempt by user {user_id} to aggregate workflow {workflow_id}') + # Exit the task gracefully without retrying or erroring + sys.exit() + ba.save_exports() ba.process_wf_export(ba.wf_csv) @@ -48,11 +52,16 @@ def run_aggregation(project_id, workflow_id, user_id): reduced_data[reducer] = batch_utils.batch_reduce(extract_df, reducer_config) filename = f'{ba.output_path}/{ba.workflow_id}_reductions.csv' reduced_data[reducer].to_csv(filename, mode='a') + + # Upload zip & reduction files to blob storage ba.upload_files() # This could catch PanoptesAPIException, but what to do if it fails? - ba.update_panoptes() + success_attrs = {'uuid': ba.id, 'status': 'completed'} + ba.update_panoptes(success_attrs) + # STDOUT messages get printed to kubernetes logs + print(f'Batch Aggregation: Run successful for workflow {workflow_id} by user {user_id}') class BatchAggregator: """ @@ -116,23 +125,27 @@ def upload_files(self): zipfile = make_archive(f'tmp/{self.id}', 'zip', self.output_path) self.upload_file_to_storage(self.id, zipfile) - def update_panoptes(self): + def update_panoptes(self, body_attributes): # An Aggregation class can be added to the python client to avoid doing this manually - params = {'workflow_id': self.workflow_id, 'user_id': self.user_id} - response = Panoptes.client().get('/aggregations/', params=params) + params = {'workflow_id': self.workflow_id} + response = Panoptes.client().get('/aggregations', params=params) + agg_id = response[0]['aggregations'][0]['id'] fresh_etag = response[1] Panoptes.client().put( - '/aggregations/', + f'/aggregations/{agg_id}', etag=fresh_etag, - json={ - 'aggregations': { - 'uuid': self.id, - 'status': 'completed' - } - } + json={'aggregations': body_attributes} ) + def check_permission(self): + project = Project.find(self.project_id) + permission = False + for user in project.collaborators(): + if user.id == self.user_id: + permission = True + return permission + def _generate_uuid(self): self.id = uuid.uuid4().hex diff --git a/panoptes_aggregation/tests/batch_aggregation/test_batch_aggregation.py b/panoptes_aggregation/tests/batch_aggregation/test_batch_aggregation.py index 5e4d23de..d2bd8b42 100644 --- a/panoptes_aggregation/tests/batch_aggregation/test_batch_aggregation.py +++ b/panoptes_aggregation/tests/batch_aggregation/test_batch_aggregation.py @@ -10,9 +10,21 @@ @patch("panoptes_aggregation.batch_aggregation.BatchAggregator._connect_api_client", new=MagicMock()) class TestBatchAggregation(unittest.TestCase): + @patch("panoptes_aggregation.batch_aggregation.BatchAggregator") + def test_run_aggregation_permission_failure(self, mock_aggregator): + mock_aggregator_instance = mock_aggregator.return_value + mock_aggregator_instance.check_permission.return_value = False + + with self.assertRaises(SystemExit) as leaver: + run_aggregation(1, 10, 100) + mock_aggregator_instance.update_panoptes.assert_not_called() + @patch("panoptes_aggregation.batch_aggregation.workflow_extractor_config") @patch("panoptes_aggregation.batch_aggregation.BatchAggregator") - def test_run_aggregation(self, mock_aggregator, mock_wf_ext_conf): + def test_run_aggregation_success(self, mock_aggregator, mock_wf_ext_conf): + mock_aggregator_instance = mock_aggregator.return_value + mock_aggregator_instance.check_permission.return_value = True + mock_df = MagicMock() test_extracts = {'question_extractor': mock_df} batch_utils.batch_extract = MagicMock(return_value=test_extracts) @@ -20,26 +32,15 @@ def test_run_aggregation(self, mock_aggregator, mock_wf_ext_conf): batch_utils.batch_reduce = mock_reducer run_aggregation(1, 10, 100) + mock_aggregator_instance.check_permission.assert_called_once() mock_aggregator.assert_called_once_with(1, 10, 100) mock_wf_ext_conf.assert_called_once() batch_utils.batch_extract.assert_called_once() mock_df.to_csv.assert_called() batch_utils.batch_reduce.assert_called() self.assertEqual(mock_reducer.call_count, 2) - - # The reducer's call list includes subsequent calls to to_csv, but the args are methods called on the mock - # rather than use the set values i.e. "" - # mock_aggregator.workflow_id = '10' - # mock_aggregator.output_path = 'tmp/10' - # mock_reducer.assert_has_calls([ - # call(mock_df, {'reducer_config': {'question_reducer': {}}}), - # call().to_csv('tmp/10/10_reducers.csv', mode='a'), - # call(mock_df, {'reducer_config': {'question_consensus_reducer': {}}}), - # call().to_csv('tmp/10/10_reducers.csv', mode='a'), - # ]) - - # How do I test the specific instance of BatchAggregator rather than the mocked class? - # mock_aggregator.upload_files.assert_called_once() + mock_aggregator_instance.upload_files.assert_called_once() + mock_aggregator_instance.update_panoptes.assert_called_once() @patch("panoptes_aggregation.batch_aggregation.os.mkdir") @patch("panoptes_aggregation.batch_aggregation.Workflow") @@ -98,14 +99,53 @@ def test_upload_file_to_storage(self): ba.upload_file_to_storage('container', cls_export) mock_client.upload_blob.assert_called_once + @patch("panoptes_aggregation.batch_aggregation.Project") + def test_check_permission_success(self, mock_project): + mock_user = MagicMock() + mock_user.id = 100 + mock_project.find().collaborators.return_value = [mock_user] + + ba = batch_agg.BatchAggregator(1, 10, 100) + ba.check_permission() + mock_project.find.assert_called_with(1) + mock_project.find().collaborators.assert_called() + self.assertEqual(ba.check_permission(), True) + + @patch("panoptes_aggregation.batch_aggregation.Project") + def test_check_permission_failure(self, mock_project): + mock_user = MagicMock() + + # List of collaborators does not include initiating user + mock_user.id = 999 + mock_project.find().collaborators.return_value = [mock_user] + + ba = batch_agg.BatchAggregator(1, 10, 100) + ba.update_panoptes = MagicMock() + ba.check_permission() + mock_project.find.assert_called_with(1) + mock_project.find().collaborators.assert_called() + self.assertEqual(ba.check_permission(), False) + ba.update_panoptes.assert_not_called() + + @patch("panoptes_aggregation.batch_aggregation.Panoptes.put") + @patch("panoptes_aggregation.batch_aggregation.Panoptes.get") + def test_update_panoptes_success(self, mock_get, mock_put): + ba = batch_agg.BatchAggregator(1, 10, 100) + mock_get.return_value = ({'aggregations': [{'id': 5555}]}, 'thisisanetag') + body = {'uuid': ba.id, 'status': 'completed'} + ba.update_panoptes(body) + mock_get.assert_called_with('/aggregations', params={'workflow_id': 10}) + mock_put.assert_called_with('/aggregations/5555', etag='thisisanetag', json={'aggregations': body }) + @patch("panoptes_aggregation.batch_aggregation.Panoptes.put") @patch("panoptes_aggregation.batch_aggregation.Panoptes.get") - def test_update_panoptes(self, mock_get, mock_put): + def test_update_panoptes_failure(self, mock_get, mock_put): ba = batch_agg.BatchAggregator(1, 10, 100) - mock_get.return_value = ({}, 'thisisanetag') - ba.update_panoptes() - mock_get.assert_called_with('/aggregations/', params={'workflow_id': 10, 'user_id': 100}) - mock_put.assert_called_with('/aggregations/', etag='thisisanetag', json={'aggregations': {'uuid': ba.id, 'status': 'completed'}}) + mock_get.return_value = ({'aggregations': [{'id': 5555}]}, 'thisisanetag') + body = {'status': 'failure'} + ba.update_panoptes(body) + mock_get.assert_called_with('/aggregations', params={'workflow_id': 10}) + mock_put.assert_called_with('/aggregations/5555', etag='thisisanetag', json={'aggregations': body }) @patch("panoptes_aggregation.batch_aggregation.BlobServiceClient") def test_connect_blob_storage(self, mock_client):