diff --git a/src/baskerville/blocks.py b/src/baskerville/blocks.py index fab9a4b..2490f2e 100644 --- a/src/baskerville/blocks.py +++ b/src/baskerville/blocks.py @@ -18,6 +18,7 @@ from baskerville import layers + ############################################################ # Convolution ############################################################ @@ -892,7 +893,7 @@ def conv_tower( divisible_by=1, repeat=1, reprs=[], - **kwargs + **kwargs, ): """Construct a reducing convolution block. @@ -943,7 +944,7 @@ def conv_tower_nac( divisible_by=1, repeat=1, reprs=[], - **kwargs + **kwargs, ): """Construct a reducing convolution block. @@ -1000,7 +1001,7 @@ def res_tower( repeat=1, num_convs=2, reprs=[], - **kwargs + **kwargs, ): """Construct a reducing convolution block. @@ -1087,7 +1088,7 @@ def convnext_tower( repeat=1, num_convs=2, reprs=[], - **kwargs + **kwargs, ): """Abc. @@ -1129,7 +1130,7 @@ def _round(x): filters=rep_filters_int, kernel_size=kernel_size, dropout=dropout, - **kwargs + **kwargs, ) current0 = current @@ -1141,7 +1142,7 @@ def _round(x): filters=rep_filters_int, kernel_size=kernel_size, dropout=dropout, - **kwargs + **kwargs, ) # residual add @@ -1187,7 +1188,7 @@ def transformer( qkv_width=1, mha_initializer="he_normal", kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a transformer block. @@ -1255,7 +1256,7 @@ def transformer_split( qkv_width=1, mha_initializer="he_normal", kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a transformer block. @@ -1393,7 +1394,7 @@ def transformer2( dropout=0.25, dense_expansion=2.0, qkv_width=1, - **kwargs + **kwargs, ): """Construct a transformer block, with length-wise pooling before returning to full length. @@ -1416,7 +1417,7 @@ def transformer2( filters=min(4 * key_size, inputs.shape[-1]), kernel_size=3, pool_size=2, - **kwargs + **kwargs, ) # layer norm @@ -1517,7 +1518,7 @@ def squeeze_excite( additive=False, norm_type=None, bn_momentum=0.9, - **kwargs + **kwargs, ): return layers.SqueezeExcite( activation, additive, bottleneck_ratio, norm_type, bn_momentum @@ -1545,7 +1546,7 @@ def dilated_dense( conv_type="standard", dropout=0, repeat=1, - **kwargs + **kwargs, ): """Construct a residual dilated dense block. @@ -1570,7 +1571,7 @@ def dilated_dense( kernel_size=kernel_size, dilation_rate=int(np.round(dilation_rate)), conv_type=conv_type, - **kwargs + **kwargs, ) # dense concat @@ -1592,7 +1593,7 @@ def dilated_residual( conv_type="standard", norm_type=None, round=False, - **kwargs + **kwargs, ): """Construct a residual dilated convolution block. @@ -1619,7 +1620,7 @@ def dilated_residual( conv_type=conv_type, norm_type=norm_type, norm_gamma="ones", - **kwargs + **kwargs, ) # return @@ -1629,7 +1630,7 @@ def dilated_residual( dropout=dropout, norm_type=norm_type, norm_gamma="zeros", - **kwargs + **kwargs, ) # InitZero @@ -1672,7 +1673,7 @@ def dilated_residual_nac( filters=filters, kernel_size=kernel_size, dilation_rate=int(np.round(dilation_rate)), - **kwargs + **kwargs, ) # return @@ -1697,7 +1698,7 @@ def dilated_residual_2d( dropout=0, repeat=1, symmetric=True, - **kwargs + **kwargs, ): """Construct a residual dilated convolution block.""" @@ -1717,7 +1718,7 @@ def dilated_residual_2d( kernel_size=kernel_size, dilation_rate=int(np.round(dilation_rate)), norm_gamma="ones", - **kwargs + **kwargs, ) # return @@ -1726,7 +1727,7 @@ def dilated_residual_2d( filters=rep_input.shape[-1], dropout=dropout, norm_gamma="zeros", - **kwargs + **kwargs, ) # residual add @@ -1818,7 +1819,7 @@ def dense_block( bn_momentum=0.99, norm_gamma=None, kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a single convolution block. @@ -1909,7 +1910,7 @@ def dense_nac( bn_momentum=0.99, norm_gamma=None, kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a single convolution block. @@ -1991,7 +1992,7 @@ def final( kernel_initializer="he_normal", l2_scale=0, l1_scale=0, - **kwargs + **kwargs, ): """Final simple transformation before comparison to targets. diff --git a/src/baskerville/metrics.py b/src/baskerville/metrics.py index 8e947ba..29c0d99 100644 --- a/src/baskerville/metrics.py +++ b/src/baskerville/metrics.py @@ -24,6 +24,7 @@ for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True) + ################################################################################ # Losses ################################################################################ diff --git a/src/baskerville/scripts/hound_eval_spec.py b/src/baskerville/scripts/hound_eval_spec.py index e69d146..43d908c 100755 --- a/src/baskerville/scripts/hound_eval_spec.py +++ b/src/baskerville/scripts/hound_eval_spec.py @@ -35,6 +35,7 @@ Test the accuracy of a trained model on targets/predictions normalized across targets. """ + ################################################################################ # main ################################################################################ diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index 167335d..0eb8a7d 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -25,6 +25,7 @@ Compute variant effect predictions for SNPs in a VCF file. """ + ################################################################################ # main ################################################################################ diff --git a/src/baskerville/scripts/hound_snp_slurm.py b/src/baskerville/scripts/hound_snp_slurm.py index e6cb077..e2be5fd 100755 --- a/src/baskerville/scripts/hound_snp_slurm.py +++ b/src/baskerville/scripts/hound_snp_slurm.py @@ -33,6 +33,7 @@ parallelized across a slurm cluster. """ + ################################################################################ # main ################################################################################ diff --git a/src/baskerville/scripts/hound_train.py b/src/baskerville/scripts/hound_train.py index beec2e2..e7ec150 100755 --- a/src/baskerville/scripts/hound_train.py +++ b/src/baskerville/scripts/hound_train.py @@ -163,7 +163,6 @@ def main(): strategy = tf.distribute.MirroredStrategy() with strategy.scope(): - if not args.keras_fit: # distribute data for di in range(len(args.data_dirs)): diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 12592ca..a00c843 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -524,7 +524,7 @@ def predict( stream: bool = False, step: int = 1, dtype: str = "float32", - **kwargs + **kwargs, ): """Predict targets for SeqDataset, with more options. diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 5c86e17..d89a6ab 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -224,7 +224,7 @@ def cluster_snps(snps, seq_len: int, center_pct: float): def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df): """Initialize an output HDF5 file for SAD stats. - + Args: out_dir (str): Output directory. snp_stats [str]: List of SAD stats to compute. @@ -287,7 +287,7 @@ def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df): def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele): """Return alternative allele one hot coding. - + Args: ref_1hot (np.array): Reference allele one hot coding. snp_seq_pos (int): SNP position in sequence. @@ -335,7 +335,7 @@ def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele): def make_strand_transform(targets_df, targets_strand_df): """Make a sparse matrix to sum strand pairs. - + Args: targets_df (pd.DataFrame): Targets DataFrame. targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed. @@ -365,7 +365,7 @@ def make_strand_transform(targets_df, targets_strand_df): def write_pct(scores_out, snp_stats): """Compute percentile values for each target and write to HDF5. - + Args: scores_out (h5py.File): Output HDF5 file. snp_stats [str]: List of SAD stats to compute. @@ -395,7 +395,7 @@ def write_pct(scores_out, snp_stats): def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats): """Write SNP predictions to HDF, assuming the length dimension has been collapsed. - + Args: ref_preds_sum (np.array): Reference allele predictions. alt_preds_sum (np.array): Alternative allele predictions. @@ -413,7 +413,7 @@ def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats): def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): """Write SNP predictions to HDF, assuming the length dimension has been maintained. - + Args: ref_preds (np.array): Reference allele predictions. alt_preds (np.array): Alternative allele predictions. diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index d55c636..5c55f52 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -783,7 +783,7 @@ def adaptive_clip_grad( ): """Adaptive gradient clipping.""" new_grads = [] - for (params, grads) in zip(parameters, gradients): + for params, grads in zip(parameters, gradients): p_norm = unitwise_norm(params) max_norm = tf.math.maximum(p_norm, eps) * clip_factor grad_norm = unitwise_norm(grads) diff --git a/src/baskerville/vcf.py b/src/baskerville/vcf.py index 436bf8e..802b0b5 100644 --- a/src/baskerville/vcf.py +++ b/src/baskerville/vcf.py @@ -219,13 +219,11 @@ def snp_seq1(snp, seq_len, genome_open): seq_ref = seq[left_len : left_len + len(snp.ref_allele)] ref_found = True if seq_ref != snp.ref_allele: - # search for reference allele in alternatives ref_found = False # for each alternative allele for alt_al in snp.alt_alleles: - # grab reference sequence matching alt length seq_ref_alt = seq[left_len : left_len + len(alt_al)] if seq_ref_alt == alt_al: @@ -314,13 +312,11 @@ def snps_seq1(snps, seq_len, genome_fasta, return_seqs=False): # verify that ref allele matches ref sequence seq_ref = seq[left_len : left_len + len(snp.ref_allele)] if seq_ref != snp.ref_allele: - # search for reference allele in alternatives ref_found = False # for each alternative allele for alt_al in snp.alt_alleles: - # grab reference sequence matching alt length seq_ref_alt = seq[left_len : left_len + len(alt_al)] if seq_ref_alt == alt_al: diff --git a/tests/test_snp.py b/tests/test_snp.py index e44cf44..be71463 100644 --- a/tests/test_snp.py +++ b/tests/test_snp.py @@ -15,12 +15,13 @@ fasta_file = "%s/assembly/ucsc/hg38.fa" % os.environ["HG38"] stat_keys = ["logSAD", "logD2"] + def test_snp(): test_out_dir = f"{out_dir}/full" scores_file = f"{test_out_dir}/scores.h5" if os.path.isfile(scores_file): os.remove(scores_file) - + cmd = [ "src/baskerville/scripts/hound_snp.py", "-f", @@ -34,7 +35,7 @@ def test_snp(): targets_file, params_file, model_file, - vcf_file + vcf_file, ] print(" ".join(cmd)) subprocess.run(cmd, check=True) @@ -42,8 +43,8 @@ def test_snp(): with h5py.File(scores_file, "r") as scores_h5: for sk in stat_keys: score = scores_h5[sk][:] - score_var = score.var(axis=0, dtype='float32') - assert (score_var> 0).all() + score_var = score.var(axis=0, dtype="float32") + assert (score_var > 0).all() def test_slice(): @@ -74,19 +75,21 @@ def test_slice(): targets_rna_file, params_file, model_file, - vcf_file + vcf_file, ] print(" ".join(cmd)) subprocess.run(cmd, check=True) # stranded mask targets_strand_df = targets_prep_strand(targets_df) - rna_strand_mask = np.array([desc.startswith("RNA") for desc in targets_strand_df.description]) + rna_strand_mask = np.array( + [desc.startswith("RNA") for desc in targets_strand_df.description] + ) - for sk in stat_keys: + for sk in stat_keys: with h5py.File(f"{test_full_dir}/scores.h5", "r") as scores_h5: - score_full = scores_h5[sk][:].astype('float32') - score_full = score_full[...,rna_strand_mask] + score_full = scores_h5[sk][:].astype("float32") + score_full = score_full[..., rna_strand_mask] with h5py.File(f"{test_slice_dir}/scores.h5", "r") as scores_h5: - score_slice = scores_h5[sk][:].astype('float32') - assert np.allclose(score_full, score_slice) \ No newline at end of file + score_slice = scores_h5[sk][:].astype("float32") + assert np.allclose(score_full, score_slice)