From 5e32cbebac793dddb75b1211e13b7d9baa783f58 Mon Sep 17 00:00:00 2001 From: Mike Heddes Date: Tue, 5 Mar 2024 20:56:41 -0800 Subject: [PATCH] Simplify google drive download --- torchhd/datasets/utils.py | 85 ++++----------------------------------- 1 file changed, 7 insertions(+), 78 deletions(-) diff --git a/torchhd/datasets/utils.py b/torchhd/datasets/utils.py index 05898e84..71cd97f9 100644 --- a/torchhd/datasets/utils.py +++ b/torchhd/datasets/utils.py @@ -23,13 +23,8 @@ # import zipfile import requests -import re import tqdm -# Code adapted from: -# https://github.com/wkentaro/gdown/blob/941200a9a1f4fd7ab903fb595baa5cad34a30a45/gdown/download.py -# https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url - def download_file(url, destination): response = requests.get(url, allow_redirects=True, stream=True) @@ -37,79 +32,13 @@ def download_file(url, destination): def download_file_from_google_drive(file_id, destination): - URL = "https://docs.google.com/uc" - params = dict(id=file_id, export="download") - - with requests.Session() as session: - response = session.get(URL, params=params, stream=True) - - # downloads right away - if "Content-Disposition" in response.headers: - write_response_to_disk(response, destination) - return - - # try to find a confirmation token - token = get_google_drive_confirm_token(response) - - if token: - params = dict(id=id, confirm=token) - response = session.get(URL, params=params, stream=True) - - # download if confirmation token worked - if "Content-Disposition" in response.headers: - write_response_to_disk(response, destination) - return - - # extract download url from confirmation page - url = get_url_from_gdrive_confirmation(response.text) - response = session.get(url, stream=True) - - write_response_to_disk(response, destination) - - -def get_google_drive_confirm_token(response): - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - return value - - return None - - -def get_url_from_gdrive_confirmation(contents): - url = "" - for line in contents.splitlines(): - m = re.search(r'href="(\/uc\?export=download[^"]+)', line) - if m: - url = "https://docs.google.com" + m.groups()[0] - url = url.replace("&", "&") - break - m = re.search('id="downloadForm" action="(.+?)"', line) - if m: - url = m.groups()[0] - url = url.replace("&", "&") - break - m = re.search('id="download-form" action="(.+?)"', line) - if m: - url = m.groups()[0] - url = url.replace("&", "&") - break - m = re.search('"downloadUrl":"([^"]+)', line) - if m: - url = m.groups()[0] - url = url.replace("\\u003d", "=") - url = url.replace("\\u0026", "&") - break - m = re.search('

(.*)

', line) - if m: - error = m.groups()[0] - raise RuntimeError(error) - if not url: - raise RuntimeError( - "Cannot retrieve the public link of the file. " - "You may need to change the permission to " - "'Anyone with the link', or have had many accesses." - ) - return url + try: + import gdown + except ImportError: + raise ImportError("Downloading files from Google drive requires gdown to be installed, see: https://github.com/wkentaro/gdown") + + url = f"https://drive.google.com/uc?id={file_id}" + gdown.download(url, destination) def get_download_progress_bar(response):