diff --git a/jwave/__init__.py b/jwave/__init__.py
index 2c2d8f0..72fdd6c 100755
--- a/jwave/__init__.py
+++ b/jwave/__init__.py
@@ -14,9 +14,53 @@
# License along with j-Wave. If not, see .
# nopycln: file
-from jaxdf.discretization import *
+from jaxdf import (
+ operator,
+ Continuous,
+ Domain,
+ FiniteDifferences,
+ FourierSeries,
+ Field,
+ Linear,
+ OnGrid
+)
+
+from .acoustics import (
+ angular_spectrum,
+ born_iteration,
+ born_series,
+ db2neper,
+ helmholtz_solver_verbose,
+ helmholtz_solver,
+ helmholtz,
+ homogeneous_helmholtz_green,
+ laplacian_with_pml,
+ mass_conservation_rhs,
+ momentum_conservation_rhs,
+ pml,
+ pressure_from_density,
+ rayleigh_integral,
+ scale_source_helmholtz,
+ scattering_potential,
+ simulate_wave_propagation,
+ spectral,
+ wave_propagation_symplectic_step,
+ wavevector,
+ TimeWavePropagationSettings,
+)
+from .geometry import (
+ BLISensors,
+ DistributedTransducer,
+ Medium,
+ Sensors,
+ Sources,
+ TimeAxis,
+ TimeHarmonicSource,
+)
from jwave import acoustics as ac
from jwave import geometry as geometry
+from jwave import logger as logger
+from jwave import phantoms as phantoms
from jwave import signal_processing as signal_processing
from jwave import utils as utils
diff --git a/jwave/acoustics/__init__.py b/jwave/acoustics/__init__.py
index 60d9fd6..736873e 100644
--- a/jwave/acoustics/__init__.py
+++ b/jwave/acoustics/__init__.py
@@ -14,6 +14,31 @@
# License along with j-Wave. If not, see .
# nopycln: file
-from .operators import *
-from .time_harmonic import *
-from .time_varying import *
+from .conversion import db2neper
+from .operators import (
+ helmholtz,
+ laplacian_with_pml,
+ scale_source_helmholtz,
+ wavevector,
+)
+from .time_harmonic import (
+ angular_spectrum,
+ born_iteration,
+ born_series,
+ helmholtz_solver,
+ helmholtz_solver_verbose,
+ homogeneous_helmholtz_green,
+ rayleigh_integral,
+ scattering_potential
+)
+from .time_varying import (
+ mass_conservation_rhs,
+ momentum_conservation_rhs,
+ pressure_from_density,
+ simulate_wave_propagation,
+ wave_propagation_symplectic_step,
+ TimeWavePropagationSettings,
+)
+
+from . import spectral
+from . import pml
\ No newline at end of file
diff --git a/jwave/acoustics/time_harmonic.py b/jwave/acoustics/time_harmonic.py
index 47845c2..ca567e0 100644
--- a/jwave/acoustics/time_harmonic.py
+++ b/jwave/acoustics/time_harmonic.py
@@ -342,7 +342,7 @@ def body_fun(carry):
out_field = _cbs_unnorm_units(out_field, _conversion)
- return out_field, None
+ return out_field
@operator
@@ -377,7 +377,7 @@ def born_iteration(field: Field,
G = homogeneous_helmholtz_green(V1 + src, k0=k0, epsilon=epsilon)
V2 = scattering_potential(field - G, k_sq, k0=k0, epsilon=epsilon)
- return field - (1j / epsilon) * V2, params
+ return field - (1j / epsilon) * V2
@operator
@@ -401,7 +401,7 @@ def scattering_potential(field: Field,
k = k_sq - k0**2 - 1j * epsilon
out = field * k
- return out, params
+ return out
@operator
@@ -430,7 +430,7 @@ def homogeneous_helmholtz_green(field: FourierSeries,
u_fft = jnp.fft.fftn(u)
Gu_fft = g_fourier * u_fft
Gu = jnp.fft.ifftn(Gu_fft)
- return field.replace_params(Gu), params
+ return field.replace_params(Gu)
@operator
@@ -500,7 +500,7 @@ def direc_exp_term(x, y, z):
# Weights of the Rayleigh integral
weights = jax.vmap(jax.vmap(direc_exp_term, in_axes=(0, 0, 0)),
in_axes=(0, 0, 0))(R[..., 0], R[..., 1], R[..., 2])
- return jnp.sum(weights * pressure.on_grid) * area, None
+ return jnp.sum(weights * pressure.on_grid) * area
@operator
@@ -560,7 +560,7 @@ def helm_func(u):
)[0]
elif method == "bicgstab":
out = bicgstab(helm_func, source, guess, tol=tol, maxiter=maxiter)[0]
- return -1j * omega * out, None
+ return -1j * omega * out
def helmholtz_solver_verbose(
diff --git a/jwave/acoustics/time_varying.py b/jwave/acoustics/time_varying.py
index 3477f10..e008b8c 100755
--- a/jwave/acoustics/time_varying.py
+++ b/jwave/acoustics/time_varying.py
@@ -82,8 +82,6 @@ def __init__(
self.smooth_initial = smooth_initial
-default_time_wave_prop_settings = TimeWavePropagationSettings()
-
def _shift_rho(rho0, direction, dx):
if isinstance(rho0, OnGrid):
@@ -382,7 +380,7 @@ def simulate_wave_propagation(
medium: Medium[OnGrid],
time_axis: TimeAxis,
*,
- settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
+ settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
@@ -533,7 +531,7 @@ def simulate_wave_propagation(
medium: Medium[FourierSeries],
time_axis: TimeAxis,
*,
- settings: TimeWavePropagationSettings = default_time_wave_prop_settings,
+ settings: TimeWavePropagationSettings = TimeWavePropagationSettings(),
sources=None,
sensors=None,
u0=None,
diff --git a/pyproject.toml b/pyproject.toml
index c404323..a8a0d65 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -108,9 +108,12 @@ split_before_logical_operator = true
[tool.pytest.ini_options]
addopts = """\
- --doctest-modules \
+ --doctest-modules\
"""
+[tool.pytest_env]
+CUDA_VISIBLE_DEVICES = ""
+
[tool.coverage.report]
exclude_lines = [
'if TYPE_CHECKING:',