From a019cda8ac186cded73f69c0ac58f489df156e5a Mon Sep 17 00:00:00 2001 From: Michael Hall Date: Thu, 7 Dec 2023 11:55:30 -0600 Subject: [PATCH 1/3] - Achieved over 80% unit test coverage. Some refactoring could be used to reduce duplicate lines between the RDS and OpenSearch plugin. --- pylot/plugins/opensearch/main.py | 65 ++++++++++-------- .../plugins/opensearch/tests/fake_delete.json | 1 + .../plugins/opensearch/tests/fake_update.json | 1 + pylot/plugins/opensearch/tests/temp.py | 13 ++++ .../opensearch/tests/test_opensearch.py | 68 ++++++++++++++++++- pylot/plugins/rds_lambda/main.py | 51 +++++++------- pylot/plugins/rds_lambda/tests/__init__.py | 0 pylot/plugins/rds_lambda/tests/test_file.json | 3 + pylot/plugins/rds_lambda/tests/test_rds.py | 52 ++++++++++++++ 9 files changed, 201 insertions(+), 53 deletions(-) create mode 100644 pylot/plugins/opensearch/tests/fake_delete.json create mode 100644 pylot/plugins/opensearch/tests/fake_update.json create mode 100644 pylot/plugins/opensearch/tests/temp.py create mode 100644 pylot/plugins/rds_lambda/tests/__init__.py create mode 100644 pylot/plugins/rds_lambda/tests/test_file.json create mode 100644 pylot/plugins/rds_lambda/tests/test_rds.py diff --git a/pylot/plugins/opensearch/main.py b/pylot/plugins/opensearch/main.py index f450f57..ec27251 100644 --- a/pylot/plugins/opensearch/main.py +++ b/pylot/plugins/opensearch/main.py @@ -1,6 +1,7 @@ import json import os import concurrent.futures +import pathlib import boto3 from ..helpers.pylot_helpers import PyLOTHelpers @@ -15,10 +16,10 @@ def read_json_file(filename, **kwargs): return data @staticmethod - def query_opensearch(query_data, record_type, results='query_results.json', terminate_after=100, **kwargs): + def invoke_opensearch_lambda(query_data, record_type, terminate_after, lambda_client=boto3.client('lambda'), **kwargs): lambda_arn = os.getenv('OPENSEARCH_LAMBDA_ARN') if not lambda_arn: - raise Exception('The ARN for the OpenSearch lambda is not defined. Provide it as an environment variable.') + raise ValueError('The ARN for the OpenSearch lambda is not defined. Provide it as an environment variable.') # Invoke OpenSearch lambda payload = { @@ -29,32 +30,45 @@ def query_opensearch(query_data, record_type, results='query_results.json', term } print('Invoking OpenSearch lambda...') - client = boto3.client('lambda') - rsp = client.invoke( + rsp = lambda_client.invoke( FunctionName=lambda_arn, Payload=json.dumps(payload).encode('utf-8') ) + print(rsp) if rsp.get('StatusCode') != 200: raise Exception( f'The OpenSearch lambda failed. Check the Cloudwatch logs for {os.getenv("OPENSEARCH_LAMBDA_ARN")}' ) - # Download results from S3 + return rsp + + @staticmethod + def download_file(bucket, key, results, s3_client=boto3.client('s3')): print('Downloading query results...') - ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8')) - # print(f'ret_dict: {ret_dict}') - s3_client = boto3.client('s3') s3_client.download_file( - Bucket=ret_dict.get('bucket'), - Key=ret_dict.get('key'), + Bucket=bucket, + Key=key, Filename=f'{os.getcwd()}/{results}' ) file = f'{os.getcwd()}/{results}' - print(f'{ret_dict.get("record_count")} {record_type} records obtained: {os.getcwd()}/{results}') return file +def query_opensearch(query_data, record_type, results='query_results.json', terminate_after=100, **kwargs): + open_search = OpenSearch() + if isinstance(query_data, str) and os.path.isfile(query_data): + query_data = open_search.read_json_file(query_data) + + rsp = open_search.invoke_opensearch_lambda(query_data, record_type, terminate_after) + ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8')) + + # Download results from S3 + file = open_search.download_file(ret_dict.get('Bucket'), ret_dict.get('Key'), results) + print(f'{ret_dict.get("record_count")} {record_type} records obtained: {os.getcwd()}/{results}') + return file + + def return_parser(subparsers): subparser = subparsers.add_parser( 'opensearch', @@ -85,14 +99,9 @@ def return_parser(subparsers): group = subparser.add_mutually_exclusive_group(required=True) group.add_argument( '-q', '--query', - help='The name of a file containing an OpenSearch query: .json', - metavar='' - ) - group.add_argument( - '-s', '--query-string', - help='A json query string using the OpenSearch DSL: ' - 'https://opensearch.org/docs/latest/opensearch/query-dsl/index/ ' - 'Example: \'{"query": {"term": {"collectionId": "goesimpacts___1"}}}\'', + help='The name of a file containing an OpenSearch query: .json or a json query string ' + 'using the OpenSearch DSL: \'{"query": {"term": {"collectionId": "goesimpacts___1"}}}\'. ' + 'See: https://opensearch.org/docs/latest/opensearch/query-dsl/index/ ', metavar='' ) @@ -118,9 +127,11 @@ def return_parser(subparsers): def process_update_data(update_data, query_results): - update_file = f'{os.getcwd()}/{update_data}' - print(f'Attempting to update using data file: {update_file} ') - with open(update_file, 'r', encoding='utf-8') as json_file: + if not os.path.isfile(update_data): + update_data = f'{pathlib.Path(__file__).parent.resolve()}/{update_data}' + + print(f'Attempting to update using data file: {update_data} ') + with open(update_data, 'r', encoding='utf-8') as json_file: update_dict = json.load(json_file) for record in query_results: @@ -143,9 +154,9 @@ def update_dictionary(results_dict, update_dict): def bulk_delete_cumulus(delete_file, query_results): cml = PyLOTHelpers().get_cumulus_api_instance() - + if not os.path.isfile(delete_file): + delete_file = f'{pathlib.Path(__file__).parent.resolve()}/{delete_file}' print(f'Opening delete definition: {delete_file}') - delete_file = f'{os.getcwd()}/{delete_file}' with open(delete_file, 'r+', encoding='utf-8') as delete_definition: delete_config = json.load(delete_definition) @@ -203,12 +214,10 @@ def update_cumulus(record_type, query_results): print('Updating complete\n') -def main(record_type, bulk=False, results=None, query=None, query_string=None, update_data=None, delete=None, **kwargs): - query = OpenSearch.read_json_file(query) if query else json.loads(query_string) - results_file = OpenSearch.query_opensearch(query_data=query, record_type=record_type, results=results, **kwargs) +def main(record_type, bulk=False, results=None, query=None, update_data=None, delete=None, **kwargs): + results_file = query_opensearch(query_data=query, record_type=record_type, results=results, **kwargs) if update_data or delete: - # results_file = f'{os.getcwd()}/{}' with open(results_file, 'r', encoding='utf-8') as results_file: query_results = json.load(results_file) diff --git a/pylot/plugins/opensearch/tests/fake_delete.json b/pylot/plugins/opensearch/tests/fake_delete.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/pylot/plugins/opensearch/tests/fake_delete.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/pylot/plugins/opensearch/tests/fake_update.json b/pylot/plugins/opensearch/tests/fake_update.json new file mode 100644 index 0000000..fdd8f52 --- /dev/null +++ b/pylot/plugins/opensearch/tests/fake_update.json @@ -0,0 +1 @@ +{"testField": "testValueUpdated"} \ No newline at end of file diff --git a/pylot/plugins/opensearch/tests/temp.py b/pylot/plugins/opensearch/tests/temp.py new file mode 100644 index 0000000..ffe8925 --- /dev/null +++ b/pylot/plugins/opensearch/tests/temp.py @@ -0,0 +1,13 @@ +import json + +import boto3 + + +def simple(): + client = boto3.client('lambda') + rsp = client.invoke( + FunctionName='', + Payload='' + ) + + ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8')) diff --git a/pylot/plugins/opensearch/tests/test_opensearch.py b/pylot/plugins/opensearch/tests/test_opensearch.py index 7c56c81..fa3af29 100644 --- a/pylot/plugins/opensearch/tests/test_opensearch.py +++ b/pylot/plugins/opensearch/tests/test_opensearch.py @@ -1,11 +1,16 @@ import argparse import os import unittest +from unittest.mock import patch, MagicMock -from pylot.plugins.opensearch.main import return_parser, OpenSearch +from pylot.plugins.opensearch.main import return_parser, OpenSearch, update_dictionary, thread_function, \ + bulk_delete_cumulus, process_update_data, delete_cumulus, update_cumulus, query_opensearch class TestOpenSearch(unittest.TestCase): + def tearDown(self) -> None: + os.environ.pop('OPENSEARCH_LAMBDA_ARN', '') + def test_return_parser(self): parser = argparse.ArgumentParser( usage=' -h to access help for each plugin. \n', @@ -21,5 +26,64 @@ def test_read_json_file(self): data = opensearch.read_json_file(f'{os.path.dirname(os.path.realpath(__file__))}/test_file.json') self.assertEqual(data, {"some": "json"}) - def test_query_opensearch(self): + @patch('json.loads') + @patch('pylot.plugins.opensearch.main.OpenSearch') + def test_query_opensearch(self, mock_opensearch, mock_json_loads): + mock_opensearch.invoke_opensearch_lambda.return_value = '' + mock_opensearch.invoke_opensearch_lambda.return_value = '' + query_opensearch(query_data={}, record_type='') pass + + def test_invoke_opensearch_lambda(self): + opensearch = OpenSearch() + os.environ['OPENSEARCH_LAMBDA_ARN'] = 'FAKE_ARN' + mock_client = MagicMock() + mock_client.invoke.return_value = {'StatusCode': 200} + opensearch.invoke_opensearch_lambda( + query_data={}, record_type='', terminate_after=0, lambda_client=mock_client + ) + + def test_invoke_opensearch_lambda_arn_exception(self): + opensearch = OpenSearch() + mock_client = MagicMock() + with self.assertRaises(ValueError) as context: + opensearch.invoke_opensearch_lambda( + query_data={}, record_type='', terminate_after=0, lambda_client=mock_client + ) + self.assertTrue('The ARN for the OpenSearch lambda is not defined' in context.exception) + + def test_download_file(self): + opensearch = OpenSearch() + opensearch.download_file(bucket='', key='', results='', s3_client=MagicMock()) + + def test_process_update_data(self): + query_res = [{'productVolume': 1, 'testField': 'testValue'}] + updated_res = process_update_data('/tests/fake_update.json', query_res) + expected = [{'productVolume': '1', 'testField': 'testValueUpdated'}] + self.assertEqual(updated_res, expected) + + def test_update_dictionary(self): + target = {'key_1': {'key_2': 'value_1'}, 'key_3': 'value_3'} + update = {'key_1': {'key_2': 'value_1_updated'}, 'key_4': 'value_4'} + target = update_dictionary(target, update) + print(target) + expected = {'key_1': {'key_2': 'value_1_updated'}, 'key_3': 'value_3', 'key_4': 'value_4'} + self.assertEqual(target, expected) + + @patch('pylot.plugins.helpers.pylot_helpers.PyLOTHelpers.get_cumulus_api_instance') + def test_bulk_delete_cumulus(self, gcapi): + bulk_delete_cumulus('/tests/fake_delete.json', [{'granuleId': 'fake_granuleId'}]) + + def test_thread_function(self): + thread_function(print, [1, 2, 3]) + + @patch('pylot.plugins.helpers.pylot_helpers.PyLOTHelpers.get_cumulus_api_instance') + def test_delete_cumulus(self, gcapi): + query_results = [{'productVolume': 1}] + delete_cumulus(query_results) + + @patch('pylot.plugins.helpers.pylot_helpers.PyLOTHelpers.get_cumulus_api_instance') + def test_update_cumulus(self, gcapi): + query_res = [{'record': 'value'}] + update_cumulus('test', query_res) + diff --git a/pylot/plugins/rds_lambda/main.py b/pylot/plugins/rds_lambda/main.py index c67c365..a58f79f 100644 --- a/pylot/plugins/rds_lambda/main.py +++ b/pylot/plugins/rds_lambda/main.py @@ -13,44 +13,51 @@ def read_json_file(filename, **kwargs): return data @staticmethod - def query_rds(payload, results='query_results.json', **kwargs): + def invoke_rds_lambda(query_data, lambda_client=boto3.client('lambda'), **kwargs): lambda_arn = os.getenv('RDS_LAMBDA_ARN') if not lambda_arn: - raise Exception('The ARN for the RDS lambda is not defined. Provide it as an environment variable.') + raise ValueError('The ARN for the RDS lambda is not defined. Provide it as an environment variable.') # Invoke RDS lambda print('Invoking RDS lambda...') - client = boto3.client('lambda') - rsp = client.invoke( + rsp = lambda_client.invoke( FunctionName=lambda_arn, - Payload=json.dumps(payload).encode('utf-8') + Payload=json.dumps(query_data).encode('utf-8') ) if rsp.get('StatusCode') != 200: raise Exception( f'The RDS lambda failed. Check the Cloudwatch logs for {os.getenv("RDS_LAMBDA_ARN")}' ) - # Download results from S3 - ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8')) - print(f'Query matched {ret_dict.get("count")} records.') - print(f'Downloading results to file: {os.getcwd()}/{results}') - s3_client = boto3.client('s3') + return rsp + + @staticmethod + def download_file(bucket, key, results, s3_client=boto3.client('s3')): + print('Downloading query results...') s3_client.download_file( - Bucket=ret_dict.get('bucket'), - Key=ret_dict.get('key'), + Bucket=bucket, + Key=key, Filename=f'{os.getcwd()}/{results}' ) - print(f'Deleting remote S3 results file: {ret_dict.get("bucket")}/{ret_dict.get("key")}') - s3_client.delete_object( - Bucket=ret_dict.get('bucket'), - Key=ret_dict.get('key') - ) file = f'{os.getcwd()}/{results}' - return file +def query_rds(query_data, results='query_results.json', **kwargs): + rds = QueryRDS() + if isinstance(query_data, str) and os.path.isfile(query_data): + query_data = rds.read_json_file(query_data) + + rsp = rds.invoke_rds_lambda(query_data) + ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8')) + + # Download results from S3 + file = rds.download_file(ret_dict.get('Bucket'), ret_dict.get('Key'), results) + print(f'{ret_dict.get("record_count")} records obtained: {os.getcwd()}/{results}') + return file + + def return_parser(subparsers): query = { "rds_config": { @@ -103,10 +110,8 @@ def return_parser(subparsers): ) -def main(record_type, results=None, query=None, query_string=None, **kwargs): - rds = QueryRDS() - query = rds.read_json_file(query) if query else json.loads(query_string) - print(f'Using query: {json.dumps(query)}') - rds.query_rds(payload=query) +def main(query=None, **kwargs): + query_rds(query_data=query) print('Complete') + return 0 diff --git a/pylot/plugins/rds_lambda/tests/__init__.py b/pylot/plugins/rds_lambda/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pylot/plugins/rds_lambda/tests/test_file.json b/pylot/plugins/rds_lambda/tests/test_file.json new file mode 100644 index 0000000..0294a88 --- /dev/null +++ b/pylot/plugins/rds_lambda/tests/test_file.json @@ -0,0 +1,3 @@ +{ + "some": "json" +} \ No newline at end of file diff --git a/pylot/plugins/rds_lambda/tests/test_rds.py b/pylot/plugins/rds_lambda/tests/test_rds.py new file mode 100644 index 0000000..fd0b62d --- /dev/null +++ b/pylot/plugins/rds_lambda/tests/test_rds.py @@ -0,0 +1,52 @@ +import argparse +import os +import unittest +from unittest.mock import patch, MagicMock + +from pylot.plugins.rds_lambda.main import return_parser, QueryRDS, query_rds + + +class TestRDS(unittest.TestCase): + def tearDown(self) -> None: + os.environ.pop('RDS_LAMBDA_ARN', '') + + def test_return_parser(self): + parser = argparse.ArgumentParser( + usage=' -h to access help for each plugin. \n', + description='PyLOT command line utility.' + ) + + # load plugin parsers + subparsers = parser.add_subparsers(title='plugins', dest='command', required=True) + return_parser(subparsers) + + def test_read_json_file(self): + rds = QueryRDS() + data = rds.read_json_file(f'{os.path.dirname(os.path.realpath(__file__))}/test_file.json') + self.assertEqual(data, {"some": "json"}) + + @patch('json.loads') + @patch('pylot.plugins.rds_lambda.main.QueryRDS') + def test_query_rds(self, mock_opensearch, mock_json_loads): + mock_opensearch.invoke_rds_lambda.return_value = '' + mock_opensearch.invoke_rds_lambda.return_value = '' + query_rds(query_data={}, record_type='') + pass + + def test_invoke_rds_lambda(self): + rds = QueryRDS() + os.environ['RDS_LAMBDA_ARN'] = 'FAKE_ARN' + mock_client = MagicMock() + mock_client.invoke.return_value = {'StatusCode': 200} + rds.invoke_rds_lambda(query_data={}, lambda_client=mock_client) + + def test_invoke_rds_lambda_arn_exception(self): + rds = QueryRDS() + mock_client = MagicMock() + with self.assertRaises(ValueError) as context: + rds.invoke_rds_lambda(query_data={}, lambda_client=mock_client) + self.assertTrue('The ARN for the RDS lambda is not defined' in context.exception) + + def test_download_file(self): + rds = QueryRDS() + rds.download_file(bucket='', key='', results='', s3_client=MagicMock()) From a4076939102921bd7f7409d8fb64acc15f16bd03 Mon Sep 17 00:00:00 2001 From: Michael Hall Date: Thu, 7 Dec 2023 13:18:35 -0600 Subject: [PATCH 2/3] - Removed uneeded CLI variables. --- pylot/plugins/opensearch/main.py | 4 ++- pylot/plugins/rds_lambda/main.py | 57 +++++++++++--------------------- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/pylot/plugins/opensearch/main.py b/pylot/plugins/opensearch/main.py index ec27251..f657a4b 100644 --- a/pylot/plugins/opensearch/main.py +++ b/pylot/plugins/opensearch/main.py @@ -59,12 +59,14 @@ def query_opensearch(query_data, record_type, results='query_results.json', term open_search = OpenSearch() if isinstance(query_data, str) and os.path.isfile(query_data): query_data = open_search.read_json_file(query_data) + else: + query_data = json.loads(query_data) rsp = open_search.invoke_opensearch_lambda(query_data, record_type, terminate_after) ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8')) # Download results from S3 - file = open_search.download_file(ret_dict.get('Bucket'), ret_dict.get('Key'), results) + file = open_search.download_file(bucket=ret_dict.get('bucket'), key=ret_dict.get('key'), results=results) print(f'{ret_dict.get("record_count")} {record_type} records obtained: {os.getcwd()}/{results}') return file diff --git a/pylot/plugins/rds_lambda/main.py b/pylot/plugins/rds_lambda/main.py index a58f79f..e76cc8a 100644 --- a/pylot/plugins/rds_lambda/main.py +++ b/pylot/plugins/rds_lambda/main.py @@ -48,24 +48,30 @@ def query_rds(query_data, results='query_results.json', **kwargs): rds = QueryRDS() if isinstance(query_data, str) and os.path.isfile(query_data): query_data = rds.read_json_file(query_data) + else: + query_data = json.loads(query_data) + + query_data = {'rds_config': query_data, 'is_test': True} + rsp = rds.invoke_rds_lambda(query_data) ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8')) # Download results from S3 - file = rds.download_file(ret_dict.get('Bucket'), ret_dict.get('Key'), results) - print(f'{ret_dict.get("record_count")} records obtained: {os.getcwd()}/{results}') + file = rds.download_file(bucket=ret_dict.get('bucket'), key=ret_dict.get('key'), results=results) + print( + f'{ret_dict.get("count")} {query_data.get("rds_config").get("records")} records obtained: ' + f'{os.getcwd()}/{results}' + ) return file def return_parser(subparsers): query = { - "rds_config": { - "records": "granules", - "where": "name LIKE nalma% ", - "columns": ["granule_id", "status"], - "limit": 10 - } + "records": "granules", + "where": "name LIKE nalma% ", + "columns": ["granule_id", "status"], + "limit": 10 } subparser = subparsers.add_parser( 'rds_lambda', @@ -73,19 +79,11 @@ def return_parser(subparsers): description='Submit queries to the Cumulus RDS instance.\n' f'Example query: {json.dumps(query)}' ) - choices = ['granules', 'collections', 'providers', 'pdrs', 'rules', 'logs', 'executions', 'reconciliationReport'] - choice_str = str(choices).strip('[').strip(']').replace("'", '') subparser.add_argument( - 'record_type', - help=f'The RDS table to be queried: {choice_str}', - metavar='record_type', - choices=choices - ) - subparser.add_argument( - '-l', '--limit', - help='Limit the number of records returned from RDS. Default is 100. Use 0 to retrieve all matches.', - metavar='', - default=100 + 'query', + help='A file containing an RDS Lambda query: .json or a json query string ' + 'using the RDS DSL syntax: https://github.com/ghrcdaac/ghrc_rds_lambda?tab=readme-ov-file#querying', + metavar='query' ) subparser.add_argument( '-r', '--results', @@ -94,24 +92,9 @@ def return_parser(subparsers): default='query_results.json' ) - group = subparser.add_mutually_exclusive_group(required=True) - group.add_argument( - '-q', '--query', - help='The name of a file containing an RDS lambda query: .json', - metavar='' - ) - group.add_argument( - '-s', '--query-string', - help='A json query string using the RDS DSL: ' - 'https://github.com/ghrcdaac/ghrc_rds_lambda#querying ' - 'Example: ' - '\'{"rds_config": {"records": "", "columns": ["granule_id"], "where": "name=nalmaraw", "limit": 0}}\'', - metavar='' - ) - -def main(query=None, **kwargs): - query_rds(query_data=query) +def main(query=None, records=None, **kwargs): + query_rds(query_data=query, record_type=records) print('Complete') return 0 From b9b04fd695313672076805b5d087532defbe14a9 Mon Sep 17 00:00:00 2001 From: Michael Hall Date: Thu, 7 Dec 2023 13:23:00 -0600 Subject: [PATCH 3/3] - Fix should resolve error where a non-mock s3 client is initialzed. --- pylot/plugins/opensearch/main.py | 12 ++++++++---- pylot/plugins/rds_lambda/main.py | 10 ++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pylot/plugins/opensearch/main.py b/pylot/plugins/opensearch/main.py index f657a4b..de41f46 100644 --- a/pylot/plugins/opensearch/main.py +++ b/pylot/plugins/opensearch/main.py @@ -15,8 +15,11 @@ def read_json_file(filename, **kwargs): return data - @staticmethod - def invoke_opensearch_lambda(query_data, record_type, terminate_after, lambda_client=boto3.client('lambda'), **kwargs): + def invoke_opensearch_lambda( + self, query_data, record_type, terminate_after, lambda_client=None, **kwargs + ): + if not lambda_client: + lambda_client = boto3.client('lambda') lambda_arn = os.getenv('OPENSEARCH_LAMBDA_ARN') if not lambda_arn: raise ValueError('The ARN for the OpenSearch lambda is not defined. Provide it as an environment variable.') @@ -42,8 +45,9 @@ def invoke_opensearch_lambda(query_data, record_type, terminate_after, lambda_cl return rsp - @staticmethod - def download_file(bucket, key, results, s3_client=boto3.client('s3')): + def download_file(self, bucket, key, results, s3_client=None): + if not s3_client: + s3_client = boto3.client('s3') print('Downloading query results...') s3_client.download_file( Bucket=bucket, diff --git a/pylot/plugins/rds_lambda/main.py b/pylot/plugins/rds_lambda/main.py index e76cc8a..5d3700a 100644 --- a/pylot/plugins/rds_lambda/main.py +++ b/pylot/plugins/rds_lambda/main.py @@ -12,8 +12,9 @@ def read_json_file(filename, **kwargs): return data - @staticmethod - def invoke_rds_lambda(query_data, lambda_client=boto3.client('lambda'), **kwargs): + def invoke_rds_lambda(self, query_data, lambda_client=None, **kwargs): + if not lambda_client: + lambda_client = boto3.client('lambda') lambda_arn = os.getenv('RDS_LAMBDA_ARN') if not lambda_arn: raise ValueError('The ARN for the RDS lambda is not defined. Provide it as an environment variable.') @@ -31,8 +32,9 @@ def invoke_rds_lambda(query_data, lambda_client=boto3.client('lambda'), **kwargs return rsp - @staticmethod - def download_file(bucket, key, results, s3_client=boto3.client('s3')): + def download_file(self, bucket, key, results, s3_client=None): + if not s3_client: + s3_client = boto3.client('s3') print('Downloading query results...') s3_client.download_file( Bucket=bucket,