Skip to content

Commit

Permalink
Merge pull request #39 from ghrcdaac/mlh0079-rds-results-file-fix
Browse files Browse the repository at this point in the history
 - Result filename is now settable.
  • Loading branch information
camposeddie authored Dec 12, 2023
2 parents 052c12e + f58381f commit 731a6d2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
19 changes: 9 additions & 10 deletions pylot/plugins/rds_lambda/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,22 @@ def download_file(self, bucket, key, results, s3_client=None):
return file


def query_rds(query_data, results='query_results.json', **kwargs):
def query_rds(query, results='query_results.json', **kwargs):
rds = QueryRDS()
if isinstance(query_data, str) and os.path.isfile(query_data):
query_data = rds.read_json_file(query_data)
if isinstance(query, str) and os.path.isfile(query):
query = rds.read_json_file(query)
else:
query_data = json.loads(query_data)
query = json.loads(query)

query_data = {'rds_config': query_data, 'is_test': True}
query = {'rds_config': query, 'is_test': True}


rsp = rds.invoke_rds_lambda(query_data)
rsp = rds.invoke_rds_lambda(query)
ret_dict = json.loads(rsp.get('Payload').read().decode('utf-8'))

# Download results from S3
file = rds.download_file(bucket=ret_dict.get('bucket'), key=ret_dict.get('key'), results=results)
print(
f'{ret_dict.get("count")} {query_data.get("rds_config").get("records")} records obtained: '
f'{ret_dict.get("count")} {query.get("rds_config").get("records")} records obtained: '
f'{os.getcwd()}/{results}'
)
return file
Expand Down Expand Up @@ -95,8 +94,8 @@ def return_parser(subparsers):
)


def main(query=None, records=None, **kwargs):
query_rds(query_data=query, record_type=records)
def main(**kwargs):
query_rds(**kwargs)
print('Complete')

return 0
2 changes: 1 addition & 1 deletion pylot/plugins/rds_lambda/tests/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_read_json_file(self):
def test_query_rds(self, mock_opensearch, mock_json_loads):
mock_opensearch.invoke_rds_lambda.return_value = ''
mock_opensearch.invoke_rds_lambda.return_value = ''
query_rds(query_data={}, record_type='')
query_rds(query={}, record_type='')
pass

def test_invoke_rds_lambda(self):
Expand Down

0 comments on commit 731a6d2

Please sign in to comment.