Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated stratifier behavior #129

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/handlers/dashboard/get_chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger.setLevel(log_level)


def _get_table_cols(dp_id: str, version: str | None = None) -> list:
def _get_table_cols(dp_id: str) -> list:
"""Returns the columns associated with a table.

Since running an athena query takes a decent amount of time due to queueing
Expand All @@ -26,11 +26,11 @@ def _get_table_cols(dp_id: str, version: str | None = None) -> list:
"""

s3_bucket_name = os.environ.get("BUCKET_NAME")
dp_name = dp_id.rsplit("__", 1)[0]
prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{dp_id.split('__')[0]}/{dp_name}"
study, name, version = dp_id.split("__")
prefix = f"{enums.BucketPath.CSVAGGREGATE.value}/{study}/{study}__{name}"
if version is None:
version = functions.get_latest_data_package_version(s3_bucket_name, prefix)
s3_key = f"{prefix}/{version}/{dp_name}__aggregate.csv"
s3_key = f"{prefix}/{version}/{study}__{name}__aggregate.csv"
s3_client = boto3.client("s3")
try:
s3_iter = s3_client.get_object(
Expand All @@ -48,10 +48,11 @@ def _build_query(query_params: dict, filters: list, path_params: dict) -> str:
columns = _get_table_cols(dp_id)
filter_str = filter_config.get_filter_string(filters)
if filter_str != "":
filter_str = f"AND {filter_str}"
filter_str = f"AND {filter_str} "
count_col = next(c for c in columns if c.startswith("cnt"))
columns.remove(count_col)
select_str = f"{query_params['column']}, sum({count_col}) as {count_col}"
strat_str = ""
group_str = f"{query_params['column']}"
# the 'if in' check is meant to handle the case where the selected column is also
# present in the filter logic and has already been removed
Expand All @@ -61,29 +62,33 @@ def _build_query(query_params: dict, filters: list, path_params: dict) -> str:
select_str = f"{query_params['stratifier']}, {select_str}"
group_str = f"{query_params['stratifier']}, {group_str}"
columns.remove(query_params["stratifier"])
strat_str = f'AND {query_params["stratifier"]} IS NOT NULL '
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing any quoting of these query params? We should probably be escaping them, then adding double quotes (or verifying that they match [a-zA-Z]+), yeah? That's what the ruff warning S608 is about, which is quieted below, yeah.

Does not need to block this PR, but feels important.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i was just thinking about how gross this approach to SQL construction was - I may make a ticket for this rather than deal with it as part of the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if len(columns) > 0:
coalesce_str = (
f"WHERE COALESCE (cast({' AS VARCHAR), cast('.join(columns)} AS VARCHAR)) "
"IS NOT NULL AND"
"IS NOT NULL AND "
)
else:
coalesce_str = "WHERE"
coalesce_str = "WHERE "
query_str = (
f"SELECT {select_str} " # nosec # noqa: S608
f"FROM \"{os.environ.get('GLUE_DB_NAME')}\".\"{dp_id}\" "
f"{coalesce_str} "
f"{query_params['column']} IS NOT NULL {filter_str} "
f"{coalesce_str}"
f"{query_params['column']} IS NOT NULL "
f"{filter_str}"
f"{strat_str}"
f"GROUP BY {group_str} "
)
if "stratifier" in query_params.keys():
query_str += f"ORDER BY {query_params['stratifier']}, {query_params['column']}"
else:
query_str += f"ORDER BY {query_params['column']}"
logging.debug(query_str)
return query_str
return query_str, count_col


def _format_payload(df: pandas.DataFrame, query_params: dict, filters: list) -> dict:
def _format_payload(
df: pandas.DataFrame, query_params: dict, filters: list, count_col: str
) -> dict:
"""Coerces query results into the return format defined by the dashboard"""
payload = {}
payload["column"] = query_params["column"]
Expand All @@ -92,13 +97,22 @@ def _format_payload(df: pandas.DataFrame, query_params: dict, filters: list) ->
payload["totalCount"] = int(df["cnt"].sum())
if "stratifier" in query_params.keys():
payload["stratifier"] = query_params["stratifier"]
counts = {}
for unique_val in df[query_params["column"]].unique():
column_mask = df[query_params["column"]] == unique_val
df_slice = df[column_mask]
df_slice = df_slice.drop(columns=[query_params["stratifier"], query_params["column"]])
counts[unique_val] = int(df_slice[count_col].sum())
payload["counts"] = counts
data = []
for unique_val in df[query_params["stratifier"]]:
df_slice = df[df[query_params["stratifier"]] == unique_val]
for unique_strat in df[query_params["stratifier"]].unique():
strat_mask = df[query_params["stratifier"]] == unique_strat
df_slice = df[strat_mask]
df_slice = df_slice.drop(columns=[query_params["stratifier"]])
rows = df_slice.values.tolist()
data.append({"stratifier": unique_val, "rows": rows})
data.append({"stratifier": unique_strat, "rows": rows})
payload["data"] = data

else:
rows = df.values.tolist()
payload["data"] = [{"rows": rows}]
Expand All @@ -112,17 +126,19 @@ def chart_data_handler(event, context):
del context
query_params = event["queryStringParameters"]
filters = event["multiValueQueryStringParameters"].get("filter", [])
if "filter" in query_params and filters == []:
filters = [query_params["filter"]]
Comment on lines +129 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the filters == [] check (which could be not filters I think)? Do we want to ignore a filter query entirely if multiValueQueryStringParameters is set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in experimenting with this - I can't explictly say in cloudformation 'this parameter can be an array', so i'm not sure if it's going to be in queryStringParameters or multiValueQueryStringParameters in the AWS inbound event to the lambda. so I need to check both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm not 100% sure of the expected usage here, but maybe combine the two arrays? It just felt odd to me that you'd ignore the filter query param in the case you get both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we got both (which I don't think should happen), queryStringParameters would have to be a subset of multiValueQueryStringParameters

path_params = event["pathParameters"]
boto3.setup_default_session(region_name="us-east-1")
try:
query = _build_query(query_params, filters, path_params)
query, count_col = _build_query(query_params, filters, path_params)
df = awswrangler.athena.read_sql_query(
query,
database=os.environ.get("GLUE_DB_NAME"),
s3_output=f"s3://{os.environ.get('BUCKET_NAME')}/awswrangler",
workgroup=os.environ.get("WORKGROUP_NAME"),
)
res = _format_payload(df, query_params, filters)
res = _format_payload(df, query_params, filters, count_col)
res = functions.http_response(200, res)
except errors.AggregatorS3Error:
# while the API is publicly accessible, we've been asked to not pass
Expand Down
2 changes: 2 additions & 0 deletions template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ Resources:
Required: true
- method.request.querystring.filters:
Required: false
- method.request.querystring.stratifier:
Required: false
Policies:
- S3CrudPolicy:
BucketName: !Ref AggregatorBucket
Expand Down
23 changes: 14 additions & 9 deletions tests/dashboard/test_get_chart_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ def mock_data_frame(filter_param):
[],
{"data_package_id": "test_study"},
f'SELECT gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE COALESCE (cast(race AS VARCHAR)) IS NOT NULL AND gender IS NOT NULL "
"WHERE COALESCE (cast(race AS VARCHAR)) IS NOT NULL AND gender IS NOT NULL "
"GROUP BY gender ORDER BY gender",
),
(
{"column": "gender", "stratifier": "race"},
[],
{"data_package_id": "test_study"},
f'SELECT race, gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE gender IS NOT NULL "
"WHERE gender IS NOT NULL "
"AND race IS NOT NULL "
"GROUP BY race, gender ORDER BY race, gender",
),
(
Expand All @@ -63,12 +64,13 @@ def mock_data_frame(filter_param):
f'SELECT race, gender, sum(cnt) as cnt FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE gender IS NOT NULL "
"AND gender LIKE 'female' "
"AND race IS NOT NULL "
"GROUP BY race, gender ORDER BY race, gender",
),
],
)
def test_build_query(query_params, filters, path_params, query_str):
query = get_chart_data._build_query(query_params, filters, path_params)
query, _ = get_chart_data._build_query(query_params, filters, path_params)
assert query == query_str


Expand Down Expand Up @@ -99,7 +101,7 @@ def test_build_query(query_params, filters, path_params, query_str):
)
def test_format_payload(query_params, filters, expected_payload):
df = mock_data_frame(filters)
payload = get_chart_data._format_payload(df, query_params, filters)
payload = get_chart_data._format_payload(df, query_params, filters, "cnt")
assert payload == expected_payload


Expand All @@ -113,11 +115,14 @@ def test_get_data_cols(mock_bucket):
@mock.patch(
"src.handlers.dashboard.get_chart_data._build_query",
lambda query_params, filters, path_params: (
"SELECT gender, sum(cnt) as cnt"
f'FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE COALESCE (race) IS NOT Null AND gender IS NOT Null "
"AND gender LIKE 'female' "
"GROUP BY gender",
(
"SELECT gender, sum(cnt) as cnt"
f'FROM "{TEST_GLUE_DB}"."test_study" '
"WHERE COALESCE (race) IS NOT NULL AND gender IS NOT NULL "
"AND gender LIKE 'female' "
"GROUP BY gender",
"cnt",
)
),
)
@mock.patch(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_data/cube_response_filtered_stratified.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
"rowCount": 5,
"totalCount": 33839,
"stratifier": "race",
"counts": {
"female": 33839
},
"data": [
{
"stratifier": "",
Expand Down
Loading
Loading