Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added study results download script #135

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ auth.json
metadata.json
dashboard_api.yaml
.DS_Store
*.zip

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies= [
"awswrangler >=3.5, <4",
"boto3",
"pandas >=2, <3",
"requests", # scripts only
"rich",
]
authors = [
Expand Down
65 changes: 65 additions & 0 deletions scripts/bulk_csv_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import argparse
import io
import os
import pathlib
import sys
import zipfile

import requests
from rich.progress import track


def bulk_csv_download(args):
if args.api_key is None:
args.api_key = os.getenv("CUMULUS_AGGREGATOR_API_KEY")
args.type = args.type.replace("_", "-")
if args.type not in ["last-valid", "aggregates"]:
sys.exit('Invalid type. Expected "last-valid" or "aggregates"')
dp_url = f"https://{args.domain}/{args.type}"
try:
res = requests.get(dp_url, headers={"x-api-key": args.api_key}, timeout=300)
except requests.exceptions.ConnectionError:
sys.exit("Invalid domain name")
if res.status_code == 403:
sys.exit("Invalid API key")
file_urls = res.json()
urls = []
version = 0
for file_url in file_urls:
file_array = file_url.split("/")
dp_version = int(file_array[4 if args.type == "last-valid" else 3])
if file_array[1] == args.study:
if dp_version > version:
version = int(dp_version)
urls = []
elif int(dp_version) == version:
if (
args.type == "last-valid" and args.site == file_array[3]
) or args.type == "aggregates":
urls.append(file_url)
if len(urls) == 0:
sys.exit(f"No aggregates matching {args.study} found")
archive = io.BytesIO()
with zipfile.ZipFile(archive, "w") as zip_archive:
for file in track(urls, description=f"Downloading {args.study} aggregates"):
csv_url = f"https://{args.domain}/{file}"
res = requests.get(
csv_url, headers={"x-api-key": args.api_key}, allow_redirects=True, timeout=300
)
with zip_archive.open(file.split("/")[-1], "w") as f:
f.write(bytes(res.text, "UTF-8"))
with open(pathlib.Path.cwd() / f"{args.study}.zip", "wb") as output:
output.write(archive.getbuffer())


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="""Fetches all data for a given study""")
parser.add_argument("-s", "--study", help="Name of study to download", required=True)
parser.add_argument("-i", "--site", help="Name of site to download (last-valid only)")
parser.add_argument(
"-d", "--domain", help="Domain of aggregator", default="api.smartcumulus.org"
)
parser.add_argument("-t", "--type", help="type of aggregate", default="last-valid")
parser.add_argument("-a", "--apikey", dest="api_key", help="API key of aggregator")
args = parser.parse_args()
bulk_csv_download(args)
6 changes: 3 additions & 3 deletions src/handlers/dashboard/get_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_csv_list_handler(event, context):
s3_client = boto3.client("s3")
if event["path"].startswith("/last-valid"):
key_prefix = "last_valid"
url_prefix = "last_valid"
url_prefix = "last-valid"
elif event["path"].startswith("/aggregates"):
key_prefix = "csv_aggregates"
url_prefix = "aggregates"
Expand All @@ -104,9 +104,9 @@ def get_csv_list_handler(event, context):
data_package = key_parts[2].split("__")[1]
version = key_parts[-2]
filename = key_parts[-1]
site = key_parts[3] if url_prefix == "last_valid" else None
site = key_parts[3] if url_prefix == "last-valid" else None
url_parts = [url_prefix, study, data_package, version, filename]
if url_prefix == "last_valid":
if url_prefix == "last-valid":
url_parts.insert(3, site)
urls.append("/".join(url_parts))
if not s3_objs["IsTruncated"]:
Expand Down
2 changes: 1 addition & 1 deletion tests/dashboard/test_get_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_get_csv(mock_bucket, params, status, expected):
"/last-valid",
200,
[
"last_valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv"
"last-valid/study/encounter/princeton_plainsboro_teaching_hospital/099/study__encounter__aggregate.csv"
],
does_not_raise(),
),
Expand Down
Loading