Skip to content

Commit

Permalink
Merge pull request #37 from ghrcdaac/mlh0079-5539-unit-test-coverage
Browse files Browse the repository at this point in the history
Mlh0079 5539 unit test coverage
  • Loading branch information
camposeddie authored Dec 11, 2023
2 parents 11d8c28 + b9b04fd commit a68f782
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 88 deletions.
73 changes: 44 additions & 29 deletions pylot/plugins/opensearch/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import concurrent.futures
import pathlib

import boto3
from ..helpers.pylot_helpers import PyLOTHelpers
Expand All @@ -14,11 +15,14 @@ def read_json_file(filename, **kwargs):

return data

@staticmethod
def query_opensearch(query_data, record_type, results='query_results.json', terminate_after=100, **kwargs):
def invoke_opensearch_lambda(
self, query_data, record_type, terminate_after, lambda_client=None, **kwargs
):
if not lambda_client:
lambda_client = boto3.client('lambda')
lambda_arn = os.getenv('OPENSEARCH_LAMBDA_ARN')
if not lambda_arn:
raise Exception('The ARN for the OpenSearch lambda is not defined. Provide it as an environment variable.')
raise ValueError('The ARN for the OpenSearch lambda is not defined. Provide it as an environment variable.')

# Invoke OpenSearch lambda
payload = {
Expand All @@ -29,32 +33,48 @@ def query_opensearch(query_data, record_type, results='query_results.json', term
}

print('Invoking OpenSearch lambda...')
client = boto3.client('lambda')
rsp = client.invoke(
rsp = lambda_client.invoke(
FunctionName=lambda_arn,
Payload=json.dumps(payload).encode('utf-8')
)
print(rsp)
if rsp.get('StatusCode') != 200:
raise Exception(
f'The OpenSearch lambda failed. Check the Cloudwatch logs for {os.getenv("OPENSEARCH_LAMBDA_ARN")}'
)

# Download results from S3
return rsp

def download_file(self, bucket, key, results, s3_client=None):
if not s3_client:
s3_client = boto3.client('s3')
print('Downloading query results...')
ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8'))
# print(f'ret_dict: {ret_dict}')
s3_client = boto3.client('s3')
s3_client.download_file(
Bucket=ret_dict.get('bucket'),
Key=ret_dict.get('key'),
Bucket=bucket,
Key=key,
Filename=f'{os.getcwd()}/{results}'
)

file = f'{os.getcwd()}/{results}'
print(f'{ret_dict.get("record_count")} {record_type} records obtained: {os.getcwd()}/{results}')
return file


def query_opensearch(query_data, record_type, results='query_results.json', terminate_after=100, **kwargs):
open_search = OpenSearch()
if isinstance(query_data, str) and os.path.isfile(query_data):
query_data = open_search.read_json_file(query_data)
else:
query_data = json.loads(query_data)

rsp = open_search.invoke_opensearch_lambda(query_data, record_type, terminate_after)
ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8'))

# Download results from S3
file = open_search.download_file(bucket=ret_dict.get('bucket'), key=ret_dict.get('key'), results=results)
print(f'{ret_dict.get("record_count")} {record_type} records obtained: {os.getcwd()}/{results}')
return file


def return_parser(subparsers):
subparser = subparsers.add_parser(
'opensearch',
Expand Down Expand Up @@ -85,14 +105,9 @@ def return_parser(subparsers):
group = subparser.add_mutually_exclusive_group(required=True)
group.add_argument(
'-q', '--query',
help='The name of a file containing an OpenSearch query: <filename>.json',
metavar=''
)
group.add_argument(
'-s', '--query-string',
help='A json query string using the OpenSearch DSL: '
'https://opensearch.org/docs/latest/opensearch/query-dsl/index/ '
'Example: \'{"query": {"term": {"collectionId": "goesimpacts___1"}}}\'',
help='The name of a file containing an OpenSearch query: <filename>.json or a json query string '
'using the OpenSearch DSL: \'{"query": {"term": {"collectionId": "goesimpacts___1"}}}\'. '
'See: https://opensearch.org/docs/latest/opensearch/query-dsl/index/ ',
metavar=''
)

Expand All @@ -118,9 +133,11 @@ def return_parser(subparsers):


def process_update_data(update_data, query_results):
update_file = f'{os.getcwd()}/{update_data}'
print(f'Attempting to update using data file: {update_file} ')
with open(update_file, 'r', encoding='utf-8') as json_file:
if not os.path.isfile(update_data):
update_data = f'{pathlib.Path(__file__).parent.resolve()}/{update_data}'

print(f'Attempting to update using data file: {update_data} ')
with open(update_data, 'r', encoding='utf-8') as json_file:
update_dict = json.load(json_file)

for record in query_results:
Expand All @@ -143,9 +160,9 @@ def update_dictionary(results_dict, update_dict):

def bulk_delete_cumulus(delete_file, query_results):
cml = PyLOTHelpers().get_cumulus_api_instance()

if not os.path.isfile(delete_file):
delete_file = f'{pathlib.Path(__file__).parent.resolve()}/{delete_file}'
print(f'Opening delete definition: {delete_file}')
delete_file = f'{os.getcwd()}/{delete_file}'
with open(delete_file, 'r+', encoding='utf-8') as delete_definition:
delete_config = json.load(delete_definition)

Expand Down Expand Up @@ -203,12 +220,10 @@ def update_cumulus(record_type, query_results):
print('Updating complete\n')


def main(record_type, bulk=False, results=None, query=None, query_string=None, update_data=None, delete=None, **kwargs):
query = OpenSearch.read_json_file(query) if query else json.loads(query_string)
results_file = OpenSearch.query_opensearch(query_data=query, record_type=record_type, results=results, **kwargs)
def main(record_type, bulk=False, results=None, query=None, update_data=None, delete=None, **kwargs):
results_file = query_opensearch(query_data=query, record_type=record_type, results=results, **kwargs)

if update_data or delete:
# results_file = f'{os.getcwd()}/{}'
with open(results_file, 'r', encoding='utf-8') as results_file:
query_results = json.load(results_file)

Expand Down
1 change: 1 addition & 0 deletions pylot/plugins/opensearch/tests/fake_delete.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
1 change: 1 addition & 0 deletions pylot/plugins/opensearch/tests/fake_update.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"testField": "testValueUpdated"}
13 changes: 13 additions & 0 deletions pylot/plugins/opensearch/tests/temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import json

import boto3


def simple():
client = boto3.client('lambda')
rsp = client.invoke(
FunctionName='',
Payload=''
)

ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8'))
68 changes: 66 additions & 2 deletions pylot/plugins/opensearch/tests/test_opensearch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import argparse
import os
import unittest
from unittest.mock import patch, MagicMock

from pylot.plugins.opensearch.main import return_parser, OpenSearch
from pylot.plugins.opensearch.main import return_parser, OpenSearch, update_dictionary, thread_function, \
bulk_delete_cumulus, process_update_data, delete_cumulus, update_cumulus, query_opensearch


class TestOpenSearch(unittest.TestCase):
def tearDown(self) -> None:
os.environ.pop('OPENSEARCH_LAMBDA_ARN', '')

def test_return_parser(self):
parser = argparse.ArgumentParser(
usage='<plugin> -h to access help for each plugin. \n',
Expand All @@ -21,5 +26,64 @@ def test_read_json_file(self):
data = opensearch.read_json_file(f'{os.path.dirname(os.path.realpath(__file__))}/test_file.json')
self.assertEqual(data, {"some": "json"})

def test_query_opensearch(self):
@patch('json.loads')
@patch('pylot.plugins.opensearch.main.OpenSearch')
def test_query_opensearch(self, mock_opensearch, mock_json_loads):
mock_opensearch.invoke_opensearch_lambda.return_value = ''
mock_opensearch.invoke_opensearch_lambda.return_value = ''
query_opensearch(query_data={}, record_type='')
pass

def test_invoke_opensearch_lambda(self):
opensearch = OpenSearch()
os.environ['OPENSEARCH_LAMBDA_ARN'] = 'FAKE_ARN'
mock_client = MagicMock()
mock_client.invoke.return_value = {'StatusCode': 200}
opensearch.invoke_opensearch_lambda(
query_data={}, record_type='', terminate_after=0, lambda_client=mock_client
)

def test_invoke_opensearch_lambda_arn_exception(self):
opensearch = OpenSearch()
mock_client = MagicMock()
with self.assertRaises(ValueError) as context:
opensearch.invoke_opensearch_lambda(
query_data={}, record_type='', terminate_after=0, lambda_client=mock_client
)
self.assertTrue('The ARN for the OpenSearch lambda is not defined' in context.exception)

def test_download_file(self):
opensearch = OpenSearch()
opensearch.download_file(bucket='', key='', results='', s3_client=MagicMock())

def test_process_update_data(self):
query_res = [{'productVolume': 1, 'testField': 'testValue'}]
updated_res = process_update_data('/tests/fake_update.json', query_res)
expected = [{'productVolume': '1', 'testField': 'testValueUpdated'}]
self.assertEqual(updated_res, expected)

def test_update_dictionary(self):
target = {'key_1': {'key_2': 'value_1'}, 'key_3': 'value_3'}
update = {'key_1': {'key_2': 'value_1_updated'}, 'key_4': 'value_4'}
target = update_dictionary(target, update)
print(target)
expected = {'key_1': {'key_2': 'value_1_updated'}, 'key_3': 'value_3', 'key_4': 'value_4'}
self.assertEqual(target, expected)

@patch('pylot.plugins.helpers.pylot_helpers.PyLOTHelpers.get_cumulus_api_instance')
def test_bulk_delete_cumulus(self, gcapi):
bulk_delete_cumulus('/tests/fake_delete.json', [{'granuleId': 'fake_granuleId'}])

def test_thread_function(self):
thread_function(print, [1, 2, 3])

@patch('pylot.plugins.helpers.pylot_helpers.PyLOTHelpers.get_cumulus_api_instance')
def test_delete_cumulus(self, gcapi):
query_results = [{'productVolume': 1}]
delete_cumulus(query_results)

@patch('pylot.plugins.helpers.pylot_helpers.PyLOTHelpers.get_cumulus_api_instance')
def test_update_cumulus(self, gcapi):
query_res = [{'record': 'value'}]
update_cumulus('test', query_res)

104 changes: 47 additions & 57 deletions pylot/plugins/rds_lambda/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,73 +12,80 @@ def read_json_file(filename, **kwargs):

return data

@staticmethod
def query_rds(payload, results='query_results.json', **kwargs):
def invoke_rds_lambda(self, query_data, lambda_client=None, **kwargs):
if not lambda_client:
lambda_client = boto3.client('lambda')
lambda_arn = os.getenv('RDS_LAMBDA_ARN')
if not lambda_arn:
raise Exception('The ARN for the RDS lambda is not defined. Provide it as an environment variable.')
raise ValueError('The ARN for the RDS lambda is not defined. Provide it as an environment variable.')

# Invoke RDS lambda
print('Invoking RDS lambda...')
client = boto3.client('lambda')
rsp = client.invoke(
rsp = lambda_client.invoke(
FunctionName=lambda_arn,
Payload=json.dumps(payload).encode('utf-8')
Payload=json.dumps(query_data).encode('utf-8')
)
if rsp.get('StatusCode') != 200:
raise Exception(
f'The RDS lambda failed. Check the Cloudwatch logs for {os.getenv("RDS_LAMBDA_ARN")}'
)

# Download results from S3
ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8'))
print(f'Query matched {ret_dict.get("count")} records.')
print(f'Downloading results to file: {os.getcwd()}/{results}')
s3_client = boto3.client('s3')
return rsp

def download_file(self, bucket, key, results, s3_client=None):
if not s3_client:
s3_client = boto3.client('s3')
print('Downloading query results...')
s3_client.download_file(
Bucket=ret_dict.get('bucket'),
Key=ret_dict.get('key'),
Bucket=bucket,
Key=key,
Filename=f'{os.getcwd()}/{results}'
)
print(f'Deleting remote S3 results file: {ret_dict.get("bucket")}/{ret_dict.get("key")}')
s3_client.delete_object(
Bucket=ret_dict.get('bucket'),
Key=ret_dict.get('key')
)

file = f'{os.getcwd()}/{results}'

return file


def query_rds(query_data, results='query_results.json', **kwargs):
rds = QueryRDS()
if isinstance(query_data, str) and os.path.isfile(query_data):
query_data = rds.read_json_file(query_data)
else:
query_data = json.loads(query_data)

query_data = {'rds_config': query_data, 'is_test': True}


rsp = rds.invoke_rds_lambda(query_data)
ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8'))

# Download results from S3
file = rds.download_file(bucket=ret_dict.get('bucket'), key=ret_dict.get('key'), results=results)
print(
f'{ret_dict.get("count")} {query_data.get("rds_config").get("records")} records obtained: '
f'{os.getcwd()}/{results}'
)
return file


def return_parser(subparsers):
query = {
"rds_config": {
"records": "granules",
"where": "name LIKE nalma% ",
"columns": ["granule_id", "status"],
"limit": 10
}
"records": "granules",
"where": "name LIKE nalma% ",
"columns": ["granule_id", "status"],
"limit": 10
}
subparser = subparsers.add_parser(
'rds_lambda',
help='This plugin is used to submit queries directly to the cumulus RDS bypassing the cumulus API.',
description='Submit queries to the Cumulus RDS instance.\n'
f'Example query: {json.dumps(query)}'
)
choices = ['granules', 'collections', 'providers', 'pdrs', 'rules', 'logs', 'executions', 'reconciliationReport']
choice_str = str(choices).strip('[').strip(']').replace("'", '')
subparser.add_argument(
'record_type',
help=f'The RDS table to be queried: {choice_str}',
metavar='record_type',
choices=choices
)
subparser.add_argument(
'-l', '--limit',
help='Limit the number of records returned from RDS. Default is 100. Use 0 to retrieve all matches.',
metavar='',
default=100
'query',
help='A file containing an RDS Lambda query: <filename>.json or a json query string '
'using the RDS DSL syntax: https://github.com/ghrcdaac/ghrc_rds_lambda?tab=readme-ov-file#querying',
metavar='query'
)
subparser.add_argument(
'-r', '--results',
Expand All @@ -87,26 +94,9 @@ def return_parser(subparsers):
default='query_results.json'
)

group = subparser.add_mutually_exclusive_group(required=True)
group.add_argument(
'-q', '--query',
help='The name of a file containing an RDS lambda query: <filename>.json',
metavar=''
)
group.add_argument(
'-s', '--query-string',
help='A json query string using the RDS DSL: '
'https://github.com/ghrcdaac/ghrc_rds_lambda#querying '
'Example: '
'\'{"rds_config": {"records": "", "columns": ["granule_id"], "where": "name=nalmaraw", "limit": 0}}\'',
metavar=''
)


def main(record_type, results=None, query=None, query_string=None, **kwargs):
rds = QueryRDS()
query = rds.read_json_file(query) if query else json.loads(query_string)
print(f'Using query: {json.dumps(query)}')
rds.query_rds(payload=query)
def main(query=None, records=None, **kwargs):
query_rds(query_data=query, record_type=records)
print('Complete')

return 0
Empty file.
Loading

0 comments on commit a68f782

Please sign in to comment.