Skip to content

Commit

Permalink
Symmetric Metrics & Metrics improvements (#44)
Browse files Browse the repository at this point in the history
* Extend docstring for symmetric metrics

* Improve aggregator docstring

* Fix that identity is only guaranteed for L^p norm

* Fix docstring

* Implement symmetric mode

* Fix missing sqrt

* Add symmetric versions for MAE, MSE, and RMSE

* Add new symmetric metrics to documentation

* Add test for symmetric metrics

* Ensure mean is only taken over axis zero in case there are additional axes

* Add details
  • Loading branch information
Ceyron authored Oct 14, 2024
1 parent beea37e commit 52a191d
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 15 deletions.
12 changes: 12 additions & 0 deletions docs/api/utilities/metrics/spatial.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@

---

::: exponax.metrics.sMAE

---

::: exponax.metrics.sMSE

---

::: exponax.metrics.sRMSE

---

::: exponax.metrics.spatial_norm

---
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/on_metrics_simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
"3. Rooted metrics (i.e., related to the RMSE)\n",
"\n",
"Then for each of the three, there is both the absolute version and a\n",
"relative/normalized version\n",
"relative/normalized version. For all spatial-based metrics, MAE, MSE, and RMSE\n",
"also come with a symmetric version.\n",
"\n",
"All metrics computation work on single state arrays, i.e., arrays with a leading channel axis and one, two, or three subsequent spatial axes. **The arrays shall not have leading batch axes.** To work with batched arrays use `jax.vmap` and then reduce, e.g., by `jnp.mean`. Alternatively, use the convinience wrapper [`exponax.metrics.mean_metric`][].\n",
"\n",
"All metrics **sum over the channel axis**.\n",
"\n",
" ⚠️ ⚠️ ⚠️ ⚠️ ⚠️ This notebook is a WIP, it will come with future release of Exponax ⚠️ ⚠️ ⚠️ ⚠️ ⚠️"
]
},
Expand Down
6 changes: 6 additions & 0 deletions exponax/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
nMAE,
nMSE,
nRMSE,
sMAE,
sMSE,
spatial_aggregator,
spatial_norm,
sRMSE,
)
from ._utils import mean_metric

Expand All @@ -31,6 +34,9 @@
"nMAE",
"nMSE",
"nRMSE",
"sMAE",
"sMSE",
"sRMSE",
"fourier_aggregator",
"fourier_norm",
"fourier_MAE",
Expand Down
2 changes: 1 addition & 1 deletion exponax/metrics/_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fourier_aggregator(
!!! info
The result of this function (under default settings) is (up to rounding
errors) identical to [`exponax.metrics.spatial_aggregator`][] for
`inner_exponent=1.0`. As such, it can be a consistent counterpart for
`inner_exponent=2.0`. As such, it can be a consistent counterpart for
metrics based on the `L²(Ω)` functional norm.
!!! tip
Expand Down
207 changes: 195 additions & 12 deletions exponax/metrics/_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ def spatial_aggregator(
and the right is not, there is the following relation between a continuous
function `u(x)` and its discretely sampled counterpart `uₕ`
‖ u(x) ‖ᵖ_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) = ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p)
‖ u(x) ‖_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p)
where the summation `∑ᵢ` must be understood as a sum over all `Nᴰ` points
across all spatial dimensions. The `inner_exponent` corresponds to `p` in
the above formula. This function allows setting the outer exponent `q`
manually. If it is not specified, it is set to `1/q = 1/p` to get a valid
norm.
the above formula. This function also allows setting the outer exponent `q`
which via
( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^q
If it is not specified, it is set to `q = 1/p` to get a valid norm.
!!! tip
To apply this function to a state tensor with a leading channel axis,
Expand All @@ -40,7 +43,7 @@ def spatial_aggregator(
**Arguments:**
- `state_no_channel`: The state tensor **without a leading channel
dimension**.
axis**.
- `num_spatial_dims`: The number of spatial dimensions. If not specified,
it is inferred from the number of axes in `state_no_channel`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`.
Expand Down Expand Up @@ -84,7 +87,7 @@ def spatial_norm(
state: Float[Array, "C ... N"],
state_ref: Optional[Float[Array, "C ... N"]] = None,
*,
mode: Literal["absolute", "normalized"] = "absolute",
mode: Literal["absolute", "normalized", "symmetric"] = "absolute",
domain_extent: float = 1.0,
inner_exponent: float = 2.0,
outer_exponent: Optional[float] = None,
Expand All @@ -97,13 +100,18 @@ def spatial_norm(
control, consider using [`exponax.metrics.spatial_aggregator`][] directly.
This function allows providing a second state (`state_ref`) to compute
either the absolute or normalized difference. The `"absolute"` mode computes
either the absolute, normalized, or symmetric difference. The `"absolute"`
mode computes
(‖|uₕ uₕʳ|ᵖ ‖_L²(Ω))^q
(‖uₕ - uₕʳ‖_L^p(Ω))^(q*p)
while the `"normalized"` mode computes
(‖|uₕ − uₕʳ|ᵖ‖_ L²(Ω))^q / (‖|uₕʳ|ᵖ‖_ L²(Ω))^q
(‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕʳ‖_L^p(Ω))^(q*p))
and the `"symmetric"` mode computes
2 * (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕ‖_L^p(Ω))^(q*p) + (‖uₕʳ‖_L^p(Ω))^(q*p))
In either way, the channels are summed **after** the aggregation. The
`inner_exponent` corresponds to `p` in the above formulas. The
Expand All @@ -124,7 +132,8 @@ def spatial_norm(
- `state_ref`: The reference state tensor. Must have the same shape as
`state`. If not specified, only the absolute norm of `state` is
computed.
- `mode`: The mode of the norm. Either `"absolute"` or `"normalized"`.
- `mode`: The mode of the norm. Either `"absolute"`, `"normalized"`, or
`"symmetric"`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`.
- `inner_exponent`: The exponent `p` in the L^p norm.
- `outer_exponent`: The exponent `q` the result after aggregation is raised
Expand All @@ -133,6 +142,8 @@ def spatial_norm(
if state_ref is None:
if mode == "normalized":
raise ValueError("mode 'normalized' requires state_ref")
if mode == "symmetric":
raise ValueError("mode 'symmetric' requires state_ref")
diff = state
else:
diff = state - state_ref
Expand All @@ -157,6 +168,27 @@ def spatial_norm(
)(state_ref)
normalized_diff_per_channel = diff_norm_per_channel / ref_norm_per_channel
norm_per_channel = normalized_diff_per_channel
elif mode == "symmetric":
state_norm_per_channel = jax.vmap(
lambda s: spatial_aggregator(
s,
domain_extent=domain_extent,
inner_exponent=inner_exponent,
outer_exponent=outer_exponent,
),
)(state)
ref_norm_per_channel = jax.vmap(
lambda r: spatial_aggregator(
r,
domain_extent=domain_extent,
inner_exponent=inner_exponent,
outer_exponent=outer_exponent,
),
)(state_ref)
symmetric_diff_per_channel = (
2 * diff_norm_per_channel / (state_norm_per_channel + ref_norm_per_channel)
)
norm_per_channel = symmetric_diff_per_channel
else:
norm_per_channel = diff_norm_per_channel

Expand Down Expand Up @@ -255,6 +287,55 @@ def nMAE(
)


def sMAE(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
*,
domain_extent: float = 1.0,
) -> float:
"""
Compute the symmetric mean absolute error (sMAE) between two states.
∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ| / (∑_(space) (L/N)ᴰ |uₕ| + ∑_(space) (L/N)ᴰ |uₕʳ|)]
Given the correct `domain_extent`, this is consistent to the following
functional norm:
2 ∫_Ω |u(x) - uʳ(x)| dx / (∫_Ω |u(x)| dx + ∫_Ω |uʳ(x)| dx)
The channel axis is summed **after** the aggregation.
!!! tip
To apply this function to a state tensor with a leading batch axis, use
`jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
a helper for this, [`exponax.metrics.mean_metric`][] is provided.
!!! info
This symmetric metric is bounded between 0 and C with C being the number
of channels.
**Arguments:**
- `u_pred`: The state array, must follow the `Exponax` convention with a
leading channel axis, and either one, two, or three subsequent spatial
axes.
- `u_ref`: The reference state array. Must have the same shape as `u_pred`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
provide to get the correctly consistent norm. If this metric is used an
optimization objective, it can often be ignored since it only
contributes a multiplicative factor.
"""
return spatial_norm(
u_pred,
u_ref,
mode="symmetric",
domain_extent=domain_extent,
inner_exponent=1.0,
outer_exponent=1.0,
)


def MSE(
u_pred: Float[Array, "C ... N"],
u_ref: Optional[Float[Array, "C ... N"]] = None,
Expand Down Expand Up @@ -347,6 +428,55 @@ def nMSE(
)


def sMSE(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
*,
domain_extent: float = 1.0,
) -> float:
"""
Compute the symmetric mean squared error (sMSE) between two states.
∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|² / (∑_(space) (L/N)ᴰ |uₕ|² + ∑_(space) (L/N)ᴰ |uₕʳ|²)]
Given the correct `domain_extent`, this is consistent to the following
functional norm:
2 ∫_Ω |u(x) - uʳ(x)|² dx / (∫_Ω |u(x)|² dx + ∫_Ω |uʳ(x)|² dx)
The channel axis is summed **after** the aggregation.
!!! tip
To apply this function to a state tensor with a leading batch axis, use
`jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
a helper for this, [`exponax.metrics.mean_metric`][] is provided.
!!! info
This symmetric metric is bounded between 0 and C with C being the number
of channels.
**Arguments:**
- `u_pred`: The state array, must follow the `Exponax` convention with a
leading channel axis, and either one, two, or three subsequent spatial
axes.
- `u_ref`: The reference state array. Must have the same shape as `u_pred`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
provide to get the correctly consistent norm. If this metric is used an
optimization objective, it can often be ignored since it only
contributes a multiplicative factor.
"""
return spatial_norm(
u_pred,
u_ref,
mode="symmetric",
domain_extent=domain_extent,
inner_exponent=2.0,
outer_exponent=1.0,
)


def RMSE(
u_pred: Float[Array, "C ... N"],
u_ref: Optional[Float[Array, "C ... N"]] = None,
Expand All @@ -361,7 +491,7 @@ def RMSE(
Given the correct `domain_extent`, this is consistent to the following
functional norm:
(‖ u - uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx)
(‖ u - uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx)
The channel axis is summed **after** the aggregation. Hence, it is also
summed **after** the square root. If you need the RMSE per channel, consider
Expand Down Expand Up @@ -411,7 +541,7 @@ def nRMSE(
Given the correct `domain_extent`, this is consistent to the following
functional norm:
(‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω
(‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω
|uʳ(x)|² dx
The channel axis is summed **after** the aggregation. Hence, it is also
Expand Down Expand Up @@ -444,3 +574,56 @@ def nRMSE(
inner_exponent=2.0,
outer_exponent=0.5,
)


def sRMSE(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
*,
domain_extent: float = 1.0,
) -> float:
"""
Compute the symmetric root mean squared error (sRMSE) between two states.
∑_(channels) [2 √(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²) / (√(∑_(space) (L/N)ᴰ
|uₕ|²) + √(∑_(space) (L/N)ᴰ |uₕʳ|²))]
Given the correct `domain_extent`, this is consistent to the following
functional norm:
2 √(∫_Ω |u(x) - uʳ(x)|² dx) / (√(∫_Ω |u(x)|² dx) + √(∫_Ω |uʳ(x)|² dx))
The channel axis is summed **after** the aggregation. Hence, it is also
summed **after** the square root and after normalization. If you need more
fine-grained control, consider using
[`exponax.metrics.spatial_aggregator`][] directly.
!!! tip
To apply this function to a state tensor with a leading batch axis, use
`jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
a helper for this, [`exponax.metrics.mean_metric`][] is provided.
!!! info
This symmetric metric is bounded between 0 and C with C being the number
of channels.
**Arguments:**
- `u_pred`: The state array, must follow the `Exponax` convention with a
leading channel axis, and either one, two, or three subsequent spatial
axes.
- `u_ref`: The reference state array. Must have the same shape as `u_pred`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
provide to get the correctly consistent norm. If this metric is used an
optimization objective, it can often be ignored since it only contributes
a multiplicative factor
"""
return spatial_norm(
u_pred,
u_ref,
mode="symmetric",
domain_extent=domain_extent,
inner_exponent=2.0,
outer_exponent=0.5,
)
3 changes: 2 additions & 1 deletion exponax/metrics/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ def mean_metric(
'meanifies' a metric function to operate on arrays with a leading batch axis
"""
wrapped_fn = lambda *a: metric_fn(*a, **kwargs)
return jnp.mean(jax.vmap(wrapped_fn)(*args))
metric_per_sample = jax.vmap(wrapped_fn, in_axes=0)(*args)
return jnp.mean(metric_per_sample, axis=0)
12 changes: 12 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def test_constant_offset(num_spatial_dims: int):
assert ex.metrics.nMSE(u_0, u_1) == pytest.approx((2.0 - 4.0) ** 2 / (4.0) ** 2)
assert ex.metrics.nMSE(u_0, u_1) == pytest.approx(1 / 4)

# == approx(0.4
assert ex.metrics.sMSE(u_1, u_0) == pytest.approx(
2.0 * (4.0 - 2.0) ** 2 / ((2.0) ** 2 + (4.0) ** 2)
)
assert ex.metrics.sMSE(u_1, u_0) == pytest.approx(0.4)

# Symmetric metric must be symmetric
assert ex.metrics.sMSE(u_0, u_1) == ex.metrics.sMSE(u_1, u_0)

assert ex.metrics.RMSE(u_1, u_0, domain_extent=1.0) == pytest.approx(2.0)
assert ex.metrics.RMSE(u_1, u_0, domain_extent=DOMAIN_EXTENT) == pytest.approx(
jnp.sqrt(DOMAIN_EXTENT**num_spatial_dims * 4.0)
Expand All @@ -60,6 +69,9 @@ def test_constant_offset(num_spatial_dims: int):
)
assert ex.metrics.nRMSE(u_0, u_1) == pytest.approx(0.5)

# == approx(2/3)
assert ex.metrics.sRMSE(u_1, u_0) == pytest.approx(2 / 3)

# The Fourier nRMSE should be identical to the spatial nRMSE
# assert ex.metrics.fourier_nRMSE(u_1, u_0) == ex.metrics.nRMSE(u_1, u_0)
# assert ex.metrics.fourier_nRMSE(u_0, u_1) == ex.metrics.nRMSE(u_0, u_1)
Expand Down

0 comments on commit 52a191d

Please sign in to comment.