diff --git a/jwave/__init__.py b/jwave/__init__.py index 2c2d8f0..72fdd6c 100755 --- a/jwave/__init__.py +++ b/jwave/__init__.py @@ -14,9 +14,53 @@ # License along with j-Wave. If not, see . # nopycln: file -from jaxdf.discretization import * +from jaxdf import ( + operator, + Continuous, + Domain, + FiniteDifferences, + FourierSeries, + Field, + Linear, + OnGrid +) + +from .acoustics import ( + angular_spectrum, + born_iteration, + born_series, + db2neper, + helmholtz_solver_verbose, + helmholtz_solver, + helmholtz, + homogeneous_helmholtz_green, + laplacian_with_pml, + mass_conservation_rhs, + momentum_conservation_rhs, + pml, + pressure_from_density, + rayleigh_integral, + scale_source_helmholtz, + scattering_potential, + simulate_wave_propagation, + spectral, + wave_propagation_symplectic_step, + wavevector, + TimeWavePropagationSettings, +) +from .geometry import ( + BLISensors, + DistributedTransducer, + Medium, + Sensors, + Sources, + TimeAxis, + TimeHarmonicSource, +) from jwave import acoustics as ac from jwave import geometry as geometry +from jwave import logger as logger +from jwave import phantoms as phantoms from jwave import signal_processing as signal_processing from jwave import utils as utils diff --git a/jwave/acoustics/__init__.py b/jwave/acoustics/__init__.py index 60d9fd6..736873e 100644 --- a/jwave/acoustics/__init__.py +++ b/jwave/acoustics/__init__.py @@ -14,6 +14,31 @@ # License along with j-Wave. If not, see . # nopycln: file -from .operators import * -from .time_harmonic import * -from .time_varying import * +from .conversion import db2neper +from .operators import ( + helmholtz, + laplacian_with_pml, + scale_source_helmholtz, + wavevector, +) +from .time_harmonic import ( + angular_spectrum, + born_iteration, + born_series, + helmholtz_solver, + helmholtz_solver_verbose, + homogeneous_helmholtz_green, + rayleigh_integral, + scattering_potential +) +from .time_varying import ( + mass_conservation_rhs, + momentum_conservation_rhs, + pressure_from_density, + simulate_wave_propagation, + wave_propagation_symplectic_step, + TimeWavePropagationSettings, +) + +from . import spectral +from . import pml \ No newline at end of file diff --git a/jwave/acoustics/time_harmonic.py b/jwave/acoustics/time_harmonic.py index 47845c2..ca567e0 100644 --- a/jwave/acoustics/time_harmonic.py +++ b/jwave/acoustics/time_harmonic.py @@ -342,7 +342,7 @@ def body_fun(carry): out_field = _cbs_unnorm_units(out_field, _conversion) - return out_field, None + return out_field @operator @@ -377,7 +377,7 @@ def born_iteration(field: Field, G = homogeneous_helmholtz_green(V1 + src, k0=k0, epsilon=epsilon) V2 = scattering_potential(field - G, k_sq, k0=k0, epsilon=epsilon) - return field - (1j / epsilon) * V2, params + return field - (1j / epsilon) * V2 @operator @@ -401,7 +401,7 @@ def scattering_potential(field: Field, k = k_sq - k0**2 - 1j * epsilon out = field * k - return out, params + return out @operator @@ -430,7 +430,7 @@ def homogeneous_helmholtz_green(field: FourierSeries, u_fft = jnp.fft.fftn(u) Gu_fft = g_fourier * u_fft Gu = jnp.fft.ifftn(Gu_fft) - return field.replace_params(Gu), params + return field.replace_params(Gu) @operator @@ -500,7 +500,7 @@ def direc_exp_term(x, y, z): # Weights of the Rayleigh integral weights = jax.vmap(jax.vmap(direc_exp_term, in_axes=(0, 0, 0)), in_axes=(0, 0, 0))(R[..., 0], R[..., 1], R[..., 2]) - return jnp.sum(weights * pressure.on_grid) * area, None + return jnp.sum(weights * pressure.on_grid) * area @operator @@ -560,7 +560,7 @@ def helm_func(u): )[0] elif method == "bicgstab": out = bicgstab(helm_func, source, guess, tol=tol, maxiter=maxiter)[0] - return -1j * omega * out, None + return -1j * omega * out def helmholtz_solver_verbose( diff --git a/jwave/acoustics/time_varying.py b/jwave/acoustics/time_varying.py index 3477f10..e008b8c 100755 --- a/jwave/acoustics/time_varying.py +++ b/jwave/acoustics/time_varying.py @@ -82,8 +82,6 @@ def __init__( self.smooth_initial = smooth_initial -default_time_wave_prop_settings = TimeWavePropagationSettings() - def _shift_rho(rho0, direction, dx): if isinstance(rho0, OnGrid): @@ -382,7 +380,7 @@ def simulate_wave_propagation( medium: Medium[OnGrid], time_axis: TimeAxis, *, - settings: TimeWavePropagationSettings = default_time_wave_prop_settings, + settings: TimeWavePropagationSettings = TimeWavePropagationSettings(), sources=None, sensors=None, u0=None, @@ -533,7 +531,7 @@ def simulate_wave_propagation( medium: Medium[FourierSeries], time_axis: TimeAxis, *, - settings: TimeWavePropagationSettings = default_time_wave_prop_settings, + settings: TimeWavePropagationSettings = TimeWavePropagationSettings(), sources=None, sensors=None, u0=None, diff --git a/pyproject.toml b/pyproject.toml index c404323..a8a0d65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,9 +108,12 @@ split_before_logical_operator = true [tool.pytest.ini_options] addopts = """\ - --doctest-modules \ + --doctest-modules\ """ +[tool.pytest_env] +CUDA_VISIBLE_DEVICES = "" + [tool.coverage.report] exclude_lines = [ 'if TYPE_CHECKING:',