Skip to content

Commit

Permalink
black format
Browse files Browse the repository at this point in the history
  • Loading branch information
davek44 committed Nov 10, 2023
1 parent 38c874c commit bbbb2fd
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 56 deletions.
21 changes: 3 additions & 18 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,12 +1772,7 @@ def dense_block(
# flatten
if flatten:
_, seq_len, seq_depth = current.shape
current = tf.keras.layers.Reshape(
(
1,
seq_len * seq_depth,
)
)(current)
current = tf.keras.layers.Reshape((1, seq_len * seq_depth,))(current)

# dense
current = tf.keras.layers.Dense(
Expand Down Expand Up @@ -1879,12 +1874,7 @@ def dense_nac(
# flatten
if flatten:
_, seq_len, seq_depth = current.shape
current = tf.keras.layers.Reshape(
(
1,
seq_len * seq_depth,
)
)(current)
current = tf.keras.layers.Reshape((1, seq_len * seq_depth,))(current)

# dense
current = tf.keras.layers.Dense(
Expand Down Expand Up @@ -1934,12 +1924,7 @@ def final(
# flatten
if flatten:
_, seq_len, seq_depth = current.shape
current = tf.keras.layers.Reshape(
(
1,
seq_len * seq_depth,
)
)(current)
current = tf.keras.layers.Reshape((1, seq_len * seq_depth,))(current)

# dense
current = tf.keras.layers.Dense(
Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def untransform_preds(preds, targets_df, unscale=False):

# sqrt
sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat])
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** 2 # (4 / 3)
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** 2 # (4 / 3)

# scale
if unscale:
Expand Down
4 changes: 2 additions & 2 deletions src/baskerville/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def call(self, inputs, training=False):

# Scale the query by the square-root of key size.
if self._scaling:
q *= self._key_size**-0.5
q *= self._key_size ** -0.5

# [B, H, T', T]
content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True)
Expand Down Expand Up @@ -888,7 +888,7 @@ def call(self, inputs):

triu_tup = np.triu_indices(seq_len, self.diagonal_offset)
triu_index = list(triu_tup[0] + seq_len * triu_tup[1])
unroll_repr = tf.reshape(inputs, [-1, seq_len**2, output_dim])
unroll_repr = tf.reshape(inputs, [-1, seq_len ** 2, output_dim])
return tf.gather(unroll_repr, triu_index, axis=1)

def get_config(self):
Expand Down
51 changes: 36 additions & 15 deletions src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,12 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):

# setup output
scores_out = initialize_output_h5(
options.out_dir, options.snp_stats, snps, targets_length, targets_strand_df, num_shifts
options.out_dir,
options.snp_stats,
snps,
targets_length,
targets_strand_df,
num_shifts,
)

# SNP index
Expand All @@ -150,9 +155,13 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):
# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
ref_preds_shift = dataset.untransform_preds1(ref_preds_shift, targets_df)
ref_preds_shift = dataset.untransform_preds1(
ref_preds_shift, targets_df
)
else:
ref_preds_shift = dataset.untransform_preds(ref_preds_shift, targets_df)
ref_preds_shift = dataset.untransform_preds(
ref_preds_shift, targets_df
)

# sum strand pairs
if strand_transform is not None:
Expand Down Expand Up @@ -189,9 +198,13 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):
# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
alt_preds_shift = dataset.untransform_preds1(alt_preds_shift, targets_df)
alt_preds_shift = dataset.untransform_preds1(
alt_preds_shift, targets_df
)
else:
alt_preds_shift = dataset.untransform_preds(alt_preds_shift, targets_df)
alt_preds_shift = dataset.untransform_preds(
alt_preds_shift, targets_df
)

# sum strand pairs
if strand_transform is not None:
Expand Down Expand Up @@ -253,7 +266,9 @@ def cluster_snps(snps, seq_len: int, center_pct: float):
return snp_clusters


def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df, num_shifts):
def initialize_output_h5(
out_dir, snp_stats, snps, targets_length, targets_df, num_shifts
):
"""Initialize an output HDF5 file for SAD stats.
Args:
Expand Down Expand Up @@ -307,7 +322,9 @@ def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df, n
for snp_stat in snp_stats:
if snp_stat in ["REF", "ALT"]:
scores_out.create_dataset(
snp_stat, shape=(num_snps, num_shifts, targets_length, num_targets), dtype="float16"
snp_stat,
shape=(num_snps, num_shifts, targets_length, num_targets),
dtype="float16",
)
else:
scores_out.create_dataset(
Expand Down Expand Up @@ -463,12 +480,12 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
alt_preds_sqrt = np.sqrt(alt_preds)

# sum across length
ref_preds_sum = ref_preds.sum(axis=(0,1))
alt_preds_sum = alt_preds.sum(axis=(0,1))
ref_preds_log_sum = ref_preds_log.sum(axis=(0,1))
alt_preds_log_sum = alt_preds_log.sum(axis=(0,1))
ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0,1))
alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0,1))
ref_preds_sum = ref_preds.sum(axis=(0, 1))
alt_preds_sum = alt_preds.sum(axis=(0, 1))
ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1))
alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1))
ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1))
alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1))

# difference
altref_diff = alt_preds - ref_preds
Expand Down Expand Up @@ -565,8 +582,12 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
# compare normalized JS
log_js_dist = []
for s in range(num_shifts):
ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum(axis=0)
alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum(axis=0)
ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum(
axis=0
)
alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum(
axis=0
)
log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2)
log_js_dist = np.mean(log_js_dist, axis=0)
scores_out["logJS"][si] = log_js_dist.astype("float16")
Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def make_optimizer(self):

def compute_norm(x, axis, keepdims):
"""Compute L2 norm of a tensor across an axis."""
return tf.math.reduce_sum(x**2, axis=axis, keepdims=keepdims) ** 0.5
return tf.math.reduce_sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5


def unitwise_norm(x):
Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def get_alleles(self):
"""Return a list of all alleles"""
alleles = [self.ref_allele] + self.alt_alleles
return alleles

def indel_size(self):
"""Return the size of the indel."""
return len(self.alt_allele) - len(self.ref_allele)
Expand Down
21 changes: 3 additions & 18 deletions tests/test_dna.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,9 @@ def test_dna_rc():
)

dna_1hot_cases = [
(
"ACGT",
False,
False,
ACGT_ARRAY,
),
(
"ACNGT",
False,
False,
ACNGT_FF_ARRAY,
),
(
"ACNGT",
True,
False,
ACNGT_TF_ARRAY,
),
("ACGT", False, False, ACGT_ARRAY,),
("ACNGT", False, False, ACNGT_FF_ARRAY,),
("ACNGT", True, False, ACNGT_TF_ARRAY,),
]


Expand Down

0 comments on commit bbbb2fd

Please sign in to comment.