diff --git a/src/baskerville/helpers/gcs_utils.py b/src/baskerville/helpers/gcs_utils.py index 72a853e..2c5b40a 100644 --- a/src/baskerville/helpers/gcs_utils.py +++ b/src/baskerville/helpers/gcs_utils.py @@ -64,6 +64,43 @@ def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None: storage_client.download_blob_to_file(gcs_path, o) +def download_folder_from_gcs(gcs_dir: str, local_dir: str, bytes=True) -> None: + """ + Downloads a whole folder from GCS + Args: + gcs_dir: string path to GCS folder to download + local_dir: string path to download to + bytes: boolean flag indicating if gcs file contains bytes + + Returns: None + + """ + storage_client = _get_storage_client() + write_mode = "wb" if bytes else "w" + if not is_gcs_path(gcs_dir): + raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}") + bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir) + # Get the bucket from the client. + bucket = storage_client.bucket(bucket_name) + + # Ensure local folder exists + if not os.path.exists(local_dir): + os.makedirs(local_dir) + # List all blobs with the given prefix (i.e., folder path). + blobs = bucket.list_blobs(prefix=gcs_object_prefix) + # Download each blob. + for blob in blobs: + # Compute the full path to which we'll download the blob. + blob_rel_path = os.path.relpath(blob.name, gcs_object_prefix) + local_blob_path = os.path.join(local_dir, blob_rel_path) + + # Ensure the local directory structure exists + local_blob_dir = os.path.dirname(local_blob_path) + if not os.path.exists(local_blob_dir): + os.makedirs(local_blob_dir) + download_from_gcs(join(gcs_dir, blob_rel_path), local_blob_path, bytes=bytes) + + def sync_dir_to_gcs( local_dir: str, gcs_dir: str, verbose=False, recursive=False ) -> None: @@ -207,18 +244,25 @@ def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]: return files -def download_rename_inputs(filepath: str, temp_dir: str) -> str: +def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) -> str: """ Download file from gcs to local dir Args: filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME + temp_dir: local dir to download to + is_dir: boolean flag indicating if the filepath is a directory Returns: new filepath in the local machine """ - _, filename = split_gcs_uri(filepath) - if "/" in filename: - filename = filename.split("/")[-1] - download_from_gcs(filepath, f"{temp_dir}/{filename}") - return f"{temp_dir}/{filename}" + if is_dir: + download_folder_from_gcs(filepath, temp_dir) + dir_name = filepath.split("/")[-1] + return f"{temp_dir}/{dir_name}" + else: + _, filename = split_gcs_uri(filepath) + if "/" in filename: + filename = filename.split("/")[-1] + download_from_gcs(filepath, f"{temp_dir}/{filename}") + return f"{temp_dir}/{filename}" def gcs_file_exist(gcs_path: str) -> bool: