Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

running black #41

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 60 additions & 65 deletions basenji/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,93 +25,88 @@
compute accuracy statistics.
"""

class Accuracy:

def __init__(self,
targets,
preds,
targets_na=None,
loss=None,
target_losses=None):
self.targets = targets
self.preds = preds
self.targets_na = targets_na
self.loss = loss
self.target_losses = target_losses
class Accuracy:
def __init__(self, targets, preds, targets_na=None, loss=None, target_losses=None):
self.targets = targets
self.preds = preds
self.targets_na = targets_na
self.loss = loss
self.target_losses = target_losses

self.num_targets = len(self.target_losses)
self.num_targets = len(self.target_losses)

def pearsonr(self, log=False, pseudocount=1, clip=None):
""" Compute target PearsonR vector. """
def pearsonr(self, log=False, pseudocount=1, clip=None):
""" Compute target PearsonR vector. """

pcor = np.zeros(self.num_targets)
pcor = np.zeros(self.num_targets)

for ti in range(self.num_targets):
if self.targets_na is not None:
preds_ti = self.preds[~self.targets_na, ti].astype('float64')
targets_ti = self.targets[~self.targets_na, ti].astype('float64')
else:
preds_ti = self.preds[:, :, ti].flatten().astype('float64')
targets_ti = self.targets[:, :, ti].flatten().astype('float64')
for ti in range(self.num_targets):
if self.targets_na is not None:
preds_ti = self.preds[~self.targets_na, ti].astype("float64")
targets_ti = self.targets[~self.targets_na, ti].astype("float64")
else:
preds_ti = self.preds[:, :, ti].flatten().astype("float64")
targets_ti = self.targets[:, :, ti].flatten().astype("float64")

if clip is not None:
preds_ti = np.clip(preds_ti, 0, clip)
targets_ti = np.clip(targets_ti, 0, clip)
if clip is not None:
preds_ti = np.clip(preds_ti, 0, clip)
targets_ti = np.clip(targets_ti, 0, clip)

if log:
preds_ti = np.log2(preds_ti + pseudocount)
targets_ti = np.log2(targets_ti + pseudocount)
if log:
preds_ti = np.log2(preds_ti + pseudocount)
targets_ti = np.log2(targets_ti + pseudocount)

pc, _ = stats.pearsonr(targets_ti, preds_ti)
pcor[ti] = pc
pc, _ = stats.pearsonr(targets_ti, preds_ti)
pcor[ti] = pc

return pcor
return pcor

def r2(self, log=False, pseudocount=1, clip=None):
""" Compute target R2 vector. """
r2_vec = np.zeros(self.num_targets)
def r2(self, log=False, pseudocount=1, clip=None):
""" Compute target R2 vector. """
r2_vec = np.zeros(self.num_targets)

for ti in range(self.num_targets):
if self.targets_na is not None:
preds_ti = self.preds[~self.targets_na, ti].astype('float64')
targets_ti = self.targets[~self.targets_na, ti].astype('float64')
else:
preds_ti = self.preds[:, :, ti].flatten().astype('float64')
targets_ti = self.targets[:, :, ti].flatten().astype('float64')
for ti in range(self.num_targets):
if self.targets_na is not None:
preds_ti = self.preds[~self.targets_na, ti].astype("float64")
targets_ti = self.targets[~self.targets_na, ti].astype("float64")
else:
preds_ti = self.preds[:, :, ti].flatten().astype("float64")
targets_ti = self.targets[:, :, ti].flatten().astype("float64")

if clip is not None:
preds_ti = np.clip(preds_ti, 0, clip)
targets_ti = np.clip(targets_ti, 0, clip)
if clip is not None:
preds_ti = np.clip(preds_ti, 0, clip)
targets_ti = np.clip(targets_ti, 0, clip)

if log:
preds_ti = np.log2(preds_ti + pseudocount)
targets_ti = np.log2(targets_ti + pseudocount)
if log:
preds_ti = np.log2(preds_ti + pseudocount)
targets_ti = np.log2(targets_ti + pseudocount)

r2_vec[ti] = metrics.r2_score(targets_ti, preds_ti)
r2_vec[ti] = metrics.r2_score(targets_ti, preds_ti)

return r2_vec
return r2_vec

def spearmanr(self):
""" Compute target SpearmanR vector. """
def spearmanr(self):
""" Compute target SpearmanR vector. """

scor = np.zeros(self.num_targets)
scor = np.zeros(self.num_targets)

for ti in range(self.num_targets):
if self.targets_na is not None:
preds_ti = self.preds[~self.targets_na, ti]
targets_ti = self.targets[~self.targets_na, ti]
else:
preds_ti = self.preds[:, :, ti].flatten()
targets_ti = self.targets[:, :, ti].flatten()
for ti in range(self.num_targets):
if self.targets_na is not None:
preds_ti = self.preds[~self.targets_na, ti]
targets_ti = self.targets[~self.targets_na, ti]
else:
preds_ti = self.preds[:, :, ti].flatten()
targets_ti = self.targets[:, :, ti].flatten()

sc, _ = stats.spearmanr(targets_ti, preds_ti)
scor[ti] = sc
sc, _ = stats.spearmanr(targets_ti, preds_ti)
scor[ti] = sc

return scor
return scor


################################################################################
# __main__
################################################################################
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()
147 changes: 78 additions & 69 deletions basenji/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,39 @@

from basenji import ops


def shift_sequence(seq, shift_amount, pad_value=0.25):
"""Shift a sequence left or right by shift_amount.
"""Shift a sequence left or right by shift_amount.

Args:
seq: a [batch_size, sequence_length, sequence_depth] sequence to shift
shift_amount: the signed amount to shift (tf.int32 or int)
pad_value: value to fill the padding (primitive or scalar tf.Tensor)
"""
if seq.shape.ndims != 3:
raise ValueError('input sequence should be rank 3')
input_shape = seq.shape
if seq.shape.ndims != 3:
raise ValueError("input sequence should be rank 3")
input_shape = seq.shape

pad = pad_value * tf.ones_like(seq[:, 0 : tf.abs(shift_amount), :])

pad = pad_value * tf.ones_like(seq[:, 0:tf.abs(shift_amount), :])
def _shift_right(_seq):
sliced_seq = _seq[:, :-shift_amount:, :]
return tf.concat([pad, sliced_seq], axis=1)

def _shift_right(_seq):
sliced_seq = _seq[:, :-shift_amount:, :]
return tf.concat([pad, sliced_seq], axis=1)
def _shift_left(_seq):
sliced_seq = _seq[:, -shift_amount:, :]
return tf.concat([sliced_seq, pad], axis=1)

def _shift_left(_seq):
sliced_seq = _seq[:, -shift_amount:, :]
return tf.concat([sliced_seq, pad], axis=1)
output = tf.cond(
tf.greater(shift_amount, 0), lambda: _shift_right(seq), lambda: _shift_left(seq)
)

output = tf.cond(
tf.greater(shift_amount, 0), lambda: _shift_right(seq),
lambda: _shift_left(seq))
output.set_shape(input_shape)
return output

output.set_shape(input_shape)
return output

def augment_deterministic_set(data_ops, augment_rc=False, augment_shifts=[0]):
"""
"""

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
Expand All @@ -55,22 +57,22 @@ def augment_deterministic_set(data_ops, augment_rc=False, augment_shifts=[0]):
Returns
data_ops_list:
"""
augment_pairs = []
for ashift in augment_shifts:
augment_pairs.append((False, ashift))
if augment_rc:
augment_pairs.append((True, ashift))
augment_pairs = []
for ashift in augment_shifts:
augment_pairs.append((False, ashift))
if augment_rc:
augment_pairs.append((True, ashift))

data_ops_list = []
for arc, ashift in augment_pairs:
data_ops_aug = augment_deterministic(data_ops, arc, ashift)
data_ops_list.append(data_ops_aug)
data_ops_list = []
for arc, ashift in augment_pairs:
data_ops_aug = augment_deterministic(data_ops, arc, ashift)
data_ops_list.append(data_ops_aug)

return data_ops_list
return data_ops_list


def augment_deterministic(data_ops, augment_rc=False, augment_shift=0):
"""Apply a deterministic augmentation, specified by the parameters.
"""Apply a deterministic augmentation, specified by the parameters.

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
Expand All @@ -80,75 +82,81 @@ def augment_deterministic(data_ops, augment_rc=False, augment_shift=0):
data_ops: augmented data, with all existing keys transformed
and 'reverse_preds' bool added.
"""
data_ops_aug = {}
for key in data_ops:
if key not in ['sequence']:
data_ops_aug[key] = data_ops[key]
data_ops_aug = {}
for key in data_ops:
if key not in ["sequence"]:
data_ops_aug[key] = data_ops[key]

if augment_shift == 0:
data_ops_aug['sequence'] = data_ops['sequence']
else:
shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64)
data_ops_aug['sequence'] = shift_sequence(data_ops['sequence'], shift_amount)
if augment_shift == 0:
data_ops_aug["sequence"] = data_ops["sequence"]
else:
shift_amount = tf.constant(augment_shift, shape=(), dtype=tf.int64)
data_ops_aug["sequence"] = shift_sequence(data_ops["sequence"], shift_amount)

if augment_rc:
data_ops_aug = augment_deterministic_rc(data_ops_aug)
else:
data_ops_aug['reverse_preds'] = tf.zeros((), dtype=tf.bool)
if augment_rc:
data_ops_aug = augment_deterministic_rc(data_ops_aug)
else:
data_ops_aug["reverse_preds"] = tf.zeros((), dtype=tf.bool)

return data_ops_aug
return data_ops_aug


def augment_deterministic_rc(data_ops):
"""Apply a deterministic reverse complement augmentation.
"""Apply a deterministic reverse complement augmentation.

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
Returns
data_ops_aug: augmented data ops
"""
data_ops_aug = ops.reverse_complement_transform(data_ops)
data_ops_aug['reverse_preds'] = tf.ones((), dtype=tf.bool)
return data_ops_aug
data_ops_aug = ops.reverse_complement_transform(data_ops)
data_ops_aug["reverse_preds"] = tf.ones((), dtype=tf.bool)
return data_ops_aug


def augment_stochastic_rc(data_ops):
"""Apply a stochastic reverse complement augmentation.
"""Apply a stochastic reverse complement augmentation.

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
Returns
data_ops_aug: augmented data
"""
reverse_preds = tf.random_uniform(shape=[]) > 0.5
data_ops_aug = tf.cond(reverse_preds, lambda: ops.reverse_complement_transform(data_ops),
lambda: data_ops.copy())
data_ops_aug['reverse_preds'] = reverse_preds
return data_ops_aug
reverse_preds = tf.random_uniform(shape=[]) > 0.5
data_ops_aug = tf.cond(
reverse_preds,
lambda: ops.reverse_complement_transform(data_ops),
lambda: data_ops.copy(),
)
data_ops_aug["reverse_preds"] = reverse_preds
return data_ops_aug


def augment_stochastic_shifts(seq, augment_shifts):
"""Apply a stochastic shift augmentation.
"""Apply a stochastic shift augmentation.

Args:
seq: input sequence of size [batch_size, length, depth]
augment_shifts: list of int offsets to sample from
Returns:
shifted and padded sequence of size [batch_size, length, depth]
"""
shift_index = tf.random_uniform(shape=[], minval=0,
maxval=len(augment_shifts), dtype=tf.int64)
shift_value = tf.gather(tf.constant(augment_shifts), shift_index)
shift_index = tf.random_uniform(
shape=[], minval=0, maxval=len(augment_shifts), dtype=tf.int64
)
shift_value = tf.gather(tf.constant(augment_shifts), shift_index)

seq = tf.cond(tf.not_equal(shift_value, 0),
lambda: shift_sequence(seq, shift_value),
lambda: seq)
seq = tf.cond(
tf.not_equal(shift_value, 0),
lambda: shift_sequence(seq, shift_value),
lambda: seq,
)

return seq
return seq


def augment_stochastic(data_ops, augment_rc=False, augment_shifts=[]):
"""Apply stochastic augmentations,
"""Apply stochastic augmentations,

Args:
data_ops: dict with keys 'sequence,' 'label,' and 'na.'
Expand All @@ -157,13 +165,14 @@ def augment_stochastic(data_ops, augment_rc=False, augment_shifts=[]):
Returns:
data_ops_aug: augmented data
"""
if augment_shifts:
data_ops['sequence'] = augment_stochastic_shifts(data_ops['sequence'],
augment_shifts)
if augment_shifts:
data_ops["sequence"] = augment_stochastic_shifts(
data_ops["sequence"], augment_shifts
)

if augment_rc:
data_ops = augment_stochastic_rc(data_ops)
else:
data_ops['reverse_preds'] = tf.zeros((), dtype=tf.bool)
if augment_rc:
data_ops = augment_stochastic_rc(data_ops)
else:
data_ops["reverse_preds"] = tf.zeros((), dtype=tf.bool)

return data_ops
return data_ops
Loading