diff --git a/src/baskerville/metrics.py b/src/baskerville/metrics.py index 8aa549f..8948880 100644 --- a/src/baskerville/metrics.py +++ b/src/baskerville/metrics.py @@ -115,6 +115,8 @@ def poisson_multinomial( y_true, y_pred, total_weight: float = 1, + weight_range: float = 1, + weight_exp: int = 4, epsilon: float = 1e-7, rescale: bool = False, ): @@ -126,6 +128,18 @@ def poisson_multinomial( rescale (bool): Rescale loss after re-weighting. """ seq_len = y_true.shape[1] + pos_start = -(seq_len / 2 - 0.5) + pos_end = seq_len / 2 + 0.5 + sigma = -pos_start / (np.log(weight_range)) ** (1 / weight_exp) + + positions = tf.range(pos_start, pos_end, dtype=tf.float32) + position_weights = tf.exp(-((positions / sigma) ** weight_exp)) + position_weights /= tf.reduce_max(position_weights) + position_weights = tf.expand_dims(position_weights, axis=0) + position_weights = tf.expand_dims(position_weights, axis=-1) + + y_true = tf.math.multiply(y_true, position_weights) + y_pred = tf.math.multiply(y_pred, position_weights) # sum across lengths s_true = tf.math.reduce_sum(y_true, axis=-2, keepdims=True) @@ -133,7 +147,7 @@ def poisson_multinomial( # total count poisson loss, mean across targets poisson_term = poisson(s_true, s_pred) # B x T - poisson_term /= seq_len + poisson_term /= tf.reduce_sum(position_weights) # add epsilon to protect against tiny values y_true += epsilon @@ -146,7 +160,7 @@ def poisson_multinomial( pl_pred = tf.math.log(p_pred) # B x L x T multinomial_dot = -tf.math.multiply(y_true, pl_pred) # B x L x T multinomial_term = tf.math.reduce_sum(multinomial_dot, axis=-2) # B x T - multinomial_term /= seq_len + multinomial_term /= tf.reduce_sum(position_weights) # normalize to scale of 1:1 term ratio loss_raw = multinomial_term + total_weight * poisson_term @@ -167,12 +181,15 @@ class PoissonMultinomial(LossFunctionWrapper): def __init__( self, - total_weight=1, + total_weight: float = 1, + weight_range: float = 1, + weight_exp: int = 4, reduction=losses_utils.ReductionV2.AUTO, name: str = "poisson_multinomial", ): - self.total_weight = total_weight - pois_mn = lambda yt, yp: poisson_multinomial(yt, yp, self.total_weight) + pois_mn = lambda yt, yp: poisson_multinomial( + yt, yp, total_weight, weight_range, weight_exp + ) super(PoissonMultinomial, self).__init__( pois_mn, name=name, reduction=reduction ) diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index 5c55f52..d779066 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -27,6 +27,8 @@ def parse_loss( keras_fit: bool = True, spec_weight: float = 1, total_weight: float = 1, + weight_range: float = 1, + weight_exp: int = 1, ): """Parse loss function from label, strategy, and fitting method. @@ -51,7 +53,10 @@ def parse_loss( ) elif loss_label == "poisson_mn": loss_fn = metrics.PoissonMultinomial( - total_weight, reduction=tf.keras.losses.Reduction.NONE + total_weight=total_weight, + weight_range=weight_range, + weight_exp=weight_exp, + reduction=tf.keras.losses.Reduction.NONE, ) else: loss_fn = tf.keras.losses.Poisson(reduction=tf.keras.losses.Reduction.NONE) @@ -65,7 +70,11 @@ def parse_loss( elif loss_label == "poisson_kl": loss_fn = metrics.PoissonKL(spec_weight) elif loss_label == "poisson_mn": - loss_fn = metrics.PoissonMultinomial(total_weight) + loss_fn = metrics.PoissonMultinomial( + total_weight=total_weight, + weight_range=weight_range, + weight_exp=weight_exp, + ) else: loss_fn = tf.keras.losses.Poisson() @@ -127,9 +136,17 @@ def __init__( # loss self.spec_weight = self.params.get("spec_weight", 1) self.total_weight = self.params.get("total_weight", 1) + self.weight_range = self.params.get("weight_range", 1) + self.weight_exp = self.params.get("weight_exp", 1) self.loss = self.params.get("loss", "poisson").lower() self.loss_fn = parse_loss( - self.loss, self.strategy, keras_fit, self.spec_weight, self.total_weight + self.loss, + self.strategy, + keras_fit, + self.spec_weight, + self.total_weight, + self.weight_range, + self.weight_exp, ) # optimizer