diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ed91340..ac27747 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -42,3 +42,11 @@ jobs: if: success() || failure() # still run black if above checks fails run: | black --check --verbose . + coverage: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - name: Code coverage-reporter + uses: tj-actions/coverage-reporter@v5.1 + with: + coverage-command: 'python -m coverage report' diff --git a/pyproject.toml b/pyproject.toml index b16a4fb..fc3e1ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,13 +46,22 @@ test = [ dev = [ "bandit", "black==22.12.0", + "coverage >=7.3.1", "isort==5.12.0", "pre-commit", "pylint", "pycodestyle" ] +[tool.coverage.run] +command_line="-m pytest" +source=["./src/"] + +[tool.coverage.report] +show_missing=true + [tool.isort] profile = "black" src_paths = ["src", "tests"] skip_glob = [".aws_sam"] + diff --git a/scripts/migrations/migration.001.transaction_cleanup.py b/scripts/migrations/migration.001.transaction_cleanup.py index 2f479f9..b40afbf 100644 --- a/scripts/migrations/migration.001.transaction_cleanup.py +++ b/scripts/migrations/migration.001.transaction_cleanup.py @@ -76,9 +76,9 @@ def transaction_cleanup(bucket: str): "transaction_format_version" ] = new_t[site][study][dp][version]["transacton_format_version"] new_t[site][study][dp][version].pop("transacton_format_version") - print(json.dumps(new_t, indent=2)) - # _put_s3_data("metadata/transactions.json", bucket, client, new_t) - print("transactions.json updated") + # print(json.dumps(new_t, indent=2)) + _put_s3_data("metadata/transactions.json", bucket, client, new_t) + print("transactions.json updated") if __name__ == "__main__": diff --git a/src/handlers/dashboard/get_chart_data.py b/src/handlers/dashboard/get_chart_data.py index 1ec7f62..485b138 100644 --- a/src/handlers/dashboard/get_chart_data.py +++ b/src/handlers/dashboard/get_chart_data.py @@ -10,21 +10,23 @@ from src.handlers.dashboard.filter_config import get_filter_string from src.handlers.shared.decorators import generic_error_handler from src.handlers.shared.enums import BucketPath -from src.handlers.shared.functions import http_response +from src.handlers.shared.functions import get_latest_data_package_version, http_response -def _get_table_cols(table_name: str) -> list: +def _get_table_cols(table_name: str, version: str = None) -> list: """Returns the columns associated with a table. Since running an athena query takes a decent amount of time due to queueing a query with the execution engine, and we already have this data at the top of a CSV, we're getting table cols directly from S3 for speed reasons. """ + s3_bucket_name = os.environ.get("BUCKET_NAME") - s3_key = ( - f"{BucketPath.CSVAGGREGATE.value}/{table_name.split('__')[0]}" - f"/{table_name}/{table_name}__aggregate.csv" - ) + prefix = f"{BucketPath.CSVAGGREGATE.value}/{table_name.split('__')[0]}/{table_name}" + if version is None: + version = get_latest_data_package_version(s3_bucket_name, prefix) + print(f"{prefix}/{version}/{table_name}__aggregate.csv") + s3_key = f"{prefix}/{version}/{table_name}__aggregate.csv" s3_client = boto3.client("s3") s3_iter = s3_client.get_object( Bucket=s3_bucket_name, Key=s3_key # type: ignore[arg-type] diff --git a/src/handlers/shared/functions.py b/src/handlers/shared/functions.py index be3fe11..4f06ba0 100644 --- a/src/handlers/shared/functions.py +++ b/src/handlers/shared/functions.py @@ -161,3 +161,21 @@ def get_s3_json_as_dict(bucket, key: str): Fileobj=bytes_buffer, ) return json.loads(bytes_buffer.getvalue().decode()) + + +def get_latest_data_package_version(bucket, prefix): + """Returns the newest version in a data package folder""" + s3_client = boto3.client("s3") + if not prefix.endswith("/"): + prefix = prefix + "/" + s3_res = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix) + highest_ver = None + for item in s3_res["Contents"]: + ver_str = item["Key"].replace(prefix, "").split("/")[0] + if ver_str.isdigit(): + if highest_ver is None: + highest_ver = ver_str + else: + if int(highest_ver) < int(ver_str): + highest_ver = ver_str + return highest_ver diff --git a/src/handlers/site_upload/api_gateway_authorizer.py b/src/handlers/site_upload/api_gateway_authorizer.py index 0e113de..809dccd 100644 --- a/src/handlers/site_upload/api_gateway_authorizer.py +++ b/src/handlers/site_upload/api_gateway_authorizer.py @@ -66,7 +66,7 @@ class HttpVerb: ALL = "*" -class AuthPolicy(object): # pylint: disable=missing-class-docstring +class AuthPolicy(object): # pylint: disable=missing-class-docstring; # pragma: no cover awsAccountId = "" """The AWS account id the policy will be generated for. This is used to create the method ARNs.""" diff --git a/tests/dashboard/test_get_chart_data.py b/tests/dashboard/test_get_chart_data.py index 782729a..6ccb89b 100644 --- a/tests/dashboard/test_get_chart_data.py +++ b/tests/dashboard/test_get_chart_data.py @@ -2,11 +2,20 @@ import os from unittest import mock +import boto3 import pandas import pytest from src.handlers.dashboard import get_chart_data -from tests.utils import MOCK_ENV, TEST_BUCKET, TEST_GLUE_DB, TEST_WORKGROUP +from tests.utils import ( + EXISTING_DATA_P, + EXISTING_STUDY, + EXISTING_VERSION, + MOCK_ENV, + TEST_BUCKET, + TEST_GLUE_DB, + TEST_WORKGROUP, +) def mock_get_table_cols(name): @@ -97,3 +106,12 @@ def test_format_payload(query_params, filters, expected_payload): df = mock_data_frame(filters) payload = get_chart_data._format_payload(df, query_params, filters) assert payload == expected_payload + + +def test_get_data_cols(mock_bucket): + s3_client = boto3.client("s3", region_name="us-east-1") + s3_res = s3_client.list_objects_v2(Bucket=TEST_BUCKET) + table_name = f"{EXISTING_STUDY}__{EXISTING_DATA_P}" + res = get_chart_data._get_table_cols(table_name) + cols = pandas.read_csv("./tests/test_data/count_synthea_patient_agg.csv").columns + assert res == list(cols)