Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using pydrive with user credentials for authenticated download #3

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions download_ffhq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Download Flickr-Faces-HQ (FFHQ) dataset to current working directory."""

import os
import re
import sys
import requests
import html
Expand All @@ -27,6 +28,8 @@
import itertools
import shutil
from collections import OrderedDict, defaultdict
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error

Expand Down Expand Up @@ -130,6 +133,50 @@ def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10):
except:
pass

def pydrive_create_drive_manager(cmd_auth):
gAuth = GoogleAuth()

if cmd_auth:
gAuth.CommandLineAuth()
else:
gAuth.LocalWebserverAuth()

gAuth.Authorize()
print("authorized access to google drive API!")

drive: GoogleDrive = GoogleDrive(gAuth)
return drive


def pydrive_extract_files_id(drive, link):
try:
fileID = re.search(r"(?<=/d/|id=|rs/).+?(?=/|$)", link)[0] # extract the fileID
return fileID
except Exception as error:
print("error : " + str(error))
print("Link is probably invalid")
print(link)


def pydrive_download_file(drive, spec, stats, chunk_size=128, num_attempts=10):
link = spec['file_url']
save_path = spec['file_path']
id = pydrive_extract_files_id(drive, link)
file_dir = os.path.dirname(save_path)
if file_dir:
os.makedirs(file_dir, exist_ok=True)

pydrive_file = drive.CreateFile({'id': id})
for attempts_left in reversed(range(num_attempts)):
try:
pydrive_file.GetContentFile(save_path)
break
except:
if not attempts_left:
raise
stats['files_done'] += 1
stats['bytes_done'] += os.stat(save_path).st_size

#----------------------------------------------------------------------------

def choose_bytes_unit(num_bytes):
Expand All @@ -152,7 +199,7 @@ def format_time(seconds):

#----------------------------------------------------------------------------

def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):
def download_files(file_specs, drive=None, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):

# Determine which files to download.
done_specs = {spec['file_path']: spec for spec in file_specs if os.path.isfile(spec['file_path'])}
Expand All @@ -169,7 +216,7 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
exception_queue = queue.Queue()
for spec in missing_specs:
spec_queue.put(spec)
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, download_kwargs=download_kwargs)
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, drive=drive, download_kwargs=download_kwargs)
for _thread_idx in range(min(num_threads, len(missing_specs))):
threading.Thread(target=_download_thread, kwargs=thread_kwargs, daemon=True).start()

Expand Down Expand Up @@ -206,12 +253,15 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
except queue.Empty:
pass

def _download_thread(spec_queue, exception_queue, stats, download_kwargs):
def _download_thread(spec_queue, exception_queue, stats, drive, download_kwargs):
with requests.Session() as session:
while not spec_queue.empty():
spec = spec_queue.get()
try:
download_file(session, spec, stats, **download_kwargs)
if drive is not None:
pydrive_download_file(drive, spec, stats, **download_kwargs)
else:
download_file(session, spec, stats, **download_kwargs)
except:
exception_queue.put(sys.exc_info())

Expand Down Expand Up @@ -350,10 +400,15 @@ def recreate_aligned_images(json_data, dst_dir='realign1024x1024', output_size=1

#----------------------------------------------------------------------------

def run(tasks, **download_kwargs):
def run(tasks, pydrive, cmd_auth, **download_kwargs):
if pydrive:
drive = pydrive_create_drive_manager(cmd_auth)
else:
drive = None

if not os.path.isfile(json_spec['file_path']) or not os.path.isfile('LICENSE.txt'):
print('Downloading JSON metadata...')
download_files([json_spec, license_specs['json']], **download_kwargs)
download_files([json_spec, license_specs['json']], drive=drive, **download_kwargs)

print('Parsing JSON metadata...')
with open(json_spec['file_path'], 'rb') as f:
Expand All @@ -375,7 +430,7 @@ def run(tasks, **download_kwargs):
if len(specs):
print('Downloading %d files...' % len(specs))
np.random.shuffle(specs) # to make the workload more homogeneous
download_files(specs, **download_kwargs)
download_files(specs, drive=drive, **download_kwargs)

if 'align' in tasks:
recreate_aligned_images(json_data)
Expand All @@ -390,6 +445,8 @@ def run_cmdline(argv):
parser.add_argument('-t', '--thumbs', help='download 128x128 thumbnails as PNG (1.95 GB)', dest='tasks', action='append_const', const='thumbs')
parser.add_argument('-w', '--wilds', help='download in-the-wild images as PNG (955 GB)', dest='tasks', action='append_const', const='wilds')
parser.add_argument('-r', '--tfrecords', help='download multi-resolution TFRecords (273 GB)', dest='tasks', action='append_const', const='tfrecords')
parser.add_argument('--pydrive', help='use pydrive interface to download files. it overrides google drive quota limitation this requires google credentials (default: False)', action='store_true')
parser.add_argument('--cmd_auth', help='use command line google authentication when using pydrive interface (default: False)', action='store_true')
parser.add_argument('-a', '--align', help='recreate 1024x1024 images from in-the-wild images', dest='tasks', action='append_const', const='align')
parser.add_argument('--num_threads', help='number of concurrent download threads (default: 32)', type=int, default=32, metavar='NUM')
parser.add_argument('--status_delay', help='time between download status prints (default: 0.2)', type=float, default=0.2, metavar='SEC')
Expand Down