Skip to content

Commit

Permalink
GHA-18 Cloud deployment
Browse files Browse the repository at this point in the history
Storage refactor (variable names, exception messages...)
  • Loading branch information
pvanliefland committed Feb 1, 2021
1 parent ac7d53d commit d408560
Showing 1 changed file with 65 additions and 75 deletions.
140 changes: 65 additions & 75 deletions geohealthaccess/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@


class Location:
def __init__(self, loc):
def __init__(self, path):
"""Parsed location string (file path or S3/GCS URL)."""
self.loc = loc
self._raw_path = path

@property
def protocol(self):
"""Return protocol of a location."""
parsed_url = self.loc.split("://")
parsed_url = self._raw_path.split("://")
if len(parsed_url) == 1:
return "local"
else:
Expand All @@ -40,9 +40,12 @@ def protocol(self):
def path(self):
"""Return location path (without scheme/protocol)."""
if self.protocol == "local":
return self.loc
return self._raw_path
else:
return self.loc.split("://")[-1]
return self._raw_path.split("://")[-1]

def __str__(self):
return self._raw_path


def get_s3fs():
Expand All @@ -66,18 +69,6 @@ def get_s3fs():
)


def is_gce_instance():
"""Check if code is running inside a GCE instance.
Via DNS lookup to metadata server.
"""
try:
socket.getaddrinfo("metadata.google.internal", 80)
except socket.gaierror:
return False
return True


def get_gcsfs():
"""Initialize a GCS filesystem from environment variables.
Expand All @@ -97,12 +88,12 @@ def get_gcsfs():
return gcsfs.GCSFileSystem()


def ls(loc):
def ls(path):
"""List contents of a directory.
Simulates the behavior of os.listdir().
"""
location = Location(loc)
location = Location(path)

# local
if location.protocol == "local":
Expand All @@ -119,57 +110,57 @@ def ls(loc):
return [f.split("/")[-1] for f in fs.ls(location.path)]

else:
raise IOError(f"{location.protocol} not supported.")
raise IOError(f"ls for {location} is not supported.")


def cp(src, dst):
def cp(src_path, dst_path):
"""Copy a file.
Copying a file from S3 to GCS is not supported.
"""
src, dst = Location(src), Location(dst)
src_location, dst_location = Location(src_path), Location(dst_path)

# local
if src.protocol == "local" and dst.protocol == "local":
shutil.copy(src.path, dst.path)
if src_location.protocol == "local" and dst_location.protocol == "local":
shutil.copy(src_location.path, dst_location.path)

# from S3 to local
elif src.protocol == "s3" and dst.protocol == "local":
elif src_location.protocol == "s3" and dst_location.protocol == "local":
fs = get_s3fs()
fs.get(src.path, dst.path)
fs.get(src_location.path, dst_location.path)

# from local to S3
elif src.protocol == "local" and dst.protocol == "s3":
elif src_location.protocol == "local" and dst_location.protocol == "s3":
fs = get_s3fs()
fs.put(src.path, dst.path)
fs.put(src_location.path, dst_location.path)

# from S3 to S3
elif src.protocol == "s3" and dst.protocol == "s3":
elif src_location.protocol == "s3" and dst_location.protocol == "s3":
fs = get_s3fs()
fs.copy(src.path, dst.path)
fs.copy(src_location.path, dst_location.path)

# from GCS to local
elif src.protocol == "gcs" and dst.protocol == "local":
elif src_location.protocol == "gcs" and dst_location.protocol == "local":
fs = get_gcsfs()
fs.get(src.path, dst.path)
fs.get(src_location.path, dst_location.path)

# from local to GCS
elif src.protocol == "local" and dst.protocol == "gcs":
elif src_location.protocol == "local" and dst_location.protocol == "gcs":
fs = get_gcsfs()
fs.put(src.path, dst.path)
fs.put(src_location.path, dst_location.path)

# from GCS to GCS
elif src.protocol == "gcs" and dst.protocol == "gcs":
elif src_location.protocol == "gcs" and dst_location.protocol == "gcs":
fs = get_gcsfs()
fs.copy(src.path, dst.path)
fs.copy(src_location.path, dst_location.path)

else:
raise IOError(f"File copy from {src.protocol} to {dst.protocol} not supported.")
raise IOError(f"cp from {src_location} to {dst_location} is not supported.")


def rm(loc):
def rm(path):
"""Remove a file."""
location = Location(loc)
location = Location(path)

# local
if location.protocol == "local":
Expand All @@ -186,40 +177,38 @@ def rm(loc):
fs.rm(location.path)

else:
raise IOError(f"{location.protocol} protocol not supported.")
raise IOError(f"fm for {location} is not supported.")


def mv(src, dst):
def mv(src_path, dst_path):
"""Move a file inside a filesystem.
Moving files from a filesystem to another is not supported. Use
copy() and rm() instead.
"""
src, dst = Location(src), Location(dst)
src_location, dst_location = Location(src_path), Location(dst_path)

# local
if src.protocol == "local" and dst.protocol == "local":
shutil.move(src.path, dst.path)
if src_location.protocol == "local" and dst_location.protocol == "local":
shutil.move(src_location.path, dst_location.path)

# s3
elif src.protocol == "s3" and dst.protocol == "s3":
elif src_location.protocol == "s3" and dst_location.protocol == "s3":
fs = get_s3fs()
fs.move(src.path, dst.path)
fs.move(src_location.path, dst_location.path)

# gcs
elif src.protocol == "gcs" and dst.protocol == "gcs":
elif src_location.protocol == "gcs" and dst_location.protocol == "gcs":
fs = get_gcsfs()
fs.move(src.path, dst.path)
fs.move(src_location.path, dst_location.path)

else:
raise IOError(
f"Moving files from {src.protocol} to {dst.protocol} not supported."
)
raise IOError(f"mv from {src_location} to {dst_location} is not supported.")


def exists(loc):
def exists(path):
"""Check if a file exists."""
location = Location(loc)
location = Location(path)

# local
if location.protocol == "local":
Expand All @@ -236,27 +225,26 @@ def exists(loc):
return fs.exists(location.path)

else:
raise IOError(
f'The "{location.protocol}" protocol is not supported ({location}).'
)
raise IOError(f"exists for {location} is not supported.")


def mkdir(loc):
def mkdir(path):
"""Create directories recursively, ignore if they already exists.
This is not needed for S3 and GCS as directories cannot be created and
are not needed anyway.
"""
location = Location(loc)
location = Location(path)
if location.protocol == "local":
os.makedirs(location.path, exist_ok=True)


def size(loc):
def size(path):
"""Get size of a file in bytes."""
location = Location(loc)
if not exists(location.location):
raise FileNotFoundError(f"No file found at {location.location}.")
if not exists(path):
raise FileNotFoundError(f"No file found at {path}.")

location = Location(path)

if location.protocol == "local":
return os.path.getsize(location.path)
Expand All @@ -270,7 +258,7 @@ def size(loc):
return fs.size(location.path)

else:
raise IOError(f"{location.protocol} not supported.")
raise IOError(f"size for {location} is not supported.")


def glob(pattern):
Expand All @@ -293,12 +281,12 @@ def glob(pattern):
return [f"gcs://{path}" for path in fs.glob(location.path)]

else:
raise IOError(f"{location.protocol} not supported.")
raise IOError(f"glob for {location} is not supported.")


def open_(loc, mode="r"):
def open_(path, mode="r"):
"""Return a file-like object regardless of the file system."""
location = Location(loc)
location = Location(path)

if location.protocol == "local":
return open(location.path, mode)
Expand All @@ -312,37 +300,39 @@ def open_(loc, mode="r"):
return fs.open(location.path, mode)

else:
raise IOError(f"{location.protocol} not supported.")
raise IOError(f"open_ for {location} is not supported.")


def unzip(src_file, dst_dir):
def unzip(src_file_path, dst_dir_path):
"""Extract contents of a .zip archive in dst_dir.
Can read .zip file from a cloud filesystem and copy its contents
to another cloud filesystem, but processing is performed locally.
"""
src_file, dst_dir = Location(src_file), Location(dst_dir)
src_file_location, dst_dir_location = (
Location(src_file_path),
Location(dst_dir_path),
)

with TemporaryDirectory(prefix="geohealthaccess_") as tmp_dir:

if src_file.protocol == "local":
with zipfile.ZipFile(src_file.path, "r") as z:
with zipfile.ZipFile(src_file_location.path, "r") as z:
z.extractall(tmp_dir)

elif src_file.protocol == "s3":
fs = get_s3fs()
with fs.open(src_file.path) as archive:
with fs.open(src_file_location.path) as archive:
with zipfile.ZipFile(archive, "r") as z:
z.extractall(tmp_dir)

elif src_file.protocol == "gcs":
fs = get_gcsfs()
with fs.open(src_file.path) as archive:
with fs.open(src_file_location.path) as archive:
with zipfile.ZipFile(archive, "r") as z:
z.extractall(tmp_dir)

else:
raise IOError(f"{src_file.protocol} not supported.")
raise IOError(f"unzip for {src_file_location} is not supported.")

for f in os.listdir(tmp_dir):
cp(os.path.join(tmp_dir, f), os.path.join(dst_dir.location, f))
cp(os.path.join(tmp_dir, f), os.path.join(dst_dir_path, f))

0 comments on commit d408560

Please sign in to comment.