From 408539d4f64ea06f72560bf051b21b16b2223458 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Wed, 16 Oct 2024 15:22:32 -0400 Subject: [PATCH 1/2] Updated stratifier behavior --- src/handlers/dashboard/get_chart_data.py | 40 ++-- template.yaml | 2 + tests/dashboard/test_get_chart_data.py | 25 ++- .../cube_response_filtered_stratified.json | 3 + tests/test_data/cube_response_stratified.json | 175 +----------------- 5 files changed, 53 insertions(+), 192 deletions(-) diff --git a/src/handlers/dashboard/get_chart_data.py b/src/handlers/dashboard/get_chart_data.py index 1c835e2..c98ea07 100644 --- a/src/handlers/dashboard/get_chart_data.py +++ b/src/handlers/dashboard/get_chart_data.py @@ -48,10 +48,11 @@ def _build_query(query_params: dict, filters: list, path_params: dict) -> str: columns = _get_table_cols(dp_id) filter_str = filter_config.get_filter_string(filters) if filter_str != "": - filter_str = f"AND {filter_str}" + filter_str = f"AND {filter_str} " count_col = next(c for c in columns if c.startswith("cnt")) columns.remove(count_col) select_str = f"{query_params['column']}, sum({count_col}) as {count_col}" + strat_str = "" group_str = f"{query_params['column']}" # the 'if in' check is meant to handle the case where the selected column is also # present in the filter logic and has already been removed @@ -61,29 +62,33 @@ def _build_query(query_params: dict, filters: list, path_params: dict) -> str: select_str = f"{query_params['stratifier']}, {select_str}" group_str = f"{query_params['stratifier']}, {group_str}" columns.remove(query_params["stratifier"]) + strat_str = f'AND {query_params["stratifier"]} IS NOT NULL ' if len(columns) > 0: coalesce_str = ( f"WHERE COALESCE (cast({' AS VARCHAR), cast('.join(columns)} AS VARCHAR)) " - "IS NOT NULL AND" + "IS NOT NULL AND " ) else: - coalesce_str = "WHERE" + coalesce_str = "WHERE " query_str = ( f"SELECT {select_str} " # nosec # noqa: S608 f"FROM \"{os.environ.get('GLUE_DB_NAME')}\".\"{dp_id}\" " - f"{coalesce_str} " - f"{query_params['column']} IS NOT NULL {filter_str} " + f"{coalesce_str}" + f"{query_params['column']} IS NOT NULL " + f"{filter_str}" + f"{strat_str}" f"GROUP BY {group_str} " ) if "stratifier" in query_params.keys(): query_str += f"ORDER BY {query_params['stratifier']}, {query_params['column']}" else: query_str += f"ORDER BY {query_params['column']}" - logging.debug(query_str) - return query_str + return query_str, count_col -def _format_payload(df: pandas.DataFrame, query_params: dict, filters: list) -> dict: +def _format_payload( + df: pandas.DataFrame, query_params: dict, filters: list, count_col: str +) -> dict: """Coerces query results into the return format defined by the dashboard""" payload = {} payload["column"] = query_params["column"] @@ -92,13 +97,20 @@ def _format_payload(df: pandas.DataFrame, query_params: dict, filters: list) -> payload["totalCount"] = int(df["cnt"].sum()) if "stratifier" in query_params.keys(): payload["stratifier"] = query_params["stratifier"] + counts = {} + for unique_val in df[query_params["column"]]: + df_slice = df[df[query_params["column"]] == unique_val] + df_slice = df_slice.drop(columns=[query_params["stratifier"], query_params["column"]]) + counts[unique_val] = int(df_slice[count_col].sum()) + payload["counts"] = counts data = [] - for unique_val in df[query_params["stratifier"]]: - df_slice = df[df[query_params["stratifier"]] == unique_val] + for unique_strat in df[query_params["stratifier"]].unique(): + df_slice = df[df[query_params["stratifier"]] == unique_strat] df_slice = df_slice.drop(columns=[query_params["stratifier"]]) rows = df_slice.values.tolist() - data.append({"stratifier": unique_val, "rows": rows}) + data.append({"stratifier": unique_strat, "rows": rows}) payload["data"] = data + else: rows = df.values.tolist() payload["data"] = [{"rows": rows}] @@ -112,17 +124,19 @@ def chart_data_handler(event, context): del context query_params = event["queryStringParameters"] filters = event["multiValueQueryStringParameters"].get("filter", []) + if "filter" in query_params and filters == []: + filters = [query_params["filter"]] path_params = event["pathParameters"] boto3.setup_default_session(region_name="us-east-1") try: - query = _build_query(query_params, filters, path_params) + query, count_col = _build_query(query_params, filters, path_params) df = awswrangler.athena.read_sql_query( query, database=os.environ.get("GLUE_DB_NAME"), s3_output=f"s3://{os.environ.get('BUCKET_NAME')}/awswrangler", workgroup=os.environ.get("WORKGROUP_NAME"), ) - res = _format_payload(df, query_params, filters) + res = _format_payload(df, query_params, filters, count_col) res = functions.http_response(200, res) except errors.AggregatorS3Error: # while the API is publicly accessible, we've been asked to not pass diff --git a/template.yaml b/template.yaml index 3d8c21e..2feeb5b 100644 --- a/template.yaml +++ b/template.yaml @@ -357,6 +357,8 @@ Resources: Required: true - method.request.querystring.filters: Required: false + - method.request.querystring.stratifier: + Required: false Policies: - S3CrudPolicy: BucketName: !Ref AggregatorBucket diff --git a/tests/dashboard/test_get_chart_data.py b/tests/dashboard/test_get_chart_data.py index 6517120..c69314c 100644 --- a/tests/dashboard/test_get_chart_data.py +++ b/tests/dashboard/test_get_chart_data.py @@ -36,7 +36,7 @@ def mock_data_frame(filter_param): [], {"data_package_id": "test_study"}, f'SELECT gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" ' - "WHERE COALESCE (cast(race AS VARCHAR)) IS NOT NULL AND gender IS NOT NULL " + "WHERE COALESCE (cast(race AS VARCHAR)) IS NOT NULL AND gender IS NOT NULL " "GROUP BY gender ORDER BY gender", ), ( @@ -44,7 +44,8 @@ def mock_data_frame(filter_param): [], {"data_package_id": "test_study"}, f'SELECT race, gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" ' - "WHERE gender IS NOT NULL " + "WHERE gender IS NOT NULL " + "AND race IS NOT NULL " "GROUP BY race, gender ORDER BY race, gender", ), ( @@ -63,12 +64,13 @@ def mock_data_frame(filter_param): f'SELECT race, gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" ' "WHERE gender IS NOT NULL " "AND gender LIKE 'female' " + "AND race IS NOT NULL " "GROUP BY race, gender ORDER BY race, gender", ), ], ) def test_build_query(query_params, filters, path_params, query_str): - query = get_chart_data._build_query(query_params, filters, path_params) + query, _ = get_chart_data._build_query(query_params, filters, path_params) assert query == query_str @@ -99,7 +101,9 @@ def test_build_query(query_params, filters, path_params, query_str): ) def test_format_payload(query_params, filters, expected_payload): df = mock_data_frame(filters) - payload = get_chart_data._format_payload(df, query_params, filters) + payload = get_chart_data._format_payload(df, query_params, filters, "cnt") + print(query_params) + print(payload) assert payload == expected_payload @@ -113,11 +117,14 @@ def test_get_data_cols(mock_bucket): @mock.patch( "src.handlers.dashboard.get_chart_data._build_query", lambda query_params, filters, path_params: ( - "SELECT gender, sum(cnt) as cnt" - f'FROM "{TEST_GLUE_DB}"."test_study" ' - "WHERE COALESCE (race) IS NOT Null AND gender IS NOT Null " - "AND gender LIKE 'female' " - "GROUP BY gender", + ( + "SELECT gender, sum(cnt) as cnt" + f'FROM "{TEST_GLUE_DB}"."test_study" ' + "WHERE COALESCE (race) IS NOT Null AND gender IS NOT Null " + "AND gender LIKE 'female' " + "GROUP BY gender", + "cnt", + ) ), ) @mock.patch( diff --git a/tests/test_data/cube_response_filtered_stratified.json b/tests/test_data/cube_response_filtered_stratified.json index 1693d17..bf1d105 100644 --- a/tests/test_data/cube_response_filtered_stratified.json +++ b/tests/test_data/cube_response_filtered_stratified.json @@ -6,6 +6,9 @@ "rowCount": 5, "totalCount": 33839, "stratifier": "race", + "counts": { + "female": 33839 + }, "data": [ { "stratifier": "", diff --git a/tests/test_data/cube_response_stratified.json b/tests/test_data/cube_response_stratified.json index ea46771..93eabd8 100644 --- a/tests/test_data/cube_response_stratified.json +++ b/tests/test_data/cube_response_stratified.json @@ -4,6 +4,11 @@ "rowCount": 15, "totalCount": 151974, "stratifier": "race", + "counts": { + "": 75987, + "male": 42148, + "female": 33839 + }, "data": [ { "stratifier": "", @@ -39,108 +44,6 @@ ] ] }, - { - "stratifier": "", - "rows": [ - [ - 37990, - "" - ], - [ - 21073, - "male" - ], - [ - 16917, - "female" - ] - ] - }, - { - "stratifier": "White", - "rows": [ - [ - 32392, - "" - ], - [ - 18287, - "male" - ], - [ - 14105, - "female" - ] - ] - }, - { - "stratifier": "", - "rows": [ - [ - 37990, - "" - ], - [ - 21073, - "male" - ], - [ - 16917, - "female" - ] - ] - }, - { - "stratifier": "White", - "rows": [ - [ - 32392, - "" - ], - [ - 18287, - "male" - ], - [ - 14105, - "female" - ] - ] - }, - { - "stratifier": "Black or African American", - "rows": [ - [ - 3870, - "" - ], - [ - 2055, - "female" - ], - [ - 1815, - "male" - ] - ] - }, - { - "stratifier": "Black or African American", - "rows": [ - [ - 3870, - "" - ], - [ - 2055, - "female" - ], - [ - 1815, - "male" - ] - ] - }, { "stratifier": "Black or African American", "rows": [ @@ -175,74 +78,6 @@ ] ] }, - { - "stratifier": "Asian", - "rows": [ - [ - 1675, - "" - ], - [ - 940, - "male" - ], - [ - 735, - "female" - ] - ] - }, - { - "stratifier": "Asian", - "rows": [ - [ - 1675, - "" - ], - [ - 940, - "male" - ], - [ - 735, - "female" - ] - ] - }, - { - "stratifier": "American Indian or Alaska Native", - "rows": [ - [ - 60, - "" - ], - [ - 33, - "male" - ], - [ - 27, - "female" - ] - ] - }, - { - "stratifier": "American Indian or Alaska Native", - "rows": [ - [ - 60, - "" - ], - [ - 33, - "male" - ], - [ - 27, - "female" - ] - ] - }, { "stratifier": "American Indian or Alaska Native", "rows": [ From 39e236ed86755bb7da775f8c8f042f7d3baac194 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Thu, 17 Oct 2024 14:33:05 -0400 Subject: [PATCH 2/2] PR feedback --- src/handlers/dashboard/get_chart_data.py | 16 +++++++++------- tests/dashboard/test_get_chart_data.py | 4 +--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/handlers/dashboard/get_chart_data.py b/src/handlers/dashboard/get_chart_data.py index c98ea07..a68ba2b 100644 --- a/src/handlers/dashboard/get_chart_data.py +++ b/src/handlers/dashboard/get_chart_data.py @@ -17,7 +17,7 @@ logger.setLevel(log_level) -def _get_table_cols(dp_id: str, version: str | None = None) -> list: +def _get_table_cols(dp_id: str) -> list: """Returns the columns associated with a table. Since running an athena query takes a decent amount of time due to queueing @@ -26,11 +26,11 @@ def _get_table_cols(dp_id: str, version: str | None = None) -> list: """ s3_bucket_name = os.environ.get("BUCKET_NAME") - dp_name = dp_id.rsplit("__", 1)[0] - prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{dp_id.split('__')[0]}/{dp_name}" + study, name, version = dp_id.split("__") + prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{study}/{study}__{name}" if version is None: version = functions.get_latest_data_package_version(s3_bucket_name, prefix) - s3_key = f"{prefix}/{version}/{dp_name}__aggregate.csv" + s3_key = f"{prefix}/{version}/{study}__{name}__aggregate.csv" s3_client = boto3.client("s3") try: s3_iter = s3_client.get_object( @@ -98,14 +98,16 @@ def _format_payload( if "stratifier" in query_params.keys(): payload["stratifier"] = query_params["stratifier"] counts = {} - for unique_val in df[query_params["column"]]: - df_slice = df[df[query_params["column"]] == unique_val] + for unique_val in df[query_params["column"]].unique(): + column_mask = df[query_params["column"]] == unique_val + df_slice = df[column_mask] df_slice = df_slice.drop(columns=[query_params["stratifier"], query_params["column"]]) counts[unique_val] = int(df_slice[count_col].sum()) payload["counts"] = counts data = [] for unique_strat in df[query_params["stratifier"]].unique(): - df_slice = df[df[query_params["stratifier"]] == unique_strat] + strat_mask = df[query_params["stratifier"]] == unique_strat + df_slice = df[strat_mask] df_slice = df_slice.drop(columns=[query_params["stratifier"]]) rows = df_slice.values.tolist() data.append({"stratifier": unique_strat, "rows": rows}) diff --git a/tests/dashboard/test_get_chart_data.py b/tests/dashboard/test_get_chart_data.py index c69314c..0f6c9c5 100644 --- a/tests/dashboard/test_get_chart_data.py +++ b/tests/dashboard/test_get_chart_data.py @@ -102,8 +102,6 @@ def test_build_query(query_params, filters, path_params, query_str): def test_format_payload(query_params, filters, expected_payload): df = mock_data_frame(filters) payload = get_chart_data._format_payload(df, query_params, filters, "cnt") - print(query_params) - print(payload) assert payload == expected_payload @@ -120,7 +118,7 @@ def test_get_data_cols(mock_bucket): ( "SELECT gender, sum(cnt) as cnt" f'FROM "{TEST_GLUE_DB}"."test_study" ' - "WHERE COALESCE (race) IS NOT Null AND gender IS NOT Null " + "WHERE COALESCE (race) IS NOT NULL AND gender IS NOT NULL " "AND gender LIKE 'female' " "GROUP BY gender", "cnt",