Skip to content

Commit

Permalink
black cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
davek44 committed Dec 19, 2023
1 parent 7b3fc9f commit ca640ba
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 14 deletions.
13 changes: 9 additions & 4 deletions src/baskerville/scripts/hound_ism_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
given in a VCF file.
"""


################################################################################
# main
################################################################################
Expand Down Expand Up @@ -211,7 +212,7 @@ def main():
ref_1hot = np.expand_dims(seqs_1hot[si], axis=0)

# save sequence
scores_h5['seqs'][si] = ref_1hot[0, mut_start:mut_end].astype('bool')
scores_h5["seqs"][si] = ref_1hot[0, mut_start:mut_end].astype("bool")

# predict reference
ref_preds = []
Expand Down Expand Up @@ -250,11 +251,15 @@ def main():
options.untransform_old,
)
alt_preds.append(alt_preds_shift)
alt_preds = np.array(alt_preds)
alt_preds = np.array(alt_preds)

ism_scores = snps.compute_scores(ref_preds, alt_preds, options.snp_stats)
ism_scores = snps.compute_scores(
ref_preds, alt_preds, options.snp_stats
)
for snp_stat in options.snp_stats:
scores_h5[snp_stat][si, mi-mut_start, ni] = ism_scores[snp_stat]
scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[
snp_stat
]

# close output HDF5
scores_h5.close()
Expand Down
4 changes: 2 additions & 2 deletions src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main():
print("Running on CPU")
if options.require_gpu:
raise SystemExit("Job terminated because it's running on CPU")

#################################################################
# download input files from gcs to a local file
if options.gcs:
Expand All @@ -200,7 +200,7 @@ def main():
options.targets_file = download_rename_inputs(
options.targets_file, temp_dir
)

#################################################################
# calculate SAD scores:
if options.processes is not None:
Expand Down
12 changes: 6 additions & 6 deletions src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,14 +970,14 @@ def predict(
return preds

def predict_transform(
self,
seq_1hot: np.array,
self,
seq_1hot: np.array,
targets_df,
strand_transform: np.array = None,
untransform_old: bool = False,
):
"""Predict a single sequence and transform.
Args:
seq_1hot (np.array): 1-hot encoded sequence.
targets_df (pd.DataFrame): Targets dataframe.
Expand All @@ -986,19 +986,19 @@ def predict_transform(
"""
# predict
preds = self(seq_1hot)[0]

# untransform predictions
if untransform_old:
preds = dataset.untransform_preds1(preds, targets_df)
else:
preds = dataset.untransform_preds(preds, targets_df)

# sum strand pairs
if strand_transform is not None:
preds = preds * strand_transform

return preds

def restore(self, model_file, head_i=0, trunk=False):
"""Restore weights from saved model."""
if trunk:
Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):

# read targets
if options.targets_file is None:
print('Must provide targets file to clarify stranded datasets', file=sys.stderr)
print("Must provide targets file to clarify stranded datasets", file=sys.stderr)
exit(1)
targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_ism.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
snp_out_dir = "tests/data/ism/snp_out"
bed_out_dir = "tests/data/ism/bed_out"


def test_snp():
cmd = [
"src/baskerville/scripts/hound_ism_snp.py",
Expand Down Expand Up @@ -46,6 +47,7 @@ def test_snp():
# verify variance
assert (score_var > 0).all()


def test_bed():
cmd = [
"src/baskerville/scripts/hound_ism_bed.py",
Expand Down Expand Up @@ -78,4 +80,4 @@ def test_bed():
assert not np.isnan(score).any()

# verify variance
assert (score_var > 0).all()
assert (score_var > 0).all()

0 comments on commit ca640ba

Please sign in to comment.