Skip to content

Commit

Permalink
black format
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Sep 13, 2023
1 parent 4799a4e commit 7537905
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 47 deletions.
47 changes: 24 additions & 23 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from baskerville import layers


############################################################
# Convolution
############################################################
Expand Down Expand Up @@ -892,7 +893,7 @@ def conv_tower(
divisible_by=1,
repeat=1,
reprs=[],
**kwargs
**kwargs,
):
"""Construct a reducing convolution block.
Expand Down Expand Up @@ -943,7 +944,7 @@ def conv_tower_nac(
divisible_by=1,
repeat=1,
reprs=[],
**kwargs
**kwargs,
):
"""Construct a reducing convolution block.
Expand Down Expand Up @@ -1000,7 +1001,7 @@ def res_tower(
repeat=1,
num_convs=2,
reprs=[],
**kwargs
**kwargs,
):
"""Construct a reducing convolution block.
Expand Down Expand Up @@ -1087,7 +1088,7 @@ def convnext_tower(
repeat=1,
num_convs=2,
reprs=[],
**kwargs
**kwargs,
):
"""Abc.
Expand Down Expand Up @@ -1129,7 +1130,7 @@ def _round(x):
filters=rep_filters_int,
kernel_size=kernel_size,
dropout=dropout,
**kwargs
**kwargs,
)
current0 = current

Expand All @@ -1141,7 +1142,7 @@ def _round(x):
filters=rep_filters_int,
kernel_size=kernel_size,
dropout=dropout,
**kwargs
**kwargs,
)

# residual add
Expand Down Expand Up @@ -1187,7 +1188,7 @@ def transformer(
qkv_width=1,
mha_initializer="he_normal",
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a transformer block.
Expand Down Expand Up @@ -1255,7 +1256,7 @@ def transformer_split(
qkv_width=1,
mha_initializer="he_normal",
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a transformer block.
Expand Down Expand Up @@ -1393,7 +1394,7 @@ def transformer2(
dropout=0.25,
dense_expansion=2.0,
qkv_width=1,
**kwargs
**kwargs,
):
"""Construct a transformer block, with length-wise pooling before
returning to full length.
Expand All @@ -1416,7 +1417,7 @@ def transformer2(
filters=min(4 * key_size, inputs.shape[-1]),
kernel_size=3,
pool_size=2,
**kwargs
**kwargs,
)

# layer norm
Expand Down Expand Up @@ -1517,7 +1518,7 @@ def squeeze_excite(
additive=False,
norm_type=None,
bn_momentum=0.9,
**kwargs
**kwargs,
):
return layers.SqueezeExcite(
activation, additive, bottleneck_ratio, norm_type, bn_momentum
Expand Down Expand Up @@ -1545,7 +1546,7 @@ def dilated_dense(
conv_type="standard",
dropout=0,
repeat=1,
**kwargs
**kwargs,
):
"""Construct a residual dilated dense block.
Expand All @@ -1570,7 +1571,7 @@ def dilated_dense(
kernel_size=kernel_size,
dilation_rate=int(np.round(dilation_rate)),
conv_type=conv_type,
**kwargs
**kwargs,
)

# dense concat
Expand All @@ -1592,7 +1593,7 @@ def dilated_residual(
conv_type="standard",
norm_type=None,
round=False,
**kwargs
**kwargs,
):
"""Construct a residual dilated convolution block.
Expand All @@ -1619,7 +1620,7 @@ def dilated_residual(
conv_type=conv_type,
norm_type=norm_type,
norm_gamma="ones",
**kwargs
**kwargs,
)

# return
Expand All @@ -1629,7 +1630,7 @@ def dilated_residual(
dropout=dropout,
norm_type=norm_type,
norm_gamma="zeros",
**kwargs
**kwargs,
)

# InitZero
Expand Down Expand Up @@ -1672,7 +1673,7 @@ def dilated_residual_nac(
filters=filters,
kernel_size=kernel_size,
dilation_rate=int(np.round(dilation_rate)),
**kwargs
**kwargs,
)

# return
Expand All @@ -1697,7 +1698,7 @@ def dilated_residual_2d(
dropout=0,
repeat=1,
symmetric=True,
**kwargs
**kwargs,
):
"""Construct a residual dilated convolution block."""

Expand All @@ -1717,7 +1718,7 @@ def dilated_residual_2d(
kernel_size=kernel_size,
dilation_rate=int(np.round(dilation_rate)),
norm_gamma="ones",
**kwargs
**kwargs,
)

# return
Expand All @@ -1726,7 +1727,7 @@ def dilated_residual_2d(
filters=rep_input.shape[-1],
dropout=dropout,
norm_gamma="zeros",
**kwargs
**kwargs,
)

# residual add
Expand Down Expand Up @@ -1818,7 +1819,7 @@ def dense_block(
bn_momentum=0.99,
norm_gamma=None,
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a single convolution block.
Expand Down Expand Up @@ -1909,7 +1910,7 @@ def dense_nac(
bn_momentum=0.99,
norm_gamma=None,
kernel_initializer="he_normal",
**kwargs
**kwargs,
):
"""Construct a single convolution block.
Expand Down Expand Up @@ -1991,7 +1992,7 @@ def final(
kernel_initializer="he_normal",
l2_scale=0,
l1_scale=0,
**kwargs
**kwargs,
):
"""Final simple transformation before comparison to targets.
Expand Down
1 change: 1 addition & 0 deletions src/baskerville/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
for device in gpu_devices:
tf.config.experimental.set_memory_growth(device, True)


################################################################################
# Losses
################################################################################
Expand Down
1 change: 1 addition & 0 deletions src/baskerville/scripts/hound_eval_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Test the accuracy of a trained model on targets/predictions normalized across targets.
"""


################################################################################
# main
################################################################################
Expand Down
1 change: 1 addition & 0 deletions src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Compute variant effect predictions for SNPs in a VCF file.
"""


################################################################################
# main
################################################################################
Expand Down
1 change: 1 addition & 0 deletions src/baskerville/scripts/hound_snp_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
parallelized across a slurm cluster.
"""


################################################################################
# main
################################################################################
Expand Down
1 change: 0 additions & 1 deletion src/baskerville/scripts/hound_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def main():
strategy = tf.distribute.MirroredStrategy()

with strategy.scope():

if not args.keras_fit:
# distribute data
for di in range(len(args.data_dirs)):
Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def predict(
stream: bool = False,
step: int = 1,
dtype: str = "float32",
**kwargs
**kwargs,
):
"""Predict targets for SeqDataset, with more options.
Expand Down
12 changes: 6 additions & 6 deletions src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def cluster_snps(snps, seq_len: int, center_pct: float):

def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df):
"""Initialize an output HDF5 file for SAD stats.
Args:
out_dir (str): Output directory.
snp_stats [str]: List of SAD stats to compute.
Expand Down Expand Up @@ -287,7 +287,7 @@ def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df):

def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele):
"""Return alternative allele one hot coding.
Args:
ref_1hot (np.array): Reference allele one hot coding.
snp_seq_pos (int): SNP position in sequence.
Expand Down Expand Up @@ -335,7 +335,7 @@ def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele):

def make_strand_transform(targets_df, targets_strand_df):
"""Make a sparse matrix to sum strand pairs.
Args:
targets_df (pd.DataFrame): Targets DataFrame.
targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed.
Expand Down Expand Up @@ -365,7 +365,7 @@ def make_strand_transform(targets_df, targets_strand_df):

def write_pct(scores_out, snp_stats):
"""Compute percentile values for each target and write to HDF5.
Args:
scores_out (h5py.File): Output HDF5 file.
snp_stats [str]: List of SAD stats to compute.
Expand Down Expand Up @@ -395,7 +395,7 @@ def write_pct(scores_out, snp_stats):
def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats):
"""Write SNP predictions to HDF, assuming the length dimension has
been collapsed.
Args:
ref_preds_sum (np.array): Reference allele predictions.
alt_preds_sum (np.array): Alternative allele predictions.
Expand All @@ -413,7 +413,7 @@ def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats):
def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
"""Write SNP predictions to HDF, assuming the length dimension has
been maintained.
Args:
ref_preds (np.array): Reference allele predictions.
alt_preds (np.array): Alternative allele predictions.
Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def adaptive_clip_grad(
):
"""Adaptive gradient clipping."""
new_grads = []
for (params, grads) in zip(parameters, gradients):
for params, grads in zip(parameters, gradients):
p_norm = unitwise_norm(params)
max_norm = tf.math.maximum(p_norm, eps) * clip_factor
grad_norm = unitwise_norm(grads)
Expand Down
4 changes: 0 additions & 4 deletions src/baskerville/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,11 @@ def snp_seq1(snp, seq_len, genome_open):
seq_ref = seq[left_len : left_len + len(snp.ref_allele)]
ref_found = True
if seq_ref != snp.ref_allele:

# search for reference allele in alternatives
ref_found = False

# for each alternative allele
for alt_al in snp.alt_alleles:

# grab reference sequence matching alt length
seq_ref_alt = seq[left_len : left_len + len(alt_al)]
if seq_ref_alt == alt_al:
Expand Down Expand Up @@ -314,13 +312,11 @@ def snps_seq1(snps, seq_len, genome_fasta, return_seqs=False):
# verify that ref allele matches ref sequence
seq_ref = seq[left_len : left_len + len(snp.ref_allele)]
if seq_ref != snp.ref_allele:

# search for reference allele in alternatives
ref_found = False

# for each alternative allele
for alt_al in snp.alt_alleles:

# grab reference sequence matching alt length
seq_ref_alt = seq[left_len : left_len + len(alt_al)]
if seq_ref_alt == alt_al:
Expand Down
Loading

0 comments on commit 7537905

Please sign in to comment.