Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated public api #222

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion jwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,53 @@
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.

# nopycln: file
from jaxdf.discretization import *
from jaxdf import (
operator,
Continuous,
Domain,
FiniteDifferences,
FourierSeries,
Field,
Linear,
OnGrid
)

from .acoustics import (
angular_spectrum,
born_iteration,
born_series,
db2neper,
helmholtz_solver_verbose,
helmholtz_solver,
helmholtz,
homogeneous_helmholtz_green,
laplacian_with_pml,
mass_conservation_rhs,
momentum_conservation_rhs,
pml,
pressure_from_density,
rayleigh_integral,
scale_source_helmholtz,
scattering_potential,
simulate_wave_propagation,
spectral,
wave_propagation_symplectic_step,
wavevector,
TimeWavePropagationSettings,
)
from .geometry import (
BLISensors,
DistributedTransducer,
Medium,
Sensors,
Sources,
TimeAxis,
TimeHarmonicSource,
)

from jwave import acoustics as ac
from jwave import geometry as geometry
from jwave import logger as logger
from jwave import phantoms as phantoms
from jwave import signal_processing as signal_processing
from jwave import utils as utils
31 changes: 28 additions & 3 deletions jwave/acoustics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
# License along with j-Wave. If not, see <https://www.gnu.org/licenses/>.

# nopycln: file
from .operators import *
from .time_harmonic import *
from .time_varying import *
from .conversion import db2neper
from .operators import (
helmholtz,
laplacian_with_pml,
scale_source_helmholtz,
wavevector,
)
from .time_harmonic import (
angular_spectrum,
born_iteration,
born_series,
helmholtz_solver,
helmholtz_solver_verbose,
homogeneous_helmholtz_green,
rayleigh_integral,
scattering_potential
)
from .time_varying import (
mass_conservation_rhs,
momentum_conservation_rhs,
pressure_from_density,
simulate_wave_propagation,
wave_propagation_symplectic_step,
TimeWavePropagationSettings,
)

from . import spectral
from . import pml
12 changes: 6 additions & 6 deletions jwave/acoustics/time_harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@

out_field = _cbs_unnorm_units(out_field, _conversion)

return out_field, None
return out_field


@operator
Expand Down Expand Up @@ -377,7 +377,7 @@
G = homogeneous_helmholtz_green(V1 + src, k0=k0, epsilon=epsilon)
V2 = scattering_potential(field - G, k_sq, k0=k0, epsilon=epsilon)

return field - (1j / epsilon) * V2, params
return field - (1j / epsilon) * V2


@operator
Expand All @@ -401,7 +401,7 @@

k = k_sq - k0**2 - 1j * epsilon
out = field * k
return out, params
return out


@operator
Expand Down Expand Up @@ -430,7 +430,7 @@
u_fft = jnp.fft.fftn(u)
Gu_fft = g_fourier * u_fft
Gu = jnp.fft.ifftn(Gu_fft)
return field.replace_params(Gu), params
return field.replace_params(Gu)


@operator
Expand Down Expand Up @@ -500,7 +500,7 @@
# Weights of the Rayleigh integral
weights = jax.vmap(jax.vmap(direc_exp_term, in_axes=(0, 0, 0)),
in_axes=(0, 0, 0))(R[..., 0], R[..., 1], R[..., 2])
return jnp.sum(weights * pressure.on_grid) * area, None
return jnp.sum(weights * pressure.on_grid) * area

Check warning on line 503 in jwave/acoustics/time_harmonic.py

View check run for this annotation

Codecov / codecov/patch

jwave/acoustics/time_harmonic.py#L503

Added line #L503 was not covered by tests


@operator
Expand Down Expand Up @@ -560,7 +560,7 @@
)[0]
elif method == "bicgstab":
out = bicgstab(helm_func, source, guess, tol=tol, maxiter=maxiter)[0]
return -1j * omega * out, None
return -1j * omega * out


def helmholtz_solver_verbose(
Expand Down
6 changes: 2 additions & 4 deletions jwave/acoustics/time_varying.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def __init__(
self.smooth_initial = smooth_initial


default_time_wave_prop_settings = TimeWavePropagationSettings()


def _shift_rho(rho0, direction, dx):
if isinstance(rho0, OnGrid):
Expand Down Expand Up @@ -382,7 +380,7 @@ def simulate_wave_propagation(
medium: Medium[OnGrid],
time_axis: TimeAxis,
*,
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
Expand Down Expand Up @@ -533,7 +531,7 @@ def simulate_wave_propagation(
medium: Medium[FourierSeries],
time_axis: TimeAxis,
*,
settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,12 @@ split_before_logical_operator = true

[tool.pytest.ini_options]
addopts = """\
--doctest-modules \
--doctest-modules\
"""

[tool.pytest_env]
CUDA_VISIBLE_DEVICES = ""

[tool.coverage.report]
exclude_lines = [
'if TYPE_CHECKING:',
Expand Down