Skip to content

Commit

Permalink
Merge pull request #2 from calico/snpclust
Browse files Browse the repository at this point in the history
Snpclust
  • Loading branch information
lruizcalico authored Sep 13, 2023
2 parents 8688ad4 + a2c5e3a commit cf08b90
Show file tree
Hide file tree
Showing 10 changed files with 430 additions and 182 deletions.
47 changes: 24 additions & 23 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from baskerville import layers


############################################################
# Convolution
############################################################
Expand Down Expand Up @@ -892,7 +893,7 @@ def conv_tower(
divisible_by=1,
repeat=1,
reprs=[],
**kwargs
**kwargs,
):
"""Construct a reducing convolution block.
Expand Down Expand Up @@ -943,7 +944,7 @@ def conv_tower_nac(
divisible_by=1,
repeat=1,
reprs=[],
**kwargs
**kwargs,
):
"""Construct a reducing convolution block.
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def res_tower(
repeat=1,
num_convs=2,
reprs=[],
**kwargs
**kwargs,
):
"""Construct a reducing convolution block.
Expand Down Expand Up @@ -1087,7 +1088,7 @@ def convnext_tower(
repeat=1,
num_convs=2,
reprs=[],
**kwargs
**kwargs,
):
"""Abc.
Expand Down Expand Up @@ -1129,7 +1130,7 @@ def _round(x):
filters=rep_filters_int,
kernel_size=kernel_size,
dropout=dropout,
**kwargs
**kwargs,
)
current0 = current

Expand All @@ -1141,7 +1142,7 @@ def _round(x):
filters=rep_filters_int,
kernel_size=kernel_size,
dropout=dropout,
**kwargs
**kwargs,
)

# residual add
Expand Down Expand Up @@ -1187,7 +1188,7 @@ def transformer(
qkv_width=1,
mha_initializer="he_normal",
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a transformer block.
Expand Down Expand Up @@ -1255,7 +1256,7 @@ def transformer_split(
qkv_width=1,
mha_initializer="he_normal",
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a transformer block.
Expand Down Expand Up @@ -1393,7 +1394,7 @@ def transformer2(
dropout=0.25,
dense_expansion=2.0,
qkv_width=1,
**kwargs
**kwargs,
):
"""Construct a transformer block, with length-wise pooling before
returning to full length.
Expand All @@ -1416,7 +1417,7 @@ def transformer2(
filters=min(4 * key_size, inputs.shape[-1]),
kernel_size=3,
pool_size=2,
**kwargs
**kwargs,
)

# layer norm
Expand Down Expand Up @@ -1517,7 +1518,7 @@ def squeeze_excite(
additive=False,
norm_type=None,
bn_momentum=0.9,
**kwargs
**kwargs,
):
return layers.SqueezeExcite(
activation, additive, bottleneck_ratio, norm_type, bn_momentum
Expand Down Expand Up @@ -1545,7 +1546,7 @@ def dilated_dense(
conv_type="standard",
dropout=0,
repeat=1,
**kwargs
**kwargs,
):
"""Construct a residual dilated dense block.
Expand All @@ -1570,7 +1571,7 @@ def dilated_dense(
kernel_size=kernel_size,
dilation_rate=int(np.round(dilation_rate)),
conv_type=conv_type,
**kwargs
**kwargs,
)

# dense concat
Expand All @@ -1592,7 +1593,7 @@ def dilated_residual(
conv_type="standard",
norm_type=None,
round=False,
**kwargs
**kwargs,
):
"""Construct a residual dilated convolution block.
Expand All @@ -1619,7 +1620,7 @@ def dilated_residual(
conv_type=conv_type,
norm_type=norm_type,
norm_gamma="ones",
**kwargs
**kwargs,
)

# return
Expand All @@ -1629,7 +1630,7 @@ def dilated_residual(
dropout=dropout,
norm_type=norm_type,
norm_gamma="zeros",
**kwargs
**kwargs,
)

# InitZero
Expand Down Expand Up @@ -1672,7 +1673,7 @@ def dilated_residual_nac(
filters=filters,
kernel_size=kernel_size,
dilation_rate=int(np.round(dilation_rate)),
**kwargs
**kwargs,
)

# return
Expand All @@ -1697,7 +1698,7 @@ def dilated_residual_2d(
dropout=0,
repeat=1,
symmetric=True,
**kwargs
**kwargs,
):
"""Construct a residual dilated convolution block."""

Expand All @@ -1717,7 +1718,7 @@ def dilated_residual_2d(
kernel_size=kernel_size,
dilation_rate=int(np.round(dilation_rate)),
norm_gamma="ones",
**kwargs
**kwargs,
)

# return
Expand All @@ -1726,7 +1727,7 @@ def dilated_residual_2d(
filters=rep_input.shape[-1],
dropout=dropout,
norm_gamma="zeros",
**kwargs
**kwargs,
)

# residual add
Expand Down Expand Up @@ -1818,7 +1819,7 @@ def dense_block(
bn_momentum=0.99,
norm_gamma=None,
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a single convolution block.
Expand Down Expand Up @@ -1909,7 +1910,7 @@ def dense_nac(
bn_momentum=0.99,
norm_gamma=None,
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a single convolution block.
Expand Down Expand Up @@ -1991,7 +1992,7 @@ def final(
kernel_initializer="he_normal",
l2_scale=0,
l1_scale=0,
**kwargs
**kwargs,
):
"""Final simple transformation before comparison to targets.
Expand Down
1 change: 1 addition & 0 deletions src/baskerville/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
for device in gpu_devices:
tf.config.experimental.set_memory_growth(device, True)


################################################################################
# Losses
################################################################################
Expand Down
1 change: 1 addition & 0 deletions src/baskerville/scripts/hound_eval_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Test the accuracy of a trained model on targets/predictions normalized across targets.
"""


################################################################################
# main
################################################################################
Expand Down
23 changes: 17 additions & 6 deletions src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,28 @@
import pdb
import pickle
import os
from baskerville.snps import calculate_sad
from baskerville.snps import score_snps

"""
hound_snp.py
Compute variant effect predictions for SNPs in a VCF file.
"""


################################################################################
# main
################################################################################
def main():
usage = "usage: %prog [options] <params_file> <model_file> <vcf_file>"
parser = OptionParser(usage)
parser.add_option(
"-c",
dest="cluster_snps_pct",
default=0,
type="float",
help="Cluster SNPs within a %% of the seq length to make a single ref pred [Default: %default]",
)
parser.add_option(
"-f",
dest="genome_fasta",
Expand Down Expand Up @@ -66,8 +74,8 @@ def main():
)
parser.add_option(
"--stats",
dest="sad_stats",
default="SAD",
dest="snp_stats",
default="logSAD",
help="Comma-separated list of stats to save. [Default: %default]",
)
parser.add_option(
Expand Down Expand Up @@ -129,17 +137,20 @@ def main():
else:
parser.error("Must provide parameters and model files and QTL VCF file")

if options.targets_file is None:
parser.error("Must provide targets file")

if not os.path.isdir(options.out_dir):
os.mkdir(options.out_dir)

options.shifts = [int(shift) for shift in options.shifts.split(",")]
options.sad_stats = options.sad_stats.split(",")
options.snp_stats = options.snp_stats.split(",")

# calculate SAD scores:
if options.processes is not None:
calculate_sad(params_file, model_file, vcf_file, worker_index, options)
score_snps(params_file, model_file, vcf_file, worker_index, options)
else:
calculate_sad(params_file, model_file, vcf_file, 0, options)
score_snps(params_file, model_file, vcf_file, 0, options)


################################################################################
Expand Down
21 changes: 18 additions & 3 deletions src/baskerville/scripts/hound_snp_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
parallelized across a slurm cluster.
"""


################################################################################
# main
################################################################################
Expand All @@ -41,6 +42,13 @@ def main():
parser = OptionParser(usage)

# snp
parser.add_option(
"-c",
dest="cluster_snps_pct",
default=0,
type="float",
help="Cluster SNPs within a %% of the seq length to make a single ref pred [Default: %default]",
)
parser.add_option(
"-f",
dest="genome_fasta",
Expand Down Expand Up @@ -69,8 +77,8 @@ def main():
)
parser.add_option(
"--stats",
dest="sad_stats",
default="SAD",
dest="snp_stats",
default="logSAD",
help="Comma-separated list of stats to save. [Default: %default]",
)
parser.add_option(
Expand All @@ -80,12 +88,19 @@ def main():
type="str",
help="File specifying target indexes and labels in table format",
)
parser.add_option(
"-u",
dest="untransform_old",
default=False,
action="store_true",
help="Untransform old models [Default: %default]",
)

# multi
parser.add_option(
"-e",
dest="conda_env",
default="tf210",
default="tf12",
help="Anaconda environment [Default: %default]",
)
parser.add_option(
Expand Down
1 change: 0 additions & 1 deletion src/baskerville/scripts/hound_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def main():
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():

if not args.keras_fit:
# distribute data
for di in range(len(args.data_dirs)):
Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def predict(
stream: bool = False,
step: int = 1,
dtype: str = "float32",
**kwargs
**kwargs,
):
"""Predict targets for SeqDataset, with more options.
Expand Down
Loading

0 comments on commit cf08b90

Please sign in to comment.