Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.3 #221

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Fixed
- Fixed arguments error in helmholtz notebook

### Changed
- `Medium` objects are now `jaxdf.Module`s, which is based on `equinox` modules. It is also a [parametric module for dispatching operators](https://beartype.github.io/plum/parametric.html), meaning that there's a type difference betwee `Medium[FourierSeries]` and `Medium[FiniteDifferences]`, for example.
- The settings of time domain acoustic simulations are now set using a `TimeWavePropagationSettings`. This also includes an attribute to explicity set the reference sound speed.

### Added
- Added a logger in `jwave.logger`

### Removed
- Removed `pressure_from_density` from `jwave.acoustics.conversion`, as it was a duplicate

## [0.1.5] - 2023-09-27
### Added
- Added `numbers_with_smallest_primes` utility to find grids with small primes for efficient FFT when using FourierSeries
Expand Down
46 changes: 45 additions & 1 deletion jwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,53 @@
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.

# nopycln: file
from jaxdf.discretization import *
from jaxdf import (
operator,
Continuous,
Domain,
FiniteDifferences,
FourierSeries,
Field,
Linear,
OnGrid
)

from .acoustics import (
angular_spectrum,
born_iteration,
born_series,
db2neper,
helmholtz_solver_verbose,
helmholtz_solver,
helmholtz,
homogeneous_helmholtz_green,
laplacian_with_pml,
mass_conservation_rhs,
momentum_conservation_rhs,
pml,
pressure_from_density,
rayleigh_integral,
scale_source_helmholtz,
scattering_potential,
simulate_wave_propagation,
spectral,
wave_propagation_symplectic_step,
wavevector,
TimeWavePropagationSettings,
)
from .geometry import (
BLISensors,
DistributedTransducer,
Medium,
Sensors,
Sources,
TimeAxis,
TimeHarmonicSource,
)

from jwave import acoustics as ac
from jwave import geometry as geometry
from jwave import logger as logger
from jwave import phantoms as phantoms
from jwave import signal_processing as signal_processing
from jwave import utils as utils
31 changes: 28 additions & 3 deletions jwave/acoustics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.

# nopycln: file
from .operators import *
from .time_harmonic import *
from .time_varying import *
from .conversion import db2neper
from .operators import (
helmholtz,
laplacian_with_pml,
scale_source_helmholtz,
wavevector,
)
from .time_harmonic import (
angular_spectrum,
born_iteration,
born_series,
helmholtz_solver,
helmholtz_solver_verbose,
homogeneous_helmholtz_green,
rayleigh_integral,
scattering_potential
)
from .time_varying import (
mass_conservation_rhs,
momentum_conservation_rhs,
pressure_from_density,
simulate_wave_propagation,
wave_propagation_symplectic_step,
TimeWavePropagationSettings,
)

from . import spectral
from . import pml
23 changes: 0 additions & 23 deletions jwave/acoustics/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,6 @@
import numpy as np
from jax import numpy as jnp

from jwave.geometry import Sensors


def pressure_from_density(sensors_data: jnp.ndarray, sound_speed: jnp.ndarray,
sensors: Sensors) -> jnp.ndarray:
r"""
Calculate pressure from acoustic density given by the raw output of the
timestepping scheme.

Args:
sensors_data: Raw output of the timestepping scheme.
sound_speed: Sound speed of the medium.
sensors: Sensors object.

Returns:
jnp.ndarray: Pressure time traces at sensor locations
"""
if sensors is None:
return jnp.sum(sensors_data[1], -1) * (sound_speed**2)
else:
return jnp.sum(sensors_data[1], -1) * (sound_speed[sensors.positions]**
2)


def db2neper(
alpha: jnp.ndarray,
Expand Down
18 changes: 9 additions & 9 deletions jwave/acoustics/time_harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@
# Update fields
src = FourierSeries(src.on_grid, domain)
if issubclass(type(medium.sound_speed), FourierSeries):
medium.sound_speed = FourierSeries(c, domain)
else:
medium.sound_speed = c
c = FourierSeries(c, domain)

medium = medium.replace("sound_speed", c)

# Update k0
k0 = k0 * _conversion["dx"]
Expand Down Expand Up @@ -342,7 +342,7 @@

out_field = _cbs_unnorm_units(out_field, _conversion)

return out_field, None
return out_field


@operator
Expand Down Expand Up @@ -377,7 +377,7 @@
G = homogeneous_helmholtz_green(V1 + src, k0=k0, epsilon=epsilon)
V2 = scattering_potential(field - G, k_sq, k0=k0, epsilon=epsilon)

return field - (1j / epsilon) * V2, params
return field - (1j / epsilon) * V2


@operator
Expand All @@ -401,7 +401,7 @@

k = k_sq - k0**2 - 1j * epsilon
out = field * k
return out, params
return out


@operator
Expand Down Expand Up @@ -430,7 +430,7 @@
u_fft = jnp.fft.fftn(u)
Gu_fft = g_fourier * u_fft
Gu = jnp.fft.ifftn(Gu_fft)
return field.replace_params(Gu), params
return field.replace_params(Gu)


@operator
Expand Down Expand Up @@ -500,7 +500,7 @@
# Weights of the Rayleigh integral
weights = jax.vmap(jax.vmap(direc_exp_term, in_axes=(0, 0, 0)),
in_axes=(0, 0, 0))(R[..., 0], R[..., 1], R[..., 2])
return jnp.sum(weights * pressure.on_grid) * area, None
return jnp.sum(weights * pressure.on_grid) * area

Check warning on line 503 in jwave/acoustics/time_harmonic.py

View check run for this annotation

Codecov / codecov/patch

jwave/acoustics/time_harmonic.py#L503

Added line #L503 was not covered by tests


@operator
Expand Down Expand Up @@ -560,7 +560,7 @@
)[0]
elif method == "bicgstab":
out = bicgstab(helm_func, source, guess, tol=tol, maxiter=maxiter)[0]
return -1j * omega * out, None
return -1j * omega * out


def helmholtz_solver_verbose(
Expand Down
Loading
Loading