diff --git a/CHANGELOG.md b/CHANGELOG.md
index ba2c808..d92a775 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Fixed
- Fixed arguments error in helmholtz notebook
+### Changed
+- `Medium` objects are now `jaxdf.Module`s, which is based on `equinox` modules. It is also a [parametric module for dispatching operators](https://beartype.github.io/plum/parametric.html), meaning that there's a type difference betwee `Medium[FourierSeries]` and `Medium[FiniteDifferences]`, for example.
+- The settings of time domain acoustic simulations are now set using a `TimeWavePropagationSettings`. This also includes an attribute to explicity set the reference sound speed.
+
+### Added
+- Added a logger in `jwave.logger`
+
+### Removed
+- Removed `pressure_from_density` from `jwave.acoustics.conversion`, as it was a duplicate
+
## [0.1.5] - 2023-09-27
### Added
- Added `numbers_with_smallest_primes` utility to find grids with small primes for efficient FFT when using FourierSeries
diff --git a/jwave/__init__.py b/jwave/__init__.py
index 2c2d8f0..72fdd6c 100755
--- a/jwave/__init__.py
+++ b/jwave/__init__.py
@@ -14,9 +14,53 @@
# License along with j-Wave. If not, see .
# nopycln: file
-from jaxdf.discretization import *
+from jaxdf import (
+ operator,
+ Continuous,
+ Domain,
+ FiniteDifferences,
+ FourierSeries,
+ Field,
+ Linear,
+ OnGrid
+)
+
+from .acoustics import (
+ angular_spectrum,
+ born_iteration,
+ born_series,
+ db2neper,
+ helmholtz_solver_verbose,
+ helmholtz_solver,
+ helmholtz,
+ homogeneous_helmholtz_green,
+ laplacian_with_pml,
+ mass_conservation_rhs,
+ momentum_conservation_rhs,
+ pml,
+ pressure_from_density,
+ rayleigh_integral,
+ scale_source_helmholtz,
+ scattering_potential,
+ simulate_wave_propagation,
+ spectral,
+ wave_propagation_symplectic_step,
+ wavevector,
+ TimeWavePropagationSettings,
+)
+from .geometry import (
+ BLISensors,
+ DistributedTransducer,
+ Medium,
+ Sensors,
+ Sources,
+ TimeAxis,
+ TimeHarmonicSource,
+)
from jwave import acoustics as ac
from jwave import geometry as geometry
+from jwave import logger as logger
+from jwave import phantoms as phantoms
from jwave import signal_processing as signal_processing
from jwave import utils as utils
diff --git a/jwave/acoustics/__init__.py b/jwave/acoustics/__init__.py
index 60d9fd6..736873e 100644
--- a/jwave/acoustics/__init__.py
+++ b/jwave/acoustics/__init__.py
@@ -14,6 +14,31 @@
# License along with j-Wave. If not, see .
# nopycln: file
-from .operators import *
-from .time_harmonic import *
-from .time_varying import *
+from .conversion import db2neper
+from .operators import (
+ helmholtz,
+ laplacian_with_pml,
+ scale_source_helmholtz,
+ wavevector,
+)
+from .time_harmonic import (
+ angular_spectrum,
+ born_iteration,
+ born_series,
+ helmholtz_solver,
+ helmholtz_solver_verbose,
+ homogeneous_helmholtz_green,
+ rayleigh_integral,
+ scattering_potential
+)
+from .time_varying import (
+ mass_conservation_rhs,
+ momentum_conservation_rhs,
+ pressure_from_density,
+ simulate_wave_propagation,
+ wave_propagation_symplectic_step,
+ TimeWavePropagationSettings,
+)
+
+from . import spectral
+from . import pml
\ No newline at end of file
diff --git a/jwave/acoustics/conversion.py b/jwave/acoustics/conversion.py
index b52da8f..b7b576a 100644
--- a/jwave/acoustics/conversion.py
+++ b/jwave/acoustics/conversion.py
@@ -16,29 +16,6 @@
import numpy as np
from jax import numpy as jnp
-from jwave.geometry import Sensors
-
-
-def pressure_from_density(sensors_data: jnp.ndarray, sound_speed: jnp.ndarray,
- sensors: Sensors) -> jnp.ndarray:
- r"""
- Calculate pressure from acoustic density given by the raw output of the
- timestepping scheme.
-
- Args:
- sensors_data: Raw output of the timestepping scheme.
- sound_speed: Sound speed of the medium.
- sensors: Sensors object.
-
- Returns:
- jnp.ndarray: Pressure time traces at sensor locations
- """
- if sensors is None:
- return jnp.sum(sensors_data[1], -1) * (sound_speed**2)
- else:
- return jnp.sum(sensors_data[1], -1) * (sound_speed[sensors.positions]**
- 2)
-
def db2neper(
alpha: jnp.ndarray,
diff --git a/jwave/acoustics/time_harmonic.py b/jwave/acoustics/time_harmonic.py
index 7e16a99..ca567e0 100644
--- a/jwave/acoustics/time_harmonic.py
+++ b/jwave/acoustics/time_harmonic.py
@@ -185,9 +185,9 @@ def _cbs_norm_units(medium, omega, k0, src):
# Update fields
src = FourierSeries(src.on_grid, domain)
if issubclass(type(medium.sound_speed), FourierSeries):
- medium.sound_speed = FourierSeries(c, domain)
- else:
- medium.sound_speed = c
+ c = FourierSeries(c, domain)
+
+ medium = medium.replace("sound_speed", c)
# Update k0
k0 = k0 * _conversion["dx"]
@@ -342,7 +342,7 @@ def body_fun(carry):
out_field = _cbs_unnorm_units(out_field, _conversion)
- return out_field, None
+ return out_field
@operator
@@ -377,7 +377,7 @@ def born_iteration(field: Field,
G = homogeneous_helmholtz_green(V1 + src, k0=k0, epsilon=epsilon)
V2 = scattering_potential(field - G, k_sq, k0=k0, epsilon=epsilon)
- return field - (1j / epsilon) * V2, params
+ return field - (1j / epsilon) * V2
@operator
@@ -401,7 +401,7 @@ def scattering_potential(field: Field,
k = k_sq - k0**2 - 1j * epsilon
out = field * k
- return out, params
+ return out
@operator
@@ -430,7 +430,7 @@ def homogeneous_helmholtz_green(field: FourierSeries,
u_fft = jnp.fft.fftn(u)
Gu_fft = g_fourier * u_fft
Gu = jnp.fft.ifftn(Gu_fft)
- return field.replace_params(Gu), params
+ return field.replace_params(Gu)
@operator
@@ -500,7 +500,7 @@ def direc_exp_term(x, y, z):
# Weights of the Rayleigh integral
weights = jax.vmap(jax.vmap(direc_exp_term, in_axes=(0, 0, 0)),
in_axes=(0, 0, 0))(R[..., 0], R[..., 1], R[..., 2])
- return jnp.sum(weights * pressure.on_grid) * area, None
+ return jnp.sum(weights * pressure.on_grid) * area
@operator
@@ -560,7 +560,7 @@ def helm_func(u):
)[0]
elif method == "bicgstab":
out = bicgstab(helm_func, source, guess, tol=tol, maxiter=maxiter)[0]
- return -1j * omega * out, None
+ return -1j * omega * out
def helmholtz_solver_verbose(
diff --git a/jwave/acoustics/time_varying.py b/jwave/acoustics/time_varying.py
index 15dc15c..e008b8c 100755
--- a/jwave/acoustics/time_varying.py
+++ b/jwave/acoustics/time_varying.py
@@ -13,20 +13,21 @@
# You should have received a copy of the GNU Lesser General Public
# License along with j-Wave. If not, see .
-from typing import Dict, Tuple, TypeVar, Union
+from typing import Callable, Dict, Tuple, TypeVar, Union
+import equinox as eqx
import numpy as np
from jax import checkpoint as jax_checkpoint
from jax import numpy as jnp
from jax.lax import scan
from jaxdf import Field, operator
from jaxdf.discretization import FourierSeries, Linear, OnGrid
-from jaxdf.operators import (diag_jacobian, functional, shift_operator,
- sum_over_dims)
+from jaxdf.mods import Module
+from jaxdf.operators import diag_jacobian, shift_operator, sum_over_dims
from jwave.acoustics.spectral import kspace_op
-from jwave.geometry import (Medium, MediumAllScalars, MediumOnGrid, Sources,
- TimeAxis)
+from jwave.geometry import Medium, Sources, TimeAxis
+from jwave.logger import logger
from jwave.signal_processing import smooth
from .pml import td_pml_on_grid
@@ -34,6 +35,54 @@
Any = TypeVar("Any")
+class TimeWavePropagationSettings(Module):
+ """
+ TimeWavePropagationSettings configures the settings for
+ time domain wave solvers. This class serves as a container
+ for settings that influence how wave propagation is
+ simulated.
+
+ !!! example
+ ```python
+ >>> settings = TimeWavePropagationSettings(
+ ... c_ref = lambda m: m.min_sound_speed)
+ >>> print(settings.checkpoint)
+ True
+
+ ```
+ """
+
+ c_ref: Callable = eqx.field(static=True)
+ checkpoint: bool = eqx.field(static=True)
+ smooth_initial: bool = eqx.field(static=True)
+
+ def __init__(
+ self,
+ c_ref: Callable = lambda m: m.max_sound_speed,
+ checkpoint: bool = True,
+ smooth_initial: bool = True,
+ ):
+ """
+ Initializes a new instance of the TimeWavePropagationSettings class.
+
+ Args:
+ c_ref (Callable, static): A callable that determines
+ the reference speed of the wave solver. This is a
+ expected to be a function that takes the `medium`
+ variable and returns the reference sound speed
+ checkpoint (bool, static): Flag indicating whether to
+ use checkpointing to save memory during backpropagation.
+ Defaults to True.
+ smooth_initial (bool, static): Flag to determine
+ whether to smooth initial pressure and velocity
+ fields. Defaults to True.
+ """
+ self.c_ref = c_ref
+ self.checkpoint = checkpoint
+ self.smooth_initial = smooth_initial
+
+
+
def _shift_rho(rho0, direction, dx):
if isinstance(rho0, OnGrid):
rho0_params = rho0.params[..., 0]
@@ -42,7 +91,8 @@ def linear_interp(u, axis):
return 0.5 * (jnp.roll(u, -direction, axis) + u)
rho0 = jnp.stack(
- [linear_interp(rho0_params, n) for n in range(rho0.ndim)], axis=-1)
+ [linear_interp(rho0_params, n) for n in range(rho0.domain.ndim)],
+ axis=-1)
elif isinstance(rho0, Field):
rho0 = shift_operator(rho0, direction * dx)
else:
@@ -75,7 +125,7 @@ def momentum_conservation_rhs(p: OnGrid,
dx = np.asarray(u.domain.dx)
rho0 = _shift_rho(medium.density, 1, dx)
dp = diag_jacobian(p, stagger=[0.5])
- return -dp / rho0, params
+ return -dp / rho0
@operator
@@ -128,10 +178,10 @@ def single_grad(axis):
iku = jnp.moveaxis(Fx * shift_and_k_op[axis] * k_op, -1, axis)
return jnp.fft.ifftn(iku).real
- dp = jnp.stack([single_grad(i) for i in range(p.ndim)], axis=-1)
+ dp = jnp.stack([single_grad(i) for i in range(p.domain.ndim)], axis=-1)
update = -p.replace_params(dp) / rho0
- return update, params
+ return update
@operator
@@ -166,7 +216,7 @@ def mass_conservation_rhs(p: OnGrid,
# Staggered implementation
du = diag_jacobian(u, stagger=[-0.5])
- update = -du * rho0 + 2 * mass_source / (c0 * p.ndim * dx)
+ update = -du * rho0 + 2 * mass_source / (c0 * p.domain.ndim * dx)
return update, params
@@ -222,17 +272,19 @@ def single_grad(axis, u):
iku = jnp.moveaxis(Fx * shift_and_k_op[axis] * k_op, -1, axis)
return jnp.fft.ifftn(iku).real
- du = jnp.stack([single_grad(i, u.params[..., i]) for i in range(p.ndim)],
- axis=-1)
- update = -p.replace_params(du) * rho0 + 2 * mass_source / (c0 * p.ndim *
- dx)
+ du = jnp.stack(
+ [single_grad(i, u.params[..., i]) for i in range(p.domain.ndim)],
+ axis=-1)
+ update = -p.replace_params(du) * rho0 + 2 * mass_source / (
+ c0 * p.domain.ndim * dx)
- return update, params
+ return update
@operator
def pressure_from_density(rho: Field, medium: Medium, *, params=None) -> Field:
- r"""Compute the pressure field from the density field.
+ r"""Calculate pressure from acoustic density given by the raw output of the
+ timestepping scheme.
Args:
rho (Field): The density field.
@@ -244,41 +296,7 @@ def pressure_from_density(rho: Field, medium: Medium, *, params=None) -> Field:
"""
rho_sum = sum_over_dims(rho)
c0 = medium.sound_speed
- return (c0**2) * rho_sum, params
-
-
-def ongrid_wave_prop_params(
- medium: OnGrid,
- time_axis: TimeAxis,
- *args,
- **kwargs,
-):
- # Check which elements of medium are a field
- x = [
- x for x in [medium.sound_speed, medium.density, medium.attenuation]
- if isinstance(x, Field)
- ][0]
-
- dt = time_axis.dt
- c_ref = functional(medium.sound_speed)(jnp.amax)
-
- # Making PML on grid for rho and u
- def make_pml(staggering=0.0):
- pml_grid = td_pml_on_grid(medium,
- dt,
- c0=c_ref,
- dx=medium.domain.dx[0],
- coord_shift=staggering)
- pml = x.replace_params(pml_grid)
- return pml
-
- pml_rho = make_pml()
- pml_u = make_pml(staggering=0.5)
-
- return {
- "pml_rho": pml_rho,
- "pml_u": pml_u,
- }
+ return (c0**2) * rho_sum
@operator
@@ -319,18 +337,54 @@ def wave_propagation_symplectic_step(
return [p, u, rho]
-@operator
+def ongrid_wave_prop_params(
+ medium: OnGrid,
+ time_axis: TimeAxis,
+ *,
+ settings: TimeWavePropagationSettings,
+ **kwargs,
+):
+ # Check which elements of medium are a field
+ x = [
+ x for x in [medium.sound_speed, medium.density, medium.attenuation]
+ if isinstance(x, Field)
+ ][0]
+
+ dt = time_axis.dt
+
+ # Use settings to determine reference sound speed
+ c_ref = settings.c_ref(medium)
+
+ # Making PML on grid for rho and u
+ def make_pml(staggering=0.0):
+ pml_grid = td_pml_on_grid(medium,
+ dt,
+ c0=c_ref,
+ dx=medium.domain.dx[0],
+ coord_shift=staggering)
+ pml = x.replace_params(pml_grid)
+ return pml
+
+ pml_rho = make_pml()
+ pml_u = make_pml(staggering=0.5)
+
+ return {
+ "pml_rho": pml_rho,
+ "pml_u": pml_u,
+ "c_ref": c_ref,
+ }
+
+
+@operator(init_params=ongrid_wave_prop_params)
def simulate_wave_propagation(
- medium: MediumOnGrid,
+ medium: Medium[OnGrid],
time_axis: TimeAxis,
*,
+ settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
p0=None,
- checkpoint: bool = True,
- max_unroll_checkpoint: int = 10,
- smooth_initial=True,
params=None,
):
r"""Simulate the wave propagation operator.
@@ -368,12 +422,9 @@ def simulate_wave_propagation(
# Setup parameters
output_steps = jnp.arange(0, time_axis.Nt, 1)
dt = time_axis.dt
- c_ref = functional(medium.sound_speed)(jnp.amax)
-
- if params == None:
- params = ongrid_wave_prop_params(medium, time_axis)
# Get parameters
+ c_ref = params["c_ref"]
pml_rho = params["pml_rho"]
pml_u = params["pml_u"]
@@ -391,7 +442,7 @@ def simulate_wave_propagation(
if p0 is None:
p0 = pml_rho.replace_params(jnp.zeros(shape_one))
else:
- if smooth_initial:
+ if settings.smooth_initial:
p0_params = p0.params[..., 0]
p0_params = jnp.expand_dims(smooth(p0_params), -1)
p0 = p0.replace_params(p0_params)
@@ -403,7 +454,7 @@ def simulate_wave_propagation(
# Initialize acoustic density
rho = (p0.replace_params(
jnp.stack([p0.params[..., i]
- for i in range(p0.ndim)], axis=-1)) / p0.ndim)
+ for i in range(p0.domain.ndim)], axis=-1)) / p0.domain.ndim)
rho = rho / (medium.sound_speed**2)
# define functions to integrate
@@ -430,22 +481,26 @@ def scan_fun(fields, n):
p = pressure_from_density(rho, medium)
return [p, u, rho], sensors(p, u, rho)
- if checkpoint:
+ if settings.checkpoint:
scan_fun = jax_checkpoint(scan_fun)
+ logger.debug("Starting simulation using generic OnGrid code")
_, ys = scan(scan_fun, fields, output_steps)
return ys
def fourier_wave_prop_params(
- medium: Union[MediumAllScalars, MediumOnGrid],
+ medium: Medium[FourierSeries],
time_axis: TimeAxis,
- *args,
+ *,
+ settings: TimeWavePropagationSettings,
**kwargs,
):
dt = time_axis.dt
- c_ref = functional(medium.sound_speed)(jnp.amax)
+
+ # Use settings to determine reference sound speed
+ c_ref = settings.c_ref(medium)
# Making PML on grid for rho and u
def make_pml(staggering=0.0):
@@ -467,21 +522,20 @@ def make_pml(staggering=0.0):
"pml_rho": pml_rho,
"pml_u": pml_u,
"fourier": fourier,
+ "c_ref": c_ref
}
@operator(init_params=fourier_wave_prop_params)
def simulate_wave_propagation(
- medium: Union[MediumAllScalars, MediumOnGrid],
+ medium: Medium[FourierSeries],
time_axis: TimeAxis,
*,
+ settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
p0=None,
- checkpoint: bool = True,
- max_unroll_checkpoint: int = 10,
- smooth_initial=True,
params=None,
):
r"""Simulates the wave propagation operator using the PSTD method. This
@@ -520,11 +574,9 @@ def simulate_wave_propagation(
# Setup parameters
output_steps = jnp.arange(0, time_axis.Nt, 1)
dt = time_axis.dt
- c_ref = functional(medium.sound_speed)(jnp.amax)
- if params == None:
- params = fourier_wave_prop_params(medium, time_axis)
# Get parameters
+ c_ref = params["c_ref"]
pml_rho = params["pml_rho"]
pml_u = params["pml_u"]
@@ -542,7 +594,7 @@ def simulate_wave_propagation(
if p0 is None:
p0 = pml_rho.replace_params(jnp.zeros(shape_one))
else:
- if smooth_initial:
+ if settings.smooth_initial:
p0_params = p0.params[..., 0]
p0_params = jnp.expand_dims(smooth(p0_params), -1)
p0 = p0.replace_params(p0_params)
@@ -554,7 +606,7 @@ def simulate_wave_propagation(
# Initialize acoustic density
rho = (p0.replace_params(
jnp.stack([p0.params[..., i]
- for i in range(p0.ndim)], axis=-1)) / p0.ndim)
+ for i in range(p0.domain.ndim)], axis=-1)) / p0.domain.ndim)
rho = rho / (medium.sound_speed**2)
# define functions to integrate
@@ -588,9 +640,10 @@ def scan_fun(fields, n):
return [p, u, rho], sensors(p, u, rho)
# Define the scanning function according to the checkpoint type
- if checkpoint:
+ if settings.checkpoint:
scan_fun = jax_checkpoint(scan_fun)
+ logger.debug("Starting simulation using FourierSeries code")
_, ys = scan(scan_fun, fields, output_steps)
return ys
diff --git a/jwave/geometry.py b/jwave/geometry.py
index 728fd67..fa4b003 100755
--- a/jwave/geometry.py
+++ b/jwave/geometry.py
@@ -17,97 +17,226 @@
from dataclasses import dataclass
from typing import List, Tuple, Union
+import equinox as eqx
import numpy as np
from jax import numpy as jnp
from jax.tree_util import register_pytree_node_class
-from jaxdf import Field, FourierSeries, OnGrid
+from jaxdf import Field, FourierSeries
from jaxdf.geometry import Domain
+from jaxdf.mods import Module
from jaxdf.operators import dot_product, functional
-from plum import parametric, type_of
+from jaxtyping import Array
+from plum import parametric
-Number = Union[float, int]
+from jwave.logger import logger
+Number = Union[float, int]
-@register_pytree_node_class
-class Medium:
- r"""
- Medium structure
- Attributes:
- domain (Domain): domain of the medium
- sound_speed (jnp.ndarray): speed of sound map, can be a scalar
- density (jnp.ndarray): density map, can be a scalar
- attenuation (jnp.ndarray): attenuation map, can be a scalar
- pml_size (int): size of the PML layer in grid-points
+@parametric
+class Medium(Module):
+ """_summary_
- !!! example
+ Args:
+ eqx (_type_): _description_
- ```python
- N = (128,356)
- medium = Medium(
- sound_speed = jnp.ones(N),
- density = jnp.ones(N),.
- attenuation = 0.0,
- pml_size = 15
- )
- ```
+ Raises:
+ ValueError: _description_
+ TypeError: _description_
+ ValueError: _description_
+ Returns:
+ _type_: _description_
"""
domain: Domain
- sound_speed: Union[Number, Field] = 1.0
- density: Union[Number, Field] = 1.0
- attenuation: Union[Number, Field] = 0.0
- pml_size: Number = 20.0
+ sound_speed: Union[Array, Field, float]
+ density: Union[Array, Field, float]
+ attenuation: Union[Array, Field, float]
+ pml_size: float = eqx.field(default=20.0, static=True)
def __init__(self,
- domain,
- sound_speed=1.0,
- density=1.0,
- attenuation=0.0,
- pml_size=20):
- # Check that all domains are the same
- for field in [sound_speed, density, attenuation]:
- if isinstance(field, Field):
- assert domain == field.domain, "All domains must be the same"
-
- # Set the attributes
+ domain: Domain,
+ sound_speed: Union[Array, Field, float] = 1.0,
+ density: Union[Array, Field, float] = 1.0,
+ attenuation: Union[Array, Field, float] = 1.0,
+ pml_size: float = 20.0):
self.domain = domain
+
+ # Check if any input is an Array and none are subclasses of Field
+ inputs_are_arrays = [
+ isinstance(x, Array) and not jnp.isscalar(x)
+ for x in [sound_speed, density, attenuation]
+ ]
+ inputs_are_fields = [
+ issubclass(type(x), Field)
+ for x in [sound_speed, density, attenuation]
+ ]
+
+ if any(inputs_are_arrays) and any(inputs_are_fields):
+ raise ValueError(
+ "Ambiguous inputs for Medium: cannot mix Arrays and Field subclasses."
+ )
+
+ if all(inputs_are_arrays):
+ logger.warning(
+ "All inputs are Arrays. This is not recommended for performance reasons. Consider using Fields instead."
+ )
+
self.sound_speed = sound_speed
self.density = density
self.attenuation = attenuation
+
+ # Converting if needed
+ for field_name in ["sound_speed", "density", "attenuation"]:
+ # Convert to Fourier Series if it is a jax Array and is not a scalar
+ if isinstance(
+ self.__dict__[field_name],
+ Array) and not jnp.isscalar(self.__dict__[field_name]):
+ #logger.info(f"Converting {field_name}, which is an Array, to a FourierSeries before storing it in the Medium object.")
+ self.__dict__[field_name] = FourierSeries(
+ self.__dict__[field_name], domain)
+
+ # Other parameters
self.pml_size = pml_size
+ def __check_init__(self):
+ # Check that all domains are the same
+ for field in [self.sound_speed, self.density, self.attenuation]:
+ if isinstance(field, Field):
+ assert self.domain == field.domain, "The domain of all fields must be the same as the domain of the Medium object."
+
+ @classmethod
+ def __init_type_parameter__(self, t: type):
+ """Check whether the type parameters is valid."""
+ if issubclass(t, Field):
+ return t
+ else:
+ raise TypeError(
+ f"The type parameter of a Medium object must be a subclass of Field. Got {t}"
+ )
+
@property
- def int_pml_size(self) -> int:
- r"""Returns the size of the PML layer as an integer"""
- return int(self.pml_size)
+ def max_sound_speed(self):
+ """
+ Calculate and return the maximum sound speed.
- def tree_flatten(self):
- children = (self.sound_speed, self.density, self.attenuation)
- aux = (self.domain, self.pml_size)
- return (children, aux)
+ This property uses the `sound_speed` method/function and applies the `amax`
+ function from JAX's numpy (jnp) library to find the maximum sound speed value.
- @classmethod
- def tree_unflatten(cls, aux, children):
- sound_speed, density, attenuation = children
- domain, pml_size = aux
- a = cls(domain, sound_speed, density, attenuation, pml_size)
- return a
+ Returns:
+ The maximum sound speed value.
+ """
+ return functional(self.sound_speed)(jnp.amax)
- def __str__(self) -> str:
- return self.__repr__()
+ @property
+ def min_sound_speed(self):
+ """
+ Calculate and return the minimum sound speed.
- def __repr__(self) -> str:
+ This property uses the `sound_speed` method/function and applies the `amin`
+ function from JAX's numpy (jnp) library to find the minimum sound speed value.
- def show_param(pname):
- attr = getattr(self, pname)
- return f"{pname}: " + str(attr)
+ Returns:
+ The minimum sound speed value.
+ """
+ return functional(self.sound_speed)(jnp.amin)
- all_params = [
- "domain", "sound_speed", "density", "attenuation", "pml_size"
- ]
- strings = list(map(lambda x: show_param(x), all_params))
- return "Medium:\n - " + "\n - ".join(strings)
+ @property
+ def max_density(self):
+ """
+ Calculate and return the maximum density.
+
+ This property uses the `density` method/function and applies the `amax`
+ function from JAX's numpy (jnp) library to find the maximum density value.
+
+ Returns:
+ The maximum density value.
+ """
+ return functional(self.density)(jnp.amax)
+
+ @property
+ def min_density(self):
+ """
+ Calculate and return the minimum density.
+
+ This property uses the `density` method/function and applies the `amin`
+ function from JAX's numpy (jnp) library to find the minimum density value.
+
+ Returns:
+ The minimum density value.
+ """
+ return functional(self.density)(jnp.amin)
+
+ @property
+ def max_attenuation(self):
+ """
+ Calculate and return the maximum attenuation.
+
+ This property uses the `attenuation` method/function and applies the `amax`
+ function from JAX's numpy (jnp) library to find the maximum attenuation value.
+
+ Returns:
+ The maximum attenuation value.
+ """
+ return functional(self.attenuation)(jnp.amax)
+
+ @property
+ def min_attenuation(self):
+ """
+ Calculate and return the minimum attenuation.
+
+ This property uses the `attenuation` method/function and applies the `amin`
+ function from JAX's numpy (jnp) library to find the minimum attenuation value.
+
+ Returns:
+ The minimum attenuation value.
+ """
+ return functional(self.attenuation)(jnp.amin)
+
+ @classmethod
+ def __infer_type_parameter__(self, *args, **kwargs):
+ """Inter the type parameter from the arguments. Defaults to FourierSeries if
+ the parameters are all floats"""
+ # Reconstruct kwargs from args
+ keys = self.__init__.__code__.co_varnames[1:]
+ extra_kwargs = dict(zip(keys, args))
+ kwargs.update(extra_kwargs)
+
+ # Get fields types
+ field_inputs = ["sound_speed", "density", "attenuation"]
+ input_types = []
+ for field_name in field_inputs:
+ if field_name in kwargs:
+ field = kwargs[field_name]
+
+ if isinstance(field, Field):
+ input_types.append(type(field))
+
+ # Keep only unique
+ input_types = set(input_types)
+
+ has_fields = len(input_types) > 0
+ if not has_fields:
+ return FourierSeries
+
+ # Check that there are no more than one field type
+ if len(input_types) > 1:
+ raise ValueError(
+ f"All fields must be of the same type or scalars for a Medium object. Got {input_types}"
+ )
+
+ return input_types.pop()
+
+ @classmethod
+ def __le_type_parameter__(self, left, right):
+ assert len(left) == 1 and len(
+ right) == 1, "Medium type parameters can't be tuples."
+ return issubclass(left[0], right[0])
+
+ @property
+ def int_pml_size(self) -> int:
+ r"""Returns the size of the PML layer as an integer"""
+ return int(self.pml_size)
def points_on_circle(
@@ -140,36 +269,6 @@ def points_on_circle(
return x, y
-@parametric(runtime_type_of=True)
-class MediumType(Medium):
- """A type for Medium objects that depends on the discretization of its components"""
-
-
-@type_of.dispatch
-def type_of(m: Medium):
- return MediumType[type(m.sound_speed),
- type(m.density),
- type(m.attenuation)]
-
-
-MediumAllScalars = MediumType[object, object, object]
-"""A type for Medium objects that have all scalar components"""
-
-MediumFourierSeries = Union[
- MediumType[FourierSeries, object, object],
- MediumType[object, FourierSeries, object],
- MediumType[object, object, FourierSeries],
-]
-"""A type for Medium objects that have at least one FourierSeries component"""
-
-MediumOnGrid = Union[
- MediumType[OnGrid, object, object],
- MediumType[object, OnGrid, object],
- MediumType[object, object, OnGrid],
-]
-"""A type for Medium objects that have at least one OnGrid component"""
-
-
def unit_fibonacci_sphere(
samples: int = 128) -> List[Tuple[float, float, float]]:
"""
diff --git a/jwave/logger.py b/jwave/logger.py
new file mode 100644
index 0000000..4aadd2e
--- /dev/null
+++ b/jwave/logger.py
@@ -0,0 +1,40 @@
+import logging
+
+# Initialize the logger
+logger = logging.getLogger(__name__.split(".")[0])
+logger.setLevel(logging.INFO)
+
+# Create a console handler
+ch = logging.StreamHandler()
+ch.setLevel(logging.INFO)
+
+# Create a formatter and add it to the handler
+formatter = logging.Formatter(
+ '%(asctime)s - %(name)s [%(levelname)s]: %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S')
+ch.setFormatter(formatter)
+
+# Add the handler to the logger
+logger.addHandler(ch)
+
+
+# Function to set logging level
+def set_logging_level(level: int) -> None:
+ """
+ Set the logging level for both the logger and all its handlers.
+
+ This function updates the logging level of the logger to the specified
+ level and also iterates through all the handlers associated with the logger,
+ updating their logging levels to match the specified level.
+
+ Parameters:
+ level (int): An integer representing the logging level. This should be one
+ of the logging level constants defined in the logging module, such as
+ logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, or logging.CRITICAL.
+
+ Returns:
+ None
+ """
+ logger.setLevel(level)
+ for handler in logger.handlers:
+ handler.setLevel(level)
diff --git a/pyproject.toml b/pyproject.toml
index 42818c8..a8a0d65 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,7 +51,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.9"
-jaxdf = "^0.2.4"
+jaxdf = "^0.2.7"
matplotlib = "^3.0.0"
[tool.poetry.group.dev.dependencies]
@@ -108,9 +108,12 @@ split_before_logical_operator = true
[tool.pytest.ini_options]
addopts = """\
- --doctest-modules \
+ --doctest-modules\
"""
+[tool.pytest_env]
+CUDA_VISIBLE_DEVICES = ""
+
[tool.coverage.report]
exclude_lines = [
'if TYPE_CHECKING:',
diff --git a/tests/acoustics/test_simulate_wave_propagation.py b/tests/acoustics/test_simulate_wave_propagation.py
new file mode 100644
index 0000000..d4203cb
--- /dev/null
+++ b/tests/acoustics/test_simulate_wave_propagation.py
@@ -0,0 +1,44 @@
+import logging
+from io import StringIO
+
+from jax import numpy as jnp
+
+from jwave.acoustics import simulate_wave_propagation
+from jwave.geometry import Domain, FourierSeries, Medium, TimeAxis
+from jwave.logger import logger, set_logging_level
+
+
+def test_correct_call():
+ domain = Domain((100, ), (1., ))
+ fs = FourierSeries(jnp.ones((100, )), domain)
+ medium = Medium(domain, sound_speed=fs)
+ p0 = FourierSeries(jnp.zeros((100, )), domain)
+ tax = TimeAxis.from_medium(medium)
+ tax.t_end = 1.0
+
+ # Create a StringIO object to capture log output
+ log_capture_string = StringIO()
+ ch = logging.StreamHandler(log_capture_string)
+
+ # Add the custom handler to the logger
+ logger.addHandler(ch)
+ set_logging_level(logging.DEBUG)
+
+ # Run the function
+ p = simulate_wave_propagation(medium, tax, p0=p0)
+
+ # Remove the handler after capturing the logs
+ logger.removeHandler(ch)
+
+ # Get the log output from the StringIO object
+ log_contents = log_capture_string.getvalue()
+
+ # Restore logging level
+ set_logging_level(logging.INFO)
+
+ # Perform assertions on the log contents
+ assert "Starting simulation using FourierSeries code" in log_contents
+
+
+if __name__ == "__main__":
+ test_correct_call()
diff --git a/tests/geometry/test_geometry.py b/tests/geometry/test_geometry.py
index 0e12fef..e05b67a 100644
--- a/tests/geometry/test_geometry.py
+++ b/tests/geometry/test_geometry.py
@@ -1,29 +1,9 @@
import numpy as np
-from jax import numpy as jnp
-from jwave.geometry import (Domain, Medium, fibonacci_sphere, points_on_circle,
+from jwave.geometry import (fibonacci_sphere, points_on_circle,
unit_fibonacci_sphere)
-def test_repr():
- # Create Domain object. Replace with correct constructor based on your implementation.
- domain = Domain()
-
- N = (8, 9)
- medium = Medium(domain=domain,
- sound_speed=jnp.ones(N),
- density=jnp.ones(N),
- attenuation=0.0,
- pml_size=15)
-
- expected_output = "Medium:\n - domain: {}\n - sound_speed: {}\n - density: {}\n - attenuation: {}\n - pml_size: {}".format(
- str(medium.domain), str(medium.sound_speed), str(medium.density),
- str(medium.attenuation), str(medium.pml_size))
-
- # Check that the __repr__ method output matches the expected output
- assert str(medium) == expected_output
-
-
def testpoints_on_circle():
n = 5
radius = 10.0
@@ -68,3 +48,9 @@ def testfibonacci_sphere():
(y[i] - centre[1])**2 +
(z[i] - centre[2])**2)
assert np.isclose(distance_from_centre, radius, atol=1e-5)
+
+
+if __name__ == "__main__":
+ testpoints_on_circle()
+ testunit_fibonacci_sphere()
+ testfibonacci_sphere()
diff --git a/tests/geometry/test_medium.py b/tests/geometry/test_medium.py
new file mode 100644
index 0000000..80a40ba
--- /dev/null
+++ b/tests/geometry/test_medium.py
@@ -0,0 +1,51 @@
+import pytest
+from jax import numpy as jnp
+
+from jwave import FourierSeries, OnGrid
+from jwave.geometry import Domain, Medium
+
+
+# Tests for Medium class
+def test_medium_type_with_fourier_series():
+ domain = Domain((10, ), (1., ))
+ fs = FourierSeries(jnp.zeros((10, )), domain)
+
+ m = Medium(domain, sound_speed=fs)
+ assert isinstance(
+ m, Medium[FourierSeries]), "Type should be Medium[FourierSeries]"
+
+ m = Medium(domain, sound_speed=fs, density=10.0)
+ assert isinstance(
+ m, Medium[FourierSeries]), "Type should be Medium[FourierSeries]"
+
+ m = Medium(domain)
+ assert isinstance(
+ m, Medium[FourierSeries]), "Type should be Medium[FourierSeries]"
+
+
+def test_medium_type_with_on_grid():
+ domain = Domain((10, ), (1., ))
+ params = jnp.ones((10, ))
+ fd = OnGrid(params, domain)
+
+ m = Medium(domain, density=fd)
+ assert isinstance(m, Medium[OnGrid]), "Type should be Medium[OnGrid]"
+
+
+def test_medium_type_mismatch():
+ domain = Domain((10, ), (1., ))
+ fs = FourierSeries(jnp.zeros((10, )), domain)
+ params = jnp.ones((10, ))
+ fd = OnGrid(params, domain)
+
+ with pytest.raises(ValueError):
+ m = Medium(domain, sound_speed=fs, density=fd)
+
+ with pytest.raises(TypeError):
+ m = Medium[int](domain, sound_speed=fs)
+
+
+if __name__ == "__main__":
+ test_medium_type_with_fourier_series()
+ test_medium_type_with_on_grid()
+ test_medium_type_mismatch()
diff --git a/tests/test_kwave_3d_ivp.py b/tests/test_kwave_3d_ivp.py
index 5fbb2f6..80f4c60 100644
--- a/tests/test_kwave_3d_ivp.py
+++ b/tests/test_kwave_3d_ivp.py
@@ -25,7 +25,8 @@
from scipy.io import loadmat, savemat
from jwave import FourierSeries
-from jwave.acoustics import simulate_wave_propagation
+from jwave.acoustics import (TimeWavePropagationSettings,
+ simulate_wave_propagation)
from jwave.geometry import Domain, Medium, TimeAxis, sphere_mask
from jwave.utils import plot_comparison
@@ -135,14 +136,17 @@ def test_ivp(test_name, use_plots=False):
)
time_axis = TimeAxis.from_medium(medium, cfl=0.5, t_end=2.5e-6)
+ # Define simulation settings
+ sim_settings = TimeWavePropagationSettings(
+ smooth_initial=settings["smooth_initial"])
+
# Run simulation
@partial(jit, backend="cpu")
def run_simulation(p0):
- return simulate_wave_propagation(
- medium,
- time_axis,
- p0=p0,
- smooth_initial=settings["smooth_initial"])
+ return simulate_wave_propagation(medium,
+ time_axis,
+ p0=p0,
+ settings=sim_settings)
# Extract last field
p_final = run_simulation(p0)[-1].on_grid[..., 0]
diff --git a/tests/test_kwave_ivp.py b/tests/test_kwave_ivp.py
index 62b8b59..989c78b 100644
--- a/tests/test_kwave_ivp.py
+++ b/tests/test_kwave_ivp.py
@@ -25,7 +25,8 @@
from scipy.io import loadmat, savemat
from jwave import FourierSeries
-from jwave.acoustics import simulate_wave_propagation
+from jwave.acoustics import (TimeWavePropagationSettings,
+ simulate_wave_propagation)
from jwave.geometry import Domain, Medium, TimeAxis, circ_mask
from jwave.utils import plot_comparison
@@ -219,14 +220,17 @@ def test_ivp(test_name, use_plots=False):
)
time_axis = TimeAxis.from_medium(medium, cfl=0.5, t_end=5e-6)
+ # Define simulation settings
+ sim_settings = TimeWavePropagationSettings(
+ smooth_initial=settings["smooth_initial"])
+
# Run simulation
@partial(jit, backend="cpu")
def run_simulation(p0):
- return simulate_wave_propagation(
- medium,
- time_axis,
- p0=p0,
- smooth_initial=settings["smooth_initial"])
+ return simulate_wave_propagation(medium,
+ time_axis,
+ p0=p0,
+ settings=sim_settings)
# Extract last field
p_final = run_simulation(p0)[-1].on_grid[:, :, 0]
diff --git a/tests/test_kwave_ivp_fd.py b/tests/test_kwave_ivp_fd.py
index 1bbaaa1..d6aeeec 100644
--- a/tests/test_kwave_ivp_fd.py
+++ b/tests/test_kwave_ivp_fd.py
@@ -25,7 +25,8 @@
from scipy.io import loadmat, savemat
from jwave import FiniteDifferences
-from jwave.acoustics import simulate_wave_propagation
+from jwave.acoustics import (TimeWavePropagationSettings,
+ simulate_wave_propagation)
from jwave.geometry import Domain, Medium, TimeAxis, circ_mask
from jwave.utils import plot_comparison
@@ -151,14 +152,17 @@ def test_ivp(test_name, use_plots=False):
)
time_axis = TimeAxis.from_medium(medium, cfl=0.1, t_end=4e-6)
+ # Define simulation settings
+ sim_settings = TimeWavePropagationSettings(
+ smooth_initial=settings["smooth_initial"])
+
# Run simulation
@partial(jit, backend="cpu")
def run_simulation(p0):
- return simulate_wave_propagation(
- medium,
- time_axis,
- p0=p0,
- smooth_initial=settings["smooth_initial"])
+ return simulate_wave_propagation(medium,
+ time_axis,
+ p0=p0,
+ settings=sim_settings)
# Extract last field
p_final = run_simulation(p0)[-1].on_grid[:, :, 0]