Skip to content

Commit

Permalink
fix ref panel upper case
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Sep 26, 2023
1 parent b7ad7c4 commit 340ff9e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 36 deletions.
111 changes: 77 additions & 34 deletions src/baskerville/helpers/gcs_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# taken and modified from calico-ukbb-mri-ml repo
# https://github.com/calico/calicolabs-ukbb-mri-ml/tree/main/src/ukbb_mri_ml/helpers
# =========================================================================

import os
import logging

Expand All @@ -11,29 +15,36 @@
from google.auth.exceptions import DefaultCredentialsError

logger = logging.getLogger(__name__)
PREFIX = ""

logger = logging.getLogger(__name__)

try:
# Attempt to infer credentials from environment
storage_client = Client()
logger.info("Inferred credentials from environment")
except DefaultCredentialsError:

def _get_storage_client() -> Client:
"""
Returns: Google Cloud Storage Client
"""
try:
# Attempt to load JSON credentials from GOOGLE_APPLICATION_CREDENTIALS
storage_client = Client.from_service_account_info(
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
)
logger.info("Loaded credentials from GOOGLE_APPLICATION_CREDENTIALS")
except AttributeError:
# Attempt to load JSON credentials from base64 encoded string
storage_client = Client.from_service_account_info(
loads(
b64decode(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]).decode("utf-8")
# Attempt to infer credentials from environment
storage_client = Client()
logger.info("Inferred credentials from environment")
except DefaultCredentialsError:
try:
# Attempt to load JSON credentials from GOOGLE_APPLICATION_CREDENTIALS
storage_client = Client.from_service_account_info(
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
)
)
logger.info("Loaded credentials from base64 encoded string")
logger.info("Loaded credentials from GOOGLE_APPLICATION_CREDENTIALS")
except AttributeError:
# Attempt to load JSON credentials from base64 encoded string
storage_client = Client.from_service_account_info(
loads(
b64decode(os.environ["GOOGLE_APPLICATION_CREDENTIALS"]).decode(
"utf-8"
)
)
)
logger.info("Loaded credentials from base64 encoded string")
return storage_client


def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None:
Expand All @@ -47,6 +58,7 @@ def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None:
Returns: None
"""
storage_client = _get_storage_client()
write_mode = "wb" if bytes else "w"
with open(local_path, write_mode) as o:
storage_client.download_blob_to_file(gcs_path, o)
Expand All @@ -66,6 +78,7 @@ def sync_dir_to_gcs(
Returns: None
"""
storage_client = _get_storage_client()
if not is_gcs_path(gcs_dir):
raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}")

Expand Down Expand Up @@ -105,6 +118,7 @@ def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None:
gcs_dir: string GCS destination path. Will create folders that do not exist.
Returns: None
"""
storage_client = _get_storage_client()
bucket_name = gcs_dir.split("//")[1].split("/")[0]
gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1]
local_prefix = local_dir.split("/")[-1]
Expand All @@ -116,6 +130,24 @@ def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None:
blob.upload_from_filename(local_file)


def upload_file_gcs(local_path: str, gcs_path: str, bytes=True) -> None:
"""
Upload a file to gcs
Args:
local_path: local path to file
gcs_path: string GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME
Returns: None
"""
storage_client = _get_storage_client()
bucket_name = gcs_path.split("//")[1].split("/")[0]
bucket = storage_client.bucket(bucket_name)
gcs_object_prefix = gcs_path.split("//")[1].split("/")[1]
filename = local_path.split("/")[-1]
blob = bucket.blob(f"{gcs_object_prefix}/{filename}")
blob.upload_from_filename(local_path)


def gcs_join(*args):
args = [arg.replace("gs://", "").strip("/") for arg in args]
return "gs://" + join(*args)
Expand Down Expand Up @@ -156,34 +188,45 @@ def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]:
"""
# currently only Niftidataset receives gs bucket paths, so this isn't necessary
# commenting out for now even though it is functional (lots of files)
storage_client = _get_storage_client()
if is_gcs_path(files_dir):
bucket_name, object_name = split_gcs_uri(files_dir)
blob_iterator = storage_client.list_blobs(bucket_name, prefix=object_name)
# return [str(blob) for blob in blob_iterator if "/" not in blob.name]
return [str(blob) for blob in blob_iterator]
else:
dir_contents = os.listdir(files_dir)
files = []
for entry in dir_contents:
entry_path = join(files_dir, entry)
if isfile(entry_path):
files.append(entry_path)
elif recursive:
files.extend(get_filename_in_dir(entry_path, recursive=recursive))
else:
pass
return files
dir_contents = os.listdir(files_dir)
files = []
for entry in dir_contents:
entry_path = join(files_dir, entry)
if isfile(entry_path):
files.append(entry_path)
elif recursive:
files.extend(get_filename_in_dir(entry_path, recursive=recursive))
else:
print("Nothing happened here")
pass
return files


def download_rename_inputs(filepath: str, temp_dir: str):
def download_rename_inputs(filepath: str, temp_dir: str) -> str:
"""
Download file from gcs to local dir
Args:
filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME
Returns: new filepath in the local machine
"""
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"


def gcs_file_exist(gcs_path: str) -> bool:
"""
check if a file exist in gcs
params: gcs_path
returns: true/false
"""
storage_client = _get_storage_client()
bucket, filename = split_gcs_uri(gcs_path)
bucket = storage_client.bucket(bucket)
blob = bucket.blob(filename)
return blob.exists()
3 changes: 1 addition & 2 deletions src/baskerville/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,7 @@ def vcf_snps(
if validate_ref_fasta is not None:
ref_n = len(snps[-1].ref_allele)
snp_pos = snps[-1].pos - 1
ref_snp = genome_open.fetch(snps[-1].chr, snp_pos, snp_pos + ref_n)

ref_snp = genome_open.fetch(snps[-1].chr, snp_pos, snp_pos + ref_n).upper()
if snps[-1].ref_allele != ref_snp:
if not flip_ref:
# bail
Expand Down

0 comments on commit 340ff9e

Please sign in to comment.