diff --git a/CHANGELOG.md b/CHANGELOG.md index ba2c808..d92a775 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Fixed - Fixed arguments error in helmholtz notebook +### Changed +- `Medium` objects are now `jaxdf.Module`s, which is based on `equinox` modules. It is also a [parametric module for dispatching operators](https://beartype.github.io/plum/parametric.html), meaning that there's a type difference betwee `Medium[FourierSeries]` and `Medium[FiniteDifferences]`, for example. +- The settings of time domain acoustic simulations are now set using a `TimeWavePropagationSettings`. This also includes an attribute to explicity set the reference sound speed. + +### Added +- Added a logger in `jwave.logger` + +### Removed +- Removed `pressure_from_density` from `jwave.acoustics.conversion`, as it was a duplicate + ## [0.1.5] - 2023-09-27 ### Added - Added `numbers_with_smallest_primes` utility to find grids with small primes for efficient FFT when using FourierSeries diff --git a/jwave/acoustics/conversion.py b/jwave/acoustics/conversion.py index b52da8f..b7b576a 100644 --- a/jwave/acoustics/conversion.py +++ b/jwave/acoustics/conversion.py @@ -16,29 +16,6 @@ import numpy as np from jax import numpy as jnp -from jwave.geometry import Sensors - - -def pressure_from_density(sensors_data: jnp.ndarray, sound_speed: jnp.ndarray, - sensors: Sensors) -> jnp.ndarray: - r""" - Calculate pressure from acoustic density given by the raw output of the - timestepping scheme. - - Args: - sensors_data: Raw output of the timestepping scheme. - sound_speed: Sound speed of the medium. - sensors: Sensors object. - - Returns: - jnp.ndarray: Pressure time traces at sensor locations - """ - if sensors is None: - return jnp.sum(sensors_data[1], -1) * (sound_speed**2) - else: - return jnp.sum(sensors_data[1], -1) * (sound_speed[sensors.positions]** - 2) - def db2neper( alpha: jnp.ndarray, diff --git a/jwave/acoustics/time_harmonic.py b/jwave/acoustics/time_harmonic.py index 7e16a99..47845c2 100644 --- a/jwave/acoustics/time_harmonic.py +++ b/jwave/acoustics/time_harmonic.py @@ -185,9 +185,9 @@ def _cbs_norm_units(medium, omega, k0, src): # Update fields src = FourierSeries(src.on_grid, domain) if issubclass(type(medium.sound_speed), FourierSeries): - medium.sound_speed = FourierSeries(c, domain) - else: - medium.sound_speed = c + c = FourierSeries(c, domain) + + medium = medium.replace("sound_speed", c) # Update k0 k0 = k0 * _conversion["dx"] diff --git a/jwave/acoustics/time_varying.py b/jwave/acoustics/time_varying.py index 15dc15c..3477f10 100755 --- a/jwave/acoustics/time_varying.py +++ b/jwave/acoustics/time_varying.py @@ -13,20 +13,21 @@ # You should have received a copy of the GNU Lesser General Public # License along with j-Wave. If not, see . -from typing import Dict, Tuple, TypeVar, Union +from typing import Callable, Dict, Tuple, TypeVar, Union +import equinox as eqx import numpy as np from jax import checkpoint as jax_checkpoint from jax import numpy as jnp from jax.lax import scan from jaxdf import Field, operator from jaxdf.discretization import FourierSeries, Linear, OnGrid -from jaxdf.operators import (diag_jacobian, functional, shift_operator, - sum_over_dims) +from jaxdf.mods import Module +from jaxdf.operators import diag_jacobian, shift_operator, sum_over_dims from jwave.acoustics.spectral import kspace_op -from jwave.geometry import (Medium, MediumAllScalars, MediumOnGrid, Sources, - TimeAxis) +from jwave.geometry import Medium, Sources, TimeAxis +from jwave.logger import logger from jwave.signal_processing import smooth from .pml import td_pml_on_grid @@ -34,6 +35,56 @@ Any = TypeVar("Any") +class TimeWavePropagationSettings(Module): + """ + TimeWavePropagationSettings configures the settings for + time domain wave solvers. This class serves as a container + for settings that influence how wave propagation is + simulated. + + !!! example + ```python + >>> settings = TimeWavePropagationSettings( + ... c_ref = lambda m: m.min_sound_speed) + >>> print(settings.checkpoint) + True + + ``` + """ + + c_ref: Callable = eqx.field(static=True) + checkpoint: bool = eqx.field(static=True) + smooth_initial: bool = eqx.field(static=True) + + def __init__( + self, + c_ref: Callable = lambda m: m.max_sound_speed, + checkpoint: bool = True, + smooth_initial: bool = True, + ): + """ + Initializes a new instance of the TimeWavePropagationSettings class. + + Args: + c_ref (Callable, static): A callable that determines + the reference speed of the wave solver. This is a + expected to be a function that takes the `medium` + variable and returns the reference sound speed + checkpoint (bool, static): Flag indicating whether to + use checkpointing to save memory during backpropagation. + Defaults to True. + smooth_initial (bool, static): Flag to determine + whether to smooth initial pressure and velocity + fields. Defaults to True. + """ + self.c_ref = c_ref + self.checkpoint = checkpoint + self.smooth_initial = smooth_initial + + +default_time_wave_prop_settings = TimeWavePropagationSettings() + + def _shift_rho(rho0, direction, dx): if isinstance(rho0, OnGrid): rho0_params = rho0.params[..., 0] @@ -42,7 +93,8 @@ def linear_interp(u, axis): return 0.5 * (jnp.roll(u, -direction, axis) + u) rho0 = jnp.stack( - [linear_interp(rho0_params, n) for n in range(rho0.ndim)], axis=-1) + [linear_interp(rho0_params, n) for n in range(rho0.domain.ndim)], + axis=-1) elif isinstance(rho0, Field): rho0 = shift_operator(rho0, direction * dx) else: @@ -75,7 +127,7 @@ def momentum_conservation_rhs(p: OnGrid, dx = np.asarray(u.domain.dx) rho0 = _shift_rho(medium.density, 1, dx) dp = diag_jacobian(p, stagger=[0.5]) - return -dp / rho0, params + return -dp / rho0 @operator @@ -128,10 +180,10 @@ def single_grad(axis): iku = jnp.moveaxis(Fx * shift_and_k_op[axis] * k_op, -1, axis) return jnp.fft.ifftn(iku).real - dp = jnp.stack([single_grad(i) for i in range(p.ndim)], axis=-1) + dp = jnp.stack([single_grad(i) for i in range(p.domain.ndim)], axis=-1) update = -p.replace_params(dp) / rho0 - return update, params + return update @operator @@ -166,7 +218,7 @@ def mass_conservation_rhs(p: OnGrid, # Staggered implementation du = diag_jacobian(u, stagger=[-0.5]) - update = -du * rho0 + 2 * mass_source / (c0 * p.ndim * dx) + update = -du * rho0 + 2 * mass_source / (c0 * p.domain.ndim * dx) return update, params @@ -222,17 +274,19 @@ def single_grad(axis, u): iku = jnp.moveaxis(Fx * shift_and_k_op[axis] * k_op, -1, axis) return jnp.fft.ifftn(iku).real - du = jnp.stack([single_grad(i, u.params[..., i]) for i in range(p.ndim)], - axis=-1) - update = -p.replace_params(du) * rho0 + 2 * mass_source / (c0 * p.ndim * - dx) + du = jnp.stack( + [single_grad(i, u.params[..., i]) for i in range(p.domain.ndim)], + axis=-1) + update = -p.replace_params(du) * rho0 + 2 * mass_source / ( + c0 * p.domain.ndim * dx) - return update, params + return update @operator def pressure_from_density(rho: Field, medium: Medium, *, params=None) -> Field: - r"""Compute the pressure field from the density field. + r"""Calculate pressure from acoustic density given by the raw output of the + timestepping scheme. Args: rho (Field): The density field. @@ -244,41 +298,7 @@ def pressure_from_density(rho: Field, medium: Medium, *, params=None) -> Field: """ rho_sum = sum_over_dims(rho) c0 = medium.sound_speed - return (c0**2) * rho_sum, params - - -def ongrid_wave_prop_params( - medium: OnGrid, - time_axis: TimeAxis, - *args, - **kwargs, -): - # Check which elements of medium are a field - x = [ - x for x in [medium.sound_speed, medium.density, medium.attenuation] - if isinstance(x, Field) - ][0] - - dt = time_axis.dt - c_ref = functional(medium.sound_speed)(jnp.amax) - - # Making PML on grid for rho and u - def make_pml(staggering=0.0): - pml_grid = td_pml_on_grid(medium, - dt, - c0=c_ref, - dx=medium.domain.dx[0], - coord_shift=staggering) - pml = x.replace_params(pml_grid) - return pml - - pml_rho = make_pml() - pml_u = make_pml(staggering=0.5) - - return { - "pml_rho": pml_rho, - "pml_u": pml_u, - } + return (c0**2) * rho_sum @operator @@ -319,18 +339,54 @@ def wave_propagation_symplectic_step( return [p, u, rho] -@operator +def ongrid_wave_prop_params( + medium: OnGrid, + time_axis: TimeAxis, + *, + settings: TimeWavePropagationSettings, + **kwargs, +): + # Check which elements of medium are a field + x = [ + x for x in [medium.sound_speed, medium.density, medium.attenuation] + if isinstance(x, Field) + ][0] + + dt = time_axis.dt + + # Use settings to determine reference sound speed + c_ref = settings.c_ref(medium) + + # Making PML on grid for rho and u + def make_pml(staggering=0.0): + pml_grid = td_pml_on_grid(medium, + dt, + c0=c_ref, + dx=medium.domain.dx[0], + coord_shift=staggering) + pml = x.replace_params(pml_grid) + return pml + + pml_rho = make_pml() + pml_u = make_pml(staggering=0.5) + + return { + "pml_rho": pml_rho, + "pml_u": pml_u, + "c_ref": c_ref, + } + + +@operator(init_params=ongrid_wave_prop_params) def simulate_wave_propagation( - medium: MediumOnGrid, + medium: Medium[OnGrid], time_axis: TimeAxis, *, + settings: TimeWavePropagationSettings = default_time_wave_prop_settings, sources=None, sensors=None, u0=None, p0=None, - checkpoint: bool = True, - max_unroll_checkpoint: int = 10, - smooth_initial=True, params=None, ): r"""Simulate the wave propagation operator. @@ -368,12 +424,9 @@ def simulate_wave_propagation( # Setup parameters output_steps = jnp.arange(0, time_axis.Nt, 1) dt = time_axis.dt - c_ref = functional(medium.sound_speed)(jnp.amax) - - if params == None: - params = ongrid_wave_prop_params(medium, time_axis) # Get parameters + c_ref = params["c_ref"] pml_rho = params["pml_rho"] pml_u = params["pml_u"] @@ -391,7 +444,7 @@ def simulate_wave_propagation( if p0 is None: p0 = pml_rho.replace_params(jnp.zeros(shape_one)) else: - if smooth_initial: + if settings.smooth_initial: p0_params = p0.params[..., 0] p0_params = jnp.expand_dims(smooth(p0_params), -1) p0 = p0.replace_params(p0_params) @@ -403,7 +456,7 @@ def simulate_wave_propagation( # Initialize acoustic density rho = (p0.replace_params( jnp.stack([p0.params[..., i] - for i in range(p0.ndim)], axis=-1)) / p0.ndim) + for i in range(p0.domain.ndim)], axis=-1)) / p0.domain.ndim) rho = rho / (medium.sound_speed**2) # define functions to integrate @@ -430,22 +483,26 @@ def scan_fun(fields, n): p = pressure_from_density(rho, medium) return [p, u, rho], sensors(p, u, rho) - if checkpoint: + if settings.checkpoint: scan_fun = jax_checkpoint(scan_fun) + logger.debug("Starting simulation using generic OnGrid code") _, ys = scan(scan_fun, fields, output_steps) return ys def fourier_wave_prop_params( - medium: Union[MediumAllScalars, MediumOnGrid], + medium: Medium[FourierSeries], time_axis: TimeAxis, - *args, + *, + settings: TimeWavePropagationSettings, **kwargs, ): dt = time_axis.dt - c_ref = functional(medium.sound_speed)(jnp.amax) + + # Use settings to determine reference sound speed + c_ref = settings.c_ref(medium) # Making PML on grid for rho and u def make_pml(staggering=0.0): @@ -467,21 +524,20 @@ def make_pml(staggering=0.0): "pml_rho": pml_rho, "pml_u": pml_u, "fourier": fourier, + "c_ref": c_ref } @operator(init_params=fourier_wave_prop_params) def simulate_wave_propagation( - medium: Union[MediumAllScalars, MediumOnGrid], + medium: Medium[FourierSeries], time_axis: TimeAxis, *, + settings: TimeWavePropagationSettings = default_time_wave_prop_settings, sources=None, sensors=None, u0=None, p0=None, - checkpoint: bool = True, - max_unroll_checkpoint: int = 10, - smooth_initial=True, params=None, ): r"""Simulates the wave propagation operator using the PSTD method. This @@ -520,11 +576,9 @@ def simulate_wave_propagation( # Setup parameters output_steps = jnp.arange(0, time_axis.Nt, 1) dt = time_axis.dt - c_ref = functional(medium.sound_speed)(jnp.amax) - if params == None: - params = fourier_wave_prop_params(medium, time_axis) # Get parameters + c_ref = params["c_ref"] pml_rho = params["pml_rho"] pml_u = params["pml_u"] @@ -542,7 +596,7 @@ def simulate_wave_propagation( if p0 is None: p0 = pml_rho.replace_params(jnp.zeros(shape_one)) else: - if smooth_initial: + if settings.smooth_initial: p0_params = p0.params[..., 0] p0_params = jnp.expand_dims(smooth(p0_params), -1) p0 = p0.replace_params(p0_params) @@ -554,7 +608,7 @@ def simulate_wave_propagation( # Initialize acoustic density rho = (p0.replace_params( jnp.stack([p0.params[..., i] - for i in range(p0.ndim)], axis=-1)) / p0.ndim) + for i in range(p0.domain.ndim)], axis=-1)) / p0.domain.ndim) rho = rho / (medium.sound_speed**2) # define functions to integrate @@ -588,9 +642,10 @@ def scan_fun(fields, n): return [p, u, rho], sensors(p, u, rho) # Define the scanning function according to the checkpoint type - if checkpoint: + if settings.checkpoint: scan_fun = jax_checkpoint(scan_fun) + logger.debug("Starting simulation using FourierSeries code") _, ys = scan(scan_fun, fields, output_steps) return ys diff --git a/jwave/geometry.py b/jwave/geometry.py index 728fd67..fa4b003 100755 --- a/jwave/geometry.py +++ b/jwave/geometry.py @@ -17,97 +17,226 @@ from dataclasses import dataclass from typing import List, Tuple, Union +import equinox as eqx import numpy as np from jax import numpy as jnp from jax.tree_util import register_pytree_node_class -from jaxdf import Field, FourierSeries, OnGrid +from jaxdf import Field, FourierSeries from jaxdf.geometry import Domain +from jaxdf.mods import Module from jaxdf.operators import dot_product, functional -from plum import parametric, type_of +from jaxtyping import Array +from plum import parametric -Number = Union[float, int] +from jwave.logger import logger +Number = Union[float, int] -@register_pytree_node_class -class Medium: - r""" - Medium structure - Attributes: - domain (Domain): domain of the medium - sound_speed (jnp.ndarray): speed of sound map, can be a scalar - density (jnp.ndarray): density map, can be a scalar - attenuation (jnp.ndarray): attenuation map, can be a scalar - pml_size (int): size of the PML layer in grid-points +@parametric +class Medium(Module): + """_summary_ - !!! example + Args: + eqx (_type_): _description_ - ```python - N = (128,356) - medium = Medium( - sound_speed = jnp.ones(N), - density = jnp.ones(N),. - attenuation = 0.0, - pml_size = 15 - ) - ``` + Raises: + ValueError: _description_ + TypeError: _description_ + ValueError: _description_ + Returns: + _type_: _description_ """ domain: Domain - sound_speed: Union[Number, Field] = 1.0 - density: Union[Number, Field] = 1.0 - attenuation: Union[Number, Field] = 0.0 - pml_size: Number = 20.0 + sound_speed: Union[Array, Field, float] + density: Union[Array, Field, float] + attenuation: Union[Array, Field, float] + pml_size: float = eqx.field(default=20.0, static=True) def __init__(self, - domain, - sound_speed=1.0, - density=1.0, - attenuation=0.0, - pml_size=20): - # Check that all domains are the same - for field in [sound_speed, density, attenuation]: - if isinstance(field, Field): - assert domain == field.domain, "All domains must be the same" - - # Set the attributes + domain: Domain, + sound_speed: Union[Array, Field, float] = 1.0, + density: Union[Array, Field, float] = 1.0, + attenuation: Union[Array, Field, float] = 1.0, + pml_size: float = 20.0): self.domain = domain + + # Check if any input is an Array and none are subclasses of Field + inputs_are_arrays = [ + isinstance(x, Array) and not jnp.isscalar(x) + for x in [sound_speed, density, attenuation] + ] + inputs_are_fields = [ + issubclass(type(x), Field) + for x in [sound_speed, density, attenuation] + ] + + if any(inputs_are_arrays) and any(inputs_are_fields): + raise ValueError( + "Ambiguous inputs for Medium: cannot mix Arrays and Field subclasses." + ) + + if all(inputs_are_arrays): + logger.warning( + "All inputs are Arrays. This is not recommended for performance reasons. Consider using Fields instead." + ) + self.sound_speed = sound_speed self.density = density self.attenuation = attenuation + + # Converting if needed + for field_name in ["sound_speed", "density", "attenuation"]: + # Convert to Fourier Series if it is a jax Array and is not a scalar + if isinstance( + self.__dict__[field_name], + Array) and not jnp.isscalar(self.__dict__[field_name]): + #logger.info(f"Converting {field_name}, which is an Array, to a FourierSeries before storing it in the Medium object.") + self.__dict__[field_name] = FourierSeries( + self.__dict__[field_name], domain) + + # Other parameters self.pml_size = pml_size + def __check_init__(self): + # Check that all domains are the same + for field in [self.sound_speed, self.density, self.attenuation]: + if isinstance(field, Field): + assert self.domain == field.domain, "The domain of all fields must be the same as the domain of the Medium object." + + @classmethod + def __init_type_parameter__(self, t: type): + """Check whether the type parameters is valid.""" + if issubclass(t, Field): + return t + else: + raise TypeError( + f"The type parameter of a Medium object must be a subclass of Field. Got {t}" + ) + @property - def int_pml_size(self) -> int: - r"""Returns the size of the PML layer as an integer""" - return int(self.pml_size) + def max_sound_speed(self): + """ + Calculate and return the maximum sound speed. - def tree_flatten(self): - children = (self.sound_speed, self.density, self.attenuation) - aux = (self.domain, self.pml_size) - return (children, aux) + This property uses the `sound_speed` method/function and applies the `amax` + function from JAX's numpy (jnp) library to find the maximum sound speed value. - @classmethod - def tree_unflatten(cls, aux, children): - sound_speed, density, attenuation = children - domain, pml_size = aux - a = cls(domain, sound_speed, density, attenuation, pml_size) - return a + Returns: + The maximum sound speed value. + """ + return functional(self.sound_speed)(jnp.amax) - def __str__(self) -> str: - return self.__repr__() + @property + def min_sound_speed(self): + """ + Calculate and return the minimum sound speed. - def __repr__(self) -> str: + This property uses the `sound_speed` method/function and applies the `amin` + function from JAX's numpy (jnp) library to find the minimum sound speed value. - def show_param(pname): - attr = getattr(self, pname) - return f"{pname}: " + str(attr) + Returns: + The minimum sound speed value. + """ + return functional(self.sound_speed)(jnp.amin) - all_params = [ - "domain", "sound_speed", "density", "attenuation", "pml_size" - ] - strings = list(map(lambda x: show_param(x), all_params)) - return "Medium:\n - " + "\n - ".join(strings) + @property + def max_density(self): + """ + Calculate and return the maximum density. + + This property uses the `density` method/function and applies the `amax` + function from JAX's numpy (jnp) library to find the maximum density value. + + Returns: + The maximum density value. + """ + return functional(self.density)(jnp.amax) + + @property + def min_density(self): + """ + Calculate and return the minimum density. + + This property uses the `density` method/function and applies the `amin` + function from JAX's numpy (jnp) library to find the minimum density value. + + Returns: + The minimum density value. + """ + return functional(self.density)(jnp.amin) + + @property + def max_attenuation(self): + """ + Calculate and return the maximum attenuation. + + This property uses the `attenuation` method/function and applies the `amax` + function from JAX's numpy (jnp) library to find the maximum attenuation value. + + Returns: + The maximum attenuation value. + """ + return functional(self.attenuation)(jnp.amax) + + @property + def min_attenuation(self): + """ + Calculate and return the minimum attenuation. + + This property uses the `attenuation` method/function and applies the `amin` + function from JAX's numpy (jnp) library to find the minimum attenuation value. + + Returns: + The minimum attenuation value. + """ + return functional(self.attenuation)(jnp.amin) + + @classmethod + def __infer_type_parameter__(self, *args, **kwargs): + """Inter the type parameter from the arguments. Defaults to FourierSeries if + the parameters are all floats""" + # Reconstruct kwargs from args + keys = self.__init__.__code__.co_varnames[1:] + extra_kwargs = dict(zip(keys, args)) + kwargs.update(extra_kwargs) + + # Get fields types + field_inputs = ["sound_speed", "density", "attenuation"] + input_types = [] + for field_name in field_inputs: + if field_name in kwargs: + field = kwargs[field_name] + + if isinstance(field, Field): + input_types.append(type(field)) + + # Keep only unique + input_types = set(input_types) + + has_fields = len(input_types) > 0 + if not has_fields: + return FourierSeries + + # Check that there are no more than one field type + if len(input_types) > 1: + raise ValueError( + f"All fields must be of the same type or scalars for a Medium object. Got {input_types}" + ) + + return input_types.pop() + + @classmethod + def __le_type_parameter__(self, left, right): + assert len(left) == 1 and len( + right) == 1, "Medium type parameters can't be tuples." + return issubclass(left[0], right[0]) + + @property + def int_pml_size(self) -> int: + r"""Returns the size of the PML layer as an integer""" + return int(self.pml_size) def points_on_circle( @@ -140,36 +269,6 @@ def points_on_circle( return x, y -@parametric(runtime_type_of=True) -class MediumType(Medium): - """A type for Medium objects that depends on the discretization of its components""" - - -@type_of.dispatch -def type_of(m: Medium): - return MediumType[type(m.sound_speed), - type(m.density), - type(m.attenuation)] - - -MediumAllScalars = MediumType[object, object, object] -"""A type for Medium objects that have all scalar components""" - -MediumFourierSeries = Union[ - MediumType[FourierSeries, object, object], - MediumType[object, FourierSeries, object], - MediumType[object, object, FourierSeries], -] -"""A type for Medium objects that have at least one FourierSeries component""" - -MediumOnGrid = Union[ - MediumType[OnGrid, object, object], - MediumType[object, OnGrid, object], - MediumType[object, object, OnGrid], -] -"""A type for Medium objects that have at least one OnGrid component""" - - def unit_fibonacci_sphere( samples: int = 128) -> List[Tuple[float, float, float]]: """ diff --git a/jwave/logger.py b/jwave/logger.py new file mode 100644 index 0000000..4aadd2e --- /dev/null +++ b/jwave/logger.py @@ -0,0 +1,40 @@ +import logging + +# Initialize the logger +logger = logging.getLogger(__name__.split(".")[0]) +logger.setLevel(logging.INFO) + +# Create a console handler +ch = logging.StreamHandler() +ch.setLevel(logging.INFO) + +# Create a formatter and add it to the handler +formatter = logging.Formatter( + '%(asctime)s - %(name)s [%(levelname)s]: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') +ch.setFormatter(formatter) + +# Add the handler to the logger +logger.addHandler(ch) + + +# Function to set logging level +def set_logging_level(level: int) -> None: + """ + Set the logging level for both the logger and all its handlers. + + This function updates the logging level of the logger to the specified + level and also iterates through all the handlers associated with the logger, + updating their logging levels to match the specified level. + + Parameters: + level (int): An integer representing the logging level. This should be one + of the logging level constants defined in the logging module, such as + logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, or logging.CRITICAL. + + Returns: + None + """ + logger.setLevel(level) + for handler in logger.handlers: + handler.setLevel(level) diff --git a/pyproject.toml b/pyproject.toml index 42818c8..c404323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.9" -jaxdf = "^0.2.4" +jaxdf = "^0.2.7" matplotlib = "^3.0.0" [tool.poetry.group.dev.dependencies] diff --git a/tests/acoustics/test_simulate_wave_propagation.py b/tests/acoustics/test_simulate_wave_propagation.py new file mode 100644 index 0000000..d4203cb --- /dev/null +++ b/tests/acoustics/test_simulate_wave_propagation.py @@ -0,0 +1,44 @@ +import logging +from io import StringIO + +from jax import numpy as jnp + +from jwave.acoustics import simulate_wave_propagation +from jwave.geometry import Domain, FourierSeries, Medium, TimeAxis +from jwave.logger import logger, set_logging_level + + +def test_correct_call(): + domain = Domain((100, ), (1., )) + fs = FourierSeries(jnp.ones((100, )), domain) + medium = Medium(domain, sound_speed=fs) + p0 = FourierSeries(jnp.zeros((100, )), domain) + tax = TimeAxis.from_medium(medium) + tax.t_end = 1.0 + + # Create a StringIO object to capture log output + log_capture_string = StringIO() + ch = logging.StreamHandler(log_capture_string) + + # Add the custom handler to the logger + logger.addHandler(ch) + set_logging_level(logging.DEBUG) + + # Run the function + p = simulate_wave_propagation(medium, tax, p0=p0) + + # Remove the handler after capturing the logs + logger.removeHandler(ch) + + # Get the log output from the StringIO object + log_contents = log_capture_string.getvalue() + + # Restore logging level + set_logging_level(logging.INFO) + + # Perform assertions on the log contents + assert "Starting simulation using FourierSeries code" in log_contents + + +if __name__ == "__main__": + test_correct_call() diff --git a/tests/geometry/test_geometry.py b/tests/geometry/test_geometry.py index 0e12fef..e05b67a 100644 --- a/tests/geometry/test_geometry.py +++ b/tests/geometry/test_geometry.py @@ -1,29 +1,9 @@ import numpy as np -from jax import numpy as jnp -from jwave.geometry import (Domain, Medium, fibonacci_sphere, points_on_circle, +from jwave.geometry import (fibonacci_sphere, points_on_circle, unit_fibonacci_sphere) -def test_repr(): - # Create Domain object. Replace with correct constructor based on your implementation. - domain = Domain() - - N = (8, 9) - medium = Medium(domain=domain, - sound_speed=jnp.ones(N), - density=jnp.ones(N), - attenuation=0.0, - pml_size=15) - - expected_output = "Medium:\n - domain: {}\n - sound_speed: {}\n - density: {}\n - attenuation: {}\n - pml_size: {}".format( - str(medium.domain), str(medium.sound_speed), str(medium.density), - str(medium.attenuation), str(medium.pml_size)) - - # Check that the __repr__ method output matches the expected output - assert str(medium) == expected_output - - def testpoints_on_circle(): n = 5 radius = 10.0 @@ -68,3 +48,9 @@ def testfibonacci_sphere(): (y[i] - centre[1])**2 + (z[i] - centre[2])**2) assert np.isclose(distance_from_centre, radius, atol=1e-5) + + +if __name__ == "__main__": + testpoints_on_circle() + testunit_fibonacci_sphere() + testfibonacci_sphere() diff --git a/tests/geometry/test_medium.py b/tests/geometry/test_medium.py new file mode 100644 index 0000000..80a40ba --- /dev/null +++ b/tests/geometry/test_medium.py @@ -0,0 +1,51 @@ +import pytest +from jax import numpy as jnp + +from jwave import FourierSeries, OnGrid +from jwave.geometry import Domain, Medium + + +# Tests for Medium class +def test_medium_type_with_fourier_series(): + domain = Domain((10, ), (1., )) + fs = FourierSeries(jnp.zeros((10, )), domain) + + m = Medium(domain, sound_speed=fs) + assert isinstance( + m, Medium[FourierSeries]), "Type should be Medium[FourierSeries]" + + m = Medium(domain, sound_speed=fs, density=10.0) + assert isinstance( + m, Medium[FourierSeries]), "Type should be Medium[FourierSeries]" + + m = Medium(domain) + assert isinstance( + m, Medium[FourierSeries]), "Type should be Medium[FourierSeries]" + + +def test_medium_type_with_on_grid(): + domain = Domain((10, ), (1., )) + params = jnp.ones((10, )) + fd = OnGrid(params, domain) + + m = Medium(domain, density=fd) + assert isinstance(m, Medium[OnGrid]), "Type should be Medium[OnGrid]" + + +def test_medium_type_mismatch(): + domain = Domain((10, ), (1., )) + fs = FourierSeries(jnp.zeros((10, )), domain) + params = jnp.ones((10, )) + fd = OnGrid(params, domain) + + with pytest.raises(ValueError): + m = Medium(domain, sound_speed=fs, density=fd) + + with pytest.raises(TypeError): + m = Medium[int](domain, sound_speed=fs) + + +if __name__ == "__main__": + test_medium_type_with_fourier_series() + test_medium_type_with_on_grid() + test_medium_type_mismatch() diff --git a/tests/test_kwave_3d_ivp.py b/tests/test_kwave_3d_ivp.py index 5fbb2f6..80f4c60 100644 --- a/tests/test_kwave_3d_ivp.py +++ b/tests/test_kwave_3d_ivp.py @@ -25,7 +25,8 @@ from scipy.io import loadmat, savemat from jwave import FourierSeries -from jwave.acoustics import simulate_wave_propagation +from jwave.acoustics import (TimeWavePropagationSettings, + simulate_wave_propagation) from jwave.geometry import Domain, Medium, TimeAxis, sphere_mask from jwave.utils import plot_comparison @@ -135,14 +136,17 @@ def test_ivp(test_name, use_plots=False): ) time_axis = TimeAxis.from_medium(medium, cfl=0.5, t_end=2.5e-6) + # Define simulation settings + sim_settings = TimeWavePropagationSettings( + smooth_initial=settings["smooth_initial"]) + # Run simulation @partial(jit, backend="cpu") def run_simulation(p0): - return simulate_wave_propagation( - medium, - time_axis, - p0=p0, - smooth_initial=settings["smooth_initial"]) + return simulate_wave_propagation(medium, + time_axis, + p0=p0, + settings=sim_settings) # Extract last field p_final = run_simulation(p0)[-1].on_grid[..., 0] diff --git a/tests/test_kwave_ivp.py b/tests/test_kwave_ivp.py index 62b8b59..989c78b 100644 --- a/tests/test_kwave_ivp.py +++ b/tests/test_kwave_ivp.py @@ -25,7 +25,8 @@ from scipy.io import loadmat, savemat from jwave import FourierSeries -from jwave.acoustics import simulate_wave_propagation +from jwave.acoustics import (TimeWavePropagationSettings, + simulate_wave_propagation) from jwave.geometry import Domain, Medium, TimeAxis, circ_mask from jwave.utils import plot_comparison @@ -219,14 +220,17 @@ def test_ivp(test_name, use_plots=False): ) time_axis = TimeAxis.from_medium(medium, cfl=0.5, t_end=5e-6) + # Define simulation settings + sim_settings = TimeWavePropagationSettings( + smooth_initial=settings["smooth_initial"]) + # Run simulation @partial(jit, backend="cpu") def run_simulation(p0): - return simulate_wave_propagation( - medium, - time_axis, - p0=p0, - smooth_initial=settings["smooth_initial"]) + return simulate_wave_propagation(medium, + time_axis, + p0=p0, + settings=sim_settings) # Extract last field p_final = run_simulation(p0)[-1].on_grid[:, :, 0] diff --git a/tests/test_kwave_ivp_fd.py b/tests/test_kwave_ivp_fd.py index 1bbaaa1..d6aeeec 100644 --- a/tests/test_kwave_ivp_fd.py +++ b/tests/test_kwave_ivp_fd.py @@ -25,7 +25,8 @@ from scipy.io import loadmat, savemat from jwave import FiniteDifferences -from jwave.acoustics import simulate_wave_propagation +from jwave.acoustics import (TimeWavePropagationSettings, + simulate_wave_propagation) from jwave.geometry import Domain, Medium, TimeAxis, circ_mask from jwave.utils import plot_comparison @@ -151,14 +152,17 @@ def test_ivp(test_name, use_plots=False): ) time_axis = TimeAxis.from_medium(medium, cfl=0.1, t_end=4e-6) + # Define simulation settings + sim_settings = TimeWavePropagationSettings( + smooth_initial=settings["smooth_initial"]) + # Run simulation @partial(jit, backend="cpu") def run_simulation(p0): - return simulate_wave_propagation( - medium, - time_axis, - p0=p0, - smooth_initial=settings["smooth_initial"]) + return simulate_wave_propagation(medium, + time_axis, + p0=p0, + settings=sim_settings) # Extract last field p_final = run_simulation(p0)[-1].on_grid[:, :, 0]