Skip to content

Commit

Permalink
Merge pull request #33 from calico/loss-weight
Browse files Browse the repository at this point in the history
spatial weighting for poisson multinomial
  • Loading branch information
davek44 authored May 27, 2024
2 parents 0a3075f + 313b94b commit 33a33dc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
27 changes: 22 additions & 5 deletions src/baskerville/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def poisson_multinomial(
y_true,
y_pred,
total_weight: float = 1,
weight_range: float = 1,
weight_exp: int = 4,
epsilon: float = 1e-7,
rescale: bool = False,
):
Expand All @@ -126,14 +128,26 @@ def poisson_multinomial(
rescale (bool): Rescale loss after re-weighting.
"""
seq_len = y_true.shape[1]
pos_start = -(seq_len / 2 - 0.5)
pos_end = seq_len / 2 + 0.5
sigma = -pos_start / (np.log(weight_range)) ** (1 / weight_exp)

positions = tf.range(pos_start, pos_end, dtype=tf.float32)
position_weights = tf.exp(-((positions / sigma) ** weight_exp))
position_weights /= tf.reduce_max(position_weights)
position_weights = tf.expand_dims(position_weights, axis=0)
position_weights = tf.expand_dims(position_weights, axis=-1)

y_true = tf.math.multiply(y_true, position_weights)
y_pred = tf.math.multiply(y_pred, position_weights)

# sum across lengths
s_true = tf.math.reduce_sum(y_true, axis=-2, keepdims=True)
s_pred = tf.math.reduce_sum(y_pred, axis=-2, keepdims=True)

# total count poisson loss, mean across targets
poisson_term = poisson(s_true, s_pred) # B x T
poisson_term /= seq_len
poisson_term /= tf.reduce_sum(position_weights)

# add epsilon to protect against tiny values
y_true += epsilon
Expand All @@ -146,7 +160,7 @@ def poisson_multinomial(
pl_pred = tf.math.log(p_pred) # B x L x T
multinomial_dot = -tf.math.multiply(y_true, pl_pred) # B x L x T
multinomial_term = tf.math.reduce_sum(multinomial_dot, axis=-2) # B x T
multinomial_term /= seq_len
multinomial_term /= tf.reduce_sum(position_weights)

# normalize to scale of 1:1 term ratio
loss_raw = multinomial_term + total_weight * poisson_term
Expand All @@ -167,12 +181,15 @@ class PoissonMultinomial(LossFunctionWrapper):

def __init__(
self,
total_weight=1,
total_weight: float = 1,
weight_range: float = 1,
weight_exp: int = 4,
reduction=losses_utils.ReductionV2.AUTO,
name: str = "poisson_multinomial",
):
self.total_weight = total_weight
pois_mn = lambda yt, yp: poisson_multinomial(yt, yp, self.total_weight)
pois_mn = lambda yt, yp: poisson_multinomial(
yt, yp, total_weight, weight_range, weight_exp
)
super(PoissonMultinomial, self).__init__(
pois_mn, name=name, reduction=reduction
)
Expand Down
23 changes: 20 additions & 3 deletions src/baskerville/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def parse_loss(
keras_fit: bool = True,
spec_weight: float = 1,
total_weight: float = 1,
weight_range: float = 1,
weight_exp: int = 1,
):
"""Parse loss function from label, strategy, and fitting method.
Expand All @@ -51,7 +53,10 @@ def parse_loss(
)
elif loss_label == "poisson_mn":
loss_fn = metrics.PoissonMultinomial(
total_weight, reduction=tf.keras.losses.Reduction.NONE
total_weight=total_weight,
weight_range=weight_range,
weight_exp=weight_exp,
reduction=tf.keras.losses.Reduction.NONE,
)
else:
loss_fn = tf.keras.losses.Poisson(reduction=tf.keras.losses.Reduction.NONE)
Expand All @@ -65,7 +70,11 @@ def parse_loss(
elif loss_label == "poisson_kl":
loss_fn = metrics.PoissonKL(spec_weight)
elif loss_label == "poisson_mn":
loss_fn = metrics.PoissonMultinomial(total_weight)
loss_fn = metrics.PoissonMultinomial(
total_weight=total_weight,
weight_range=weight_range,
weight_exp=weight_exp,
)
else:
loss_fn = tf.keras.losses.Poisson()

Expand Down Expand Up @@ -127,9 +136,17 @@ def __init__(
# loss
self.spec_weight = self.params.get("spec_weight", 1)
self.total_weight = self.params.get("total_weight", 1)
self.weight_range = self.params.get("weight_range", 1)
self.weight_exp = self.params.get("weight_exp", 1)
self.loss = self.params.get("loss", "poisson").lower()
self.loss_fn = parse_loss(
self.loss, self.strategy, keras_fit, self.spec_weight, self.total_weight
self.loss,
self.strategy,
keras_fit,
self.spec_weight,
self.total_weight,
self.weight_range,
self.weight_exp,
)

# optimizer
Expand Down

0 comments on commit 33a33dc

Please sign in to comment.