From 52a191d1daaa777ee220b2346b0c1e561f66b54b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20K=C3=B6hler?= <27728103+Ceyron@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:03:12 +0200 Subject: [PATCH] Symmetric Metrics & Metrics improvements (#44) * Extend docstring for symmetric metrics * Improve aggregator docstring * Fix that identity is only guaranteed for L^p norm * Fix docstring * Implement symmetric mode * Fix missing sqrt * Add symmetric versions for MAE, MSE, and RMSE * Add new symmetric metrics to documentation * Add test for symmetric metrics * Ensure mean is only taken over axis zero in case there are additional axes * Add details --- docs/api/utilities/metrics/spatial.md | 12 ++ docs/examples/on_metrics_simple.ipynb | 5 +- exponax/metrics/__init__.py | 6 + exponax/metrics/_fourier.py | 2 +- exponax/metrics/_spatial.py | 207 ++++++++++++++++++++++++-- exponax/metrics/_utils.py | 3 +- tests/test_metrics.py | 12 ++ 7 files changed, 232 insertions(+), 15 deletions(-) diff --git a/docs/api/utilities/metrics/spatial.md b/docs/api/utilities/metrics/spatial.md index 53776d9..907270b 100644 --- a/docs/api/utilities/metrics/spatial.md +++ b/docs/api/utilities/metrics/spatial.md @@ -24,6 +24,18 @@ --- +::: exponax.metrics.sMAE + +--- + +::: exponax.metrics.sMSE + +--- + +::: exponax.metrics.sRMSE + +--- + ::: exponax.metrics.spatial_norm --- diff --git a/docs/examples/on_metrics_simple.ipynb b/docs/examples/on_metrics_simple.ipynb index fbbfcd1..5a5014e 100644 --- a/docs/examples/on_metrics_simple.ipynb +++ b/docs/examples/on_metrics_simple.ipynb @@ -20,10 +20,13 @@ "3. Rooted metrics (i.e., related to the RMSE)\n", "\n", "Then for each of the three, there is both the absolute version and a\n", - "relative/normalized version\n", + "relative/normalized version. For all spatial-based metrics, MAE, MSE, and RMSE\n", + "also come with a symmetric version.\n", "\n", "All metrics computation work on single state arrays, i.e., arrays with a leading channel axis and one, two, or three subsequent spatial axes. **The arrays shall not have leading batch axes.** To work with batched arrays use `jax.vmap` and then reduce, e.g., by `jnp.mean`. Alternatively, use the convinience wrapper [`exponax.metrics.mean_metric`][].\n", "\n", + "All metrics **sum over the channel axis**.\n", + "\n", " ⚠️ ⚠️ ⚠️ ⚠️ ⚠️ This notebook is a WIP, it will come with future release of Exponax ⚠️ ⚠️ ⚠️ ⚠️ ⚠️" ] }, diff --git a/exponax/metrics/__init__.py b/exponax/metrics/__init__.py index 53520ea..3fffe39 100644 --- a/exponax/metrics/__init__.py +++ b/exponax/metrics/__init__.py @@ -17,8 +17,11 @@ nMAE, nMSE, nRMSE, + sMAE, + sMSE, spatial_aggregator, spatial_norm, + sRMSE, ) from ._utils import mean_metric @@ -31,6 +34,9 @@ "nMAE", "nMSE", "nRMSE", + "sMAE", + "sMSE", + "sRMSE", "fourier_aggregator", "fourier_norm", "fourier_MAE", diff --git a/exponax/metrics/_fourier.py b/exponax/metrics/_fourier.py index fc9fe54..6f4746f 100644 --- a/exponax/metrics/_fourier.py +++ b/exponax/metrics/_fourier.py @@ -36,7 +36,7 @@ def fourier_aggregator( !!! info The result of this function (under default settings) is (up to rounding errors) identical to [`exponax.metrics.spatial_aggregator`][] for - `inner_exponent=1.0`. As such, it can be a consistent counterpart for + `inner_exponent=2.0`. As such, it can be a consistent counterpart for metrics based on the `L²(Ω)` functional norm. !!! tip diff --git a/exponax/metrics/_spatial.py b/exponax/metrics/_spatial.py index af15ebb..fac0b21 100644 --- a/exponax/metrics/_spatial.py +++ b/exponax/metrics/_spatial.py @@ -25,13 +25,16 @@ def spatial_aggregator( and the right is not, there is the following relation between a continuous function `u(x)` and its discretely sampled counterpart `uₕ` - ‖ u(x) ‖ᵖ_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) = ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p) + ‖ u(x) ‖_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) ≈ ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p) where the summation `∑ᵢ` must be understood as a sum over all `Nᴰ` points across all spatial dimensions. The `inner_exponent` corresponds to `p` in - the above formula. This function allows setting the outer exponent `q` - manually. If it is not specified, it is set to `1/q = 1/p` to get a valid - norm. + the above formula. This function also allows setting the outer exponent `q` + which via + + ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^q + + If it is not specified, it is set to `q = 1/p` to get a valid norm. !!! tip To apply this function to a state tensor with a leading channel axis, @@ -40,7 +43,7 @@ def spatial_aggregator( **Arguments:** - `state_no_channel`: The state tensor **without a leading channel - dimension**. + axis**. - `num_spatial_dims`: The number of spatial dimensions. If not specified, it is inferred from the number of axes in `state_no_channel`. - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. @@ -84,7 +87,7 @@ def spatial_norm( state: Float[Array, "C ... N"], state_ref: Optional[Float[Array, "C ... N"]] = None, *, - mode: Literal["absolute", "normalized"] = "absolute", + mode: Literal["absolute", "normalized", "symmetric"] = "absolute", domain_extent: float = 1.0, inner_exponent: float = 2.0, outer_exponent: Optional[float] = None, @@ -97,13 +100,18 @@ def spatial_norm( control, consider using [`exponax.metrics.spatial_aggregator`][] directly. This function allows providing a second state (`state_ref`) to compute - either the absolute or normalized difference. The `"absolute"` mode computes + either the absolute, normalized, or symmetric difference. The `"absolute"` + mode computes - (‖|uₕ − uₕʳ|ᵖ ‖_L²(Ω))^q + (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) while the `"normalized"` mode computes - (‖|uₕ − uₕʳ|ᵖ‖_ L²(Ω))^q / (‖|uₕʳ|ᵖ‖_ L²(Ω))^q + (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕʳ‖_L^p(Ω))^(q*p)) + + and the `"symmetric"` mode computes + + 2 * (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕ‖_L^p(Ω))^(q*p) + (‖uₕʳ‖_L^p(Ω))^(q*p)) In either way, the channels are summed **after** the aggregation. The `inner_exponent` corresponds to `p` in the above formulas. The @@ -124,7 +132,8 @@ def spatial_norm( - `state_ref`: The reference state tensor. Must have the same shape as `state`. If not specified, only the absolute norm of `state` is computed. - - `mode`: The mode of the norm. Either `"absolute"` or `"normalized"`. + - `mode`: The mode of the norm. Either `"absolute"`, `"normalized"`, or + `"symmetric"`. - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. - `inner_exponent`: The exponent `p` in the L^p norm. - `outer_exponent`: The exponent `q` the result after aggregation is raised @@ -133,6 +142,8 @@ def spatial_norm( if state_ref is None: if mode == "normalized": raise ValueError("mode 'normalized' requires state_ref") + if mode == "symmetric": + raise ValueError("mode 'symmetric' requires state_ref") diff = state else: diff = state - state_ref @@ -157,6 +168,27 @@ def spatial_norm( )(state_ref) normalized_diff_per_channel = diff_norm_per_channel / ref_norm_per_channel norm_per_channel = normalized_diff_per_channel + elif mode == "symmetric": + state_norm_per_channel = jax.vmap( + lambda s: spatial_aggregator( + s, + domain_extent=domain_extent, + inner_exponent=inner_exponent, + outer_exponent=outer_exponent, + ), + )(state) + ref_norm_per_channel = jax.vmap( + lambda r: spatial_aggregator( + r, + domain_extent=domain_extent, + inner_exponent=inner_exponent, + outer_exponent=outer_exponent, + ), + )(state_ref) + symmetric_diff_per_channel = ( + 2 * diff_norm_per_channel / (state_norm_per_channel + ref_norm_per_channel) + ) + norm_per_channel = symmetric_diff_per_channel else: norm_per_channel = diff_norm_per_channel @@ -255,6 +287,55 @@ def nMAE( ) +def sMAE( + u_pred: Float[Array, "C ... N"], + u_ref: Float[Array, "C ... N"], + *, + domain_extent: float = 1.0, +) -> float: + """ + Compute the symmetric mean absolute error (sMAE) between two states. + + ∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ| / (∑_(space) (L/N)ᴰ |uₕ| + ∑_(space) (L/N)ᴰ |uₕʳ|)] + + Given the correct `domain_extent`, this is consistent to the following + functional norm: + + 2 ∫_Ω |u(x) - uʳ(x)| dx / (∫_Ω |u(x)| dx + ∫_Ω |uʳ(x)| dx) + + The channel axis is summed **after** the aggregation. + + !!! tip + To apply this function to a state tensor with a leading batch axis, use + `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As + a helper for this, [`exponax.metrics.mean_metric`][] is provided. + + !!! info + This symmetric metric is bounded between 0 and C with C being the number + of channels. + + + **Arguments:** + + - `u_pred`: The state array, must follow the `Exponax` convention with a + leading channel axis, and either one, two, or three subsequent spatial + axes. + - `u_ref`: The reference state array. Must have the same shape as `u_pred`. + - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be + provide to get the correctly consistent norm. If this metric is used an + optimization objective, it can often be ignored since it only + contributes a multiplicative factor. + """ + return spatial_norm( + u_pred, + u_ref, + mode="symmetric", + domain_extent=domain_extent, + inner_exponent=1.0, + outer_exponent=1.0, + ) + + def MSE( u_pred: Float[Array, "C ... N"], u_ref: Optional[Float[Array, "C ... N"]] = None, @@ -347,6 +428,55 @@ def nMSE( ) +def sMSE( + u_pred: Float[Array, "C ... N"], + u_ref: Float[Array, "C ... N"], + *, + domain_extent: float = 1.0, +) -> float: + """ + Compute the symmetric mean squared error (sMSE) between two states. + + ∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|² / (∑_(space) (L/N)ᴰ |uₕ|² + ∑_(space) (L/N)ᴰ |uₕʳ|²)] + + Given the correct `domain_extent`, this is consistent to the following + functional norm: + + 2 ∫_Ω |u(x) - uʳ(x)|² dx / (∫_Ω |u(x)|² dx + ∫_Ω |uʳ(x)|² dx) + + The channel axis is summed **after** the aggregation. + + !!! tip + To apply this function to a state tensor with a leading batch axis, use + `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As + a helper for this, [`exponax.metrics.mean_metric`][] is provided. + + !!! info + This symmetric metric is bounded between 0 and C with C being the number + of channels. + + + **Arguments:** + + - `u_pred`: The state array, must follow the `Exponax` convention with a + leading channel axis, and either one, two, or three subsequent spatial + axes. + - `u_ref`: The reference state array. Must have the same shape as `u_pred`. + - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be + provide to get the correctly consistent norm. If this metric is used an + optimization objective, it can often be ignored since it only + contributes a multiplicative factor. + """ + return spatial_norm( + u_pred, + u_ref, + mode="symmetric", + domain_extent=domain_extent, + inner_exponent=2.0, + outer_exponent=1.0, + ) + + def RMSE( u_pred: Float[Array, "C ... N"], u_ref: Optional[Float[Array, "C ... N"]] = None, @@ -361,7 +491,7 @@ def RMSE( Given the correct `domain_extent`, this is consistent to the following functional norm: - (‖ u - uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx) + (‖ u - uʳ ‖_L²(Ω)) = √(∫_Ω |u(x) - uʳ(x)|² dx) The channel axis is summed **after** the aggregation. Hence, it is also summed **after** the square root. If you need the RMSE per channel, consider @@ -411,7 +541,7 @@ def nRMSE( Given the correct `domain_extent`, this is consistent to the following functional norm: - (‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω + (‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = √(∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω |uʳ(x)|² dx The channel axis is summed **after** the aggregation. Hence, it is also @@ -444,3 +574,56 @@ def nRMSE( inner_exponent=2.0, outer_exponent=0.5, ) + + +def sRMSE( + u_pred: Float[Array, "C ... N"], + u_ref: Float[Array, "C ... N"], + *, + domain_extent: float = 1.0, +) -> float: + """ + Compute the symmetric root mean squared error (sRMSE) between two states. + + ∑_(channels) [2 √(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²) / (√(∑_(space) (L/N)ᴰ + |uₕ|²) + √(∑_(space) (L/N)ᴰ |uₕʳ|²))] + + Given the correct `domain_extent`, this is consistent to the following + functional norm: + + 2 √(∫_Ω |u(x) - uʳ(x)|² dx) / (√(∫_Ω |u(x)|² dx) + √(∫_Ω |uʳ(x)|² dx)) + + The channel axis is summed **after** the aggregation. Hence, it is also + summed **after** the square root and after normalization. If you need more + fine-grained control, consider using + [`exponax.metrics.spatial_aggregator`][] directly. + + !!! tip + To apply this function to a state tensor with a leading batch axis, use + `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As + a helper for this, [`exponax.metrics.mean_metric`][] is provided. + + !!! info + This symmetric metric is bounded between 0 and C with C being the number + of channels. + + + **Arguments:** + + - `u_pred`: The state array, must follow the `Exponax` convention with a + leading channel axis, and either one, two, or three subsequent spatial + axes. + - `u_ref`: The reference state array. Must have the same shape as `u_pred`. + - `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be + provide to get the correctly consistent norm. If this metric is used an + optimization objective, it can often be ignored since it only contributes + a multiplicative factor + """ + return spatial_norm( + u_pred, + u_ref, + mode="symmetric", + domain_extent=domain_extent, + inner_exponent=2.0, + outer_exponent=0.5, + ) diff --git a/exponax/metrics/_utils.py b/exponax/metrics/_utils.py index 22a93a1..8647331 100644 --- a/exponax/metrics/_utils.py +++ b/exponax/metrics/_utils.py @@ -11,4 +11,5 @@ def mean_metric( 'meanifies' a metric function to operate on arrays with a leading batch axis """ wrapped_fn = lambda *a: metric_fn(*a, **kwargs) - return jnp.mean(jax.vmap(wrapped_fn)(*args)) + metric_per_sample = jax.vmap(wrapped_fn, in_axes=0)(*args) + return jnp.mean(metric_per_sample, axis=0) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 6b93d9e..0e9af73 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -35,6 +35,15 @@ def test_constant_offset(num_spatial_dims: int): assert ex.metrics.nMSE(u_0, u_1) == pytest.approx((2.0 - 4.0) ** 2 / (4.0) ** 2) assert ex.metrics.nMSE(u_0, u_1) == pytest.approx(1 / 4) + # == approx(0.4 + assert ex.metrics.sMSE(u_1, u_0) == pytest.approx( + 2.0 * (4.0 - 2.0) ** 2 / ((2.0) ** 2 + (4.0) ** 2) + ) + assert ex.metrics.sMSE(u_1, u_0) == pytest.approx(0.4) + + # Symmetric metric must be symmetric + assert ex.metrics.sMSE(u_0, u_1) == ex.metrics.sMSE(u_1, u_0) + assert ex.metrics.RMSE(u_1, u_0, domain_extent=1.0) == pytest.approx(2.0) assert ex.metrics.RMSE(u_1, u_0, domain_extent=DOMAIN_EXTENT) == pytest.approx( jnp.sqrt(DOMAIN_EXTENT**num_spatial_dims * 4.0) @@ -60,6 +69,9 @@ def test_constant_offset(num_spatial_dims: int): ) assert ex.metrics.nRMSE(u_0, u_1) == pytest.approx(0.5) + # == approx(2/3) + assert ex.metrics.sRMSE(u_1, u_0) == pytest.approx(2 / 3) + # The Fourier nRMSE should be identical to the spatial nRMSE # assert ex.metrics.fourier_nRMSE(u_1, u_0) == ex.metrics.nRMSE(u_1, u_0) # assert ex.metrics.fourier_nRMSE(u_0, u_1) == ex.metrics.nRMSE(u_0, u_1)