From 034b4c546ac0377e6643fd2567e45d54395e922b Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Thu, 12 Sep 2024 09:56:31 +0200 Subject: [PATCH] Quicker spectrum in 1D --- exponax/_spectral.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/exponax/_spectral.py b/exponax/_spectral.py index 8c808dc..6673fc6 100644 --- a/exponax/_spectral.py +++ b/exponax/_spectral.py @@ -909,17 +909,21 @@ def get_spectrum( mode="reconstruction", # because of rfft ) + if power: + magnitude = 0.5 * jnp.abs(state_hat_scaled) ** 2 + else: + magnitude = jnp.abs(state_hat_scaled) + + if num_spatial_dims == 1: + # 1D does not need any binning and can be returned directly + return magnitude + wavenumbers_mesh = build_wavenumbers(num_spatial_dims, num_points) wavenumbers_1d = build_wavenumbers(1, num_points) wavenumbers_norm = jnp.linalg.norm(wavenumbers_mesh, axis=0, keepdims=True) dk = wavenumbers_1d[0, 1] - wavenumbers_1d[0, 0] - if power: - magnitude = 0.5 * jnp.abs(state_hat_scaled) ** 2 - else: - magnitude = jnp.abs(state_hat_scaled) - spectrum = [] def power_in_bucket(p, k):