Skip to content

Commit

Permalink
Merge pull request #41 from calico/metric-tests
Browse files Browse the repository at this point in the history
metrics tests
  • Loading branch information
davek44 authored Sep 7, 2024
2 parents 699b377 + 2ad2523 commit 16815fa
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/baskerville/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def poisson_multinomial(
rescale (bool): Rescale loss after re-weighting.
"""
seq_len = y_true.shape[1]

if weight_range < 1:
raise ValueError("Poisson Multinomial weight_range must be >=1")
elif weight_range == 1:
Expand All @@ -147,8 +147,8 @@ def poisson_multinomial(
y_pred = tf.math.multiply(y_pred, position_weights)

# sum across lengths
s_true = tf.math.reduce_sum(y_true, axis=-2) # B x T
s_pred = tf.math.reduce_sum(y_pred, axis=-2) # B x T
s_true = tf.math.reduce_sum(y_true, axis=-2) # B x T
s_pred = tf.math.reduce_sum(y_pred, axis=-2) # B x T

# total count poisson loss, mean across targets
poisson_term = poisson(s_true, s_pred) # B x T
Expand All @@ -159,7 +159,7 @@ def poisson_multinomial(
y_pred += epsilon

# normalize to sum to one
p_pred = y_pred / tf.expand_dims(s_pred, axis=-2) # B x L x T
p_pred = y_pred / tf.expand_dims(s_pred, axis=-2) # B x L x T

# multinomial loss
pl_pred = tf.math.log(p_pred) # B x L x T
Expand All @@ -168,7 +168,7 @@ def poisson_multinomial(
multinomial_term /= tf.reduce_sum(position_weights)

# normalize to scale of 1:1 term ratio
loss_raw = multinomial_term + total_weight * poisson_term
loss_raw = multinomial_term + total_weight * poisson_term # B x T
if rescale:
loss_rescale = loss_raw * 2 / (1 + total_weight)
else:
Expand Down
71 changes: 71 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest
from scipy import stats
from sklearn.metrics import r2_score
import tensorflow as tf

from baskerville.metrics import *

# data dimensions
N, L, T = 6, 8, 4


@pytest.fixture
def sample_data():
y_true = tf.random.uniform((N, L, T), minval=0, maxval=10, dtype=tf.float32)
y_pred = y_true + tf.random.normal((N, L, T), mean=0, stddev=0.1)
return y_true, y_pred


def test_PearsonR(sample_data):
y_true, y_pred = sample_data
pearsonr = PearsonR(num_targets=T, summarize=False)
pearsonr.update_state(y_true, y_pred)
tf_result = pearsonr.result().numpy()

# Compute SciPy result
scipy_result = np.zeros(T)
y_true_np = y_true.numpy().reshape(-1, T)
y_pred_np = y_pred.numpy().reshape(-1, T)
for i in range(T):
scipy_result[i], _ = stats.pearsonr(y_true_np[:, i], y_pred_np[:, i])

np.testing.assert_allclose(tf_result, scipy_result, rtol=1e-5, atol=1e-5)

# Test summarized result
pearsonr_summarized = PearsonR(num_targets=T, summarize=True)
pearsonr_summarized.update_state(y_true, y_pred)
tf_result_summarized = pearsonr_summarized.result().numpy()
assert tf_result_summarized.shape == ()
assert np.isclose(tf_result_summarized, np.mean(scipy_result), rtol=1e-5, atol=1e-5)


def test_R2(sample_data):
y_true, y_pred = sample_data
r2 = R2(num_targets=T, summarize=False)
r2.update_state(y_true, y_pred)
tf_result = r2.result().numpy()

# Compute sklearn result
sklearn_result = np.zeros(T)
y_true_np = y_true.numpy().reshape(-1, T)
y_pred_np = y_pred.numpy().reshape(-1, T)
for i in range(T):
sklearn_result[i] = r2_score(y_true_np[:, i], y_pred_np[:, i])

np.testing.assert_allclose(tf_result, sklearn_result, rtol=1e-5, atol=1e-5)

# Test summarized result
r2_summarized = R2(num_targets=T, summarize=True)
r2_summarized.update_state(y_true, y_pred)
tf_result_summarized = r2_summarized.result().numpy()
assert tf_result_summarized.shape == ()
assert np.isclose(
tf_result_summarized, np.mean(sklearn_result), rtol=1e-5, atol=1e-5
)


def test_poisson_multinomial_shape(sample_data):
y_true, y_pred = sample_data
loss = poisson_multinomial(y_true, y_pred)
assert loss.shape == (N, T)

0 comments on commit 16815fa

Please sign in to comment.