Skip to content

Commit

Permalink
merge from main 1.29.24
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Jan 29, 2024
2 parents 6502e52 + 929993d commit 9fbfad3
Show file tree
Hide file tree
Showing 42 changed files with 2,254 additions and 254 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Baskerville Docs

on:
workflow_dispatch:
inputs:
python-version:
default: "3.10"
required: false
type: string

defaults:
run:
shell: bash

permissions:
contents: write
jobs:
docs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
# You can test your matrix by printing the current Python version
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
run: |
cd ${{ github.workspace }}/src/docs
pip install -r requirements.txt
- name: Sphinx build
run: |
cd ${{ github.workspace }}/src/docs/source
rm -f *.rst make.bat
cd ${{ github.workspace }}/src
sphinx-apidoc -F -a -o docs/source baskerville
cd ${{ github.workspace }}/src/docs
make html
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
publish_branch: gh-pages
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ${{ github.workspace }}/src/docs/build/html
force_orphan: true
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ Baskerville provides researchers with tools to:

---

### Documentations

Documentation page: https://calico.github.io/baskerville/index.html

---

### Installation

`git clone [email protected]:calico/baskerville.git`
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ authors = [
readme = "README.md"
requires-python = ">=3.8, <3.11"
classifiers = ["License :: OSI Approved :: Apache License"]
dynamic = ["version", "description"]
dynamic = ["version", "description", "dependencies"]

[project.optional-dependencies]
dev = [
Expand Down
105 changes: 12 additions & 93 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def conv_next(
return current


def fpn_unet(
def unet_conv(
inputs,
unet_repr,
activation="relu",
Expand All @@ -486,6 +486,7 @@ def fpn_unet(
kernel_initializer="he_normal",
transfer_se=False,
se_ratio=16,
upsample_conv=False,
):
"""Construct a feature pyramid network block.
Expand All @@ -498,6 +499,7 @@ def fpn_unet(
dropout: Dropout rate probability
norm_type: Apply batch or layer normalization
bn_momentum: BatchNorm momentum
upsample_conv: Conv1D the upsampled input path
Returns:
[batch_size, seq_length, features] output sequence
Expand Down Expand Up @@ -529,11 +531,12 @@ def fpn_unet(
filters = inputs.shape[-1]

# dense
current1 = tf.keras.layers.Dense(
units=filters,
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current1)
if upsample_conv:
current1 = tf.keras.layers.Dense(
units=filters,
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current1)
current2 = tf.keras.layers.Dense(
units=filters,
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
Expand All @@ -544,7 +547,6 @@ def fpn_unet(
current1 = tf.keras.layers.UpSampling1D(size=stride)(current1)

# add
# current2 = layers.Scale(initializer='ones')(current2)
current = tf.keras.layers.Add()([current1, current2])

# normalize?
Expand Down Expand Up @@ -577,83 +579,7 @@ def fpn_unet(
return current


def fpn1_unet(
inputs,
unet_repr,
activation="relu",
stride=2,
l2_scale=0,
dropout=0,
norm_type=None,
bn_momentum=0.99,
kernel_size=1,
kernel_initializer="he_normal",
):
"""Construct a feature pyramid network block.
Args:
inputs: [batch_size, seq_length, features] input sequence
kernel_size: Conv1D kernel_size
activation: relu/gelu/etc
stride: UpSample stride
l2_scale: L2 regularization weight.
dropout: Dropout rate probability
norm_type: Apply batch or layer normalization
bn_momentum: BatchNorm momentum
Returns:
[batch_size, seq_length, features] output sequence
"""

# variables
current1 = inputs
current2 = unet_repr

# normalize
if norm_type == "batch-sync":
current1 = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current1)
elif norm_type == "batch":
current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1)
elif norm_type == "layer":
current1 = tf.keras.layers.LayerNormalization()(current1)

# activate
current1 = layers.activate(current1, activation)
# current2 = layers.activate(current2, activation)

# dense
current1 = tf.keras.layers.Dense(
units=unet_repr.shape[-1],
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current1)

# upsample
current1 = tf.keras.layers.UpSampling1D(size=stride)(current1)

# add
current2 = layers.Scale(initializer="ones")(current2)
current = tf.keras.layers.Add()([current1, current2])

# convolution
current = tf.keras.layers.SeparableConv1D(
filters=unet_repr.shape[-1],
kernel_size=kernel_size,
padding="same",
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current)

# dropout
if dropout > 0:
current = tf.keras.layers.Dropout(dropout)(current)

return current


def upsample_unet(
def unet_concat(
inputs,
unet_repr,
activation="relu",
Expand Down Expand Up @@ -815,11 +741,6 @@ def tconv_nac(
return current


def concat_unet(inputs, unet_repr, **kwargs):
current = tf.keras.layers.Concatenate()([inputs, unet_repr])
return current


def conv_block_2d(
inputs,
filters=128,
Expand Down Expand Up @@ -2100,7 +2021,6 @@ def final(
"center_average": center_average,
"concat_dist_2d": concat_dist_2d,
"concat_position": concat_position,
"concat_unet": concat_unet,
"conv_block": conv_block,
"conv_dna": conv_dna,
"conv_nac": conv_nac,
Expand All @@ -2127,10 +2047,9 @@ def final(
"tconv_nac": tconv_nac,
"transformer": transformer,
"transformer_tower": transformer_tower,
"unet_conv": unet_conv,
"unet_concat": unet_concat,
"upper_tri": upper_tri,
"fpn_unet": fpn_unet,
"fpn1_unet": fpn1_unet,
"upsample_unet": upsample_unet,
"wheeze_excite": wheeze_excite,
}

Expand Down
35 changes: 33 additions & 2 deletions src/baskerville/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from natsort import natsorted
import numpy as np
import pandas as pd
from scipy.sparse import dok_matrix
import tensorflow as tf

gpu_devices = tf.config.experimental.list_physical_devices("GPU")
Expand Down Expand Up @@ -310,6 +311,36 @@ def numpy(
return targets


def make_strand_transform(targets_df, targets_strand_df):
"""Make a sparse matrix to sum strand pairs.
Args:
targets_df (pd.DataFrame): Targets DataFrame.
targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed.
Returns:
scipy.sparse.csr_matrix: Sparse matrix to sum strand pairs.
"""

# initialize sparse matrix
strand_transform = dok_matrix((targets_df.shape[0], targets_strand_df.shape[0]))

# fill in matrix
ti = 0
sti = 0
for _, target in targets_df.iterrows():
strand_transform[ti, sti] = True
if target.strand_pair == target.name:
sti += 1
else:
if target.identifier[-1] == "-":
sti += 1
ti += 1
strand_transform = strand_transform.tocsr()

return strand_transform


def targets_prep_strand(targets_df):
"""Adjust targets table for merged stranded datasets.
Expand Down Expand Up @@ -351,9 +382,9 @@ def untransform_preds(preds, targets_df, unscale=False):
preds_unclip = cs - 1 + (preds - cs + 1) ** 2
preds = np.where(preds > cs, preds_unclip, preds)

# ** 0.75
# sqrt
sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat])
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** (4 / 3)
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** 2 # (4 / 3)

# scale
if unscale:
Expand Down
Empty file.
60 changes: 52 additions & 8 deletions src/baskerville/helpers/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,43 @@ def download_from_gcs(gcs_path: str, local_path: str, bytes=True) -> None:
storage_client.download_blob_to_file(gcs_path, o)


def download_folder_from_gcs(gcs_dir: str, local_dir: str, bytes=True) -> None:
"""
Downloads a whole folder from GCS
Args:
gcs_dir: string path to GCS folder to download
local_dir: string path to download to
bytes: boolean flag indicating if gcs file contains bytes
Returns: None
"""
storage_client = _get_storage_client()
write_mode = "wb" if bytes else "w"
if not is_gcs_path(gcs_dir):
raise ValueError(f"gcs_dir is not a valid GCS path: {gcs_dir}")
bucket_name, gcs_object_prefix = split_gcs_uri(gcs_dir)
# Get the bucket from the client.
bucket = storage_client.bucket(bucket_name)

# Ensure local folder exists
if not os.path.exists(local_dir):
os.makedirs(local_dir)
# List all blobs with the given prefix (i.e., folder path).
blobs = bucket.list_blobs(prefix=gcs_object_prefix)
# Download each blob.
for blob in blobs:
# Compute the full path to which we'll download the blob.
blob_rel_path = os.path.relpath(blob.name, gcs_object_prefix)
local_blob_path = os.path.join(local_dir, blob_rel_path)

# Ensure the local directory structure exists
local_blob_dir = os.path.dirname(local_blob_path)
if not os.path.exists(local_blob_dir):
os.makedirs(local_blob_dir)
download_from_gcs(join(gcs_dir, blob_rel_path), local_blob_path, bytes=bytes)


def sync_dir_to_gcs(
local_dir: str, gcs_dir: str, verbose=False, recursive=False
) -> None:
Expand Down Expand Up @@ -120,7 +157,7 @@ def upload_folder_gcs(local_dir: str, gcs_dir: str) -> None:
"""
storage_client = _get_storage_client()
bucket_name = gcs_dir.split("//")[1].split("/")[0]
gcs_object_prefix = gcs_dir.split("//")[1].split("/")[1]
gcs_object_prefix = "/".join(gcs_dir.split("//")[1].split("/")[1:])
local_prefix = local_dir.split("/")[-1]
bucket = storage_client.bucket(bucket_name)
for filename in os.listdir(local_dir):
Expand All @@ -142,7 +179,7 @@ def upload_file_gcs(local_path: str, gcs_path: str, bytes=True) -> None:
storage_client = _get_storage_client()
bucket_name = gcs_path.split("//")[1].split("/")[0]
bucket = storage_client.bucket(bucket_name)
gcs_object_prefix = gcs_path.split("//")[1].split("/")[1]
gcs_object_prefix = "/".join(gcs_path.split("//")[1].split("/")[1:])
filename = local_path.split("/")[-1]
blob = bucket.blob(f"{gcs_object_prefix}/{filename}")
blob.upload_from_filename(local_path)
Expand Down Expand Up @@ -207,18 +244,25 @@ def get_filename_in_dir(files_dir: str, recursive: bool = False) -> List[str]:
return files


def download_rename_inputs(filepath: str, temp_dir: str) -> str:
def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) -> str:
"""
Download file from gcs to local dir
Args:
filepath: GCS Uri follows the format gs://$BUCKET_NAME/OBJECT_NAME
temp_dir: local dir to download to
is_dir: boolean flag indicating if the filepath is a directory
Returns: new filepath in the local machine
"""
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"
if is_dir:
download_folder_from_gcs(filepath, temp_dir)
dir_name = filepath.split("/")[-1]
return temp_dir
else:
_, filename = split_gcs_uri(filepath)
if "/" in filename:
filename = filename.split("/")[-1]
download_from_gcs(filepath, f"{temp_dir}/{filename}")
return f"{temp_dir}/{filename}"


def gcs_file_exist(gcs_path: str) -> bool:
Expand Down
Loading

0 comments on commit 9fbfad3

Please sign in to comment.