Skip to content

Commit

Permalink
Change 'get' to 'make'
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 1, 2024
1 parent 56c6133 commit 50ae48e
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 15 deletions.
10 changes: 5 additions & 5 deletions exponax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
from ._spectral import derivative
from ._utils import (
build_ic_set,
get_grid,
make_grid,
repeat,
rollout,
stack_sub_trajectories,
wrap_bc,
)
from ._viz import get_animation, get_grouped_animation
from ._viz import make_animation, make_grouped_animation

__all__ = [
"ForcedStepper",
"normalized",
"poisson",
"RepeatedStepper",
"derivative",
"get_grid",
"get_animation",
"get_grouped_animation",
"make_grid",
"make_animation",
"make_grouped_animation",
"rollout",
"repeat",
"stack_sub_trajectories",
Expand Down
2 changes: 1 addition & 1 deletion exponax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jaxtyping import Array, Float, PRNGKeyArray, PyTree


def get_grid(
def make_grid(
num_spatial_dims: int,
domain_extent: float,
num_points: int,
Expand Down
4 changes: 2 additions & 2 deletions exponax/_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from matplotlib.animation import FuncAnimation


def get_animation(trj, *, vlim=(-1, 1)):
def make_animation(trj, *, vlim=(-1, 1)):
fig, ax = plt.subplots()
im = ax.imshow(
trj[0].squeeze().T, vmin=vlim[0], vmax=vlim[1], cmap="RdBu_r", origin="lower"
Expand All @@ -26,7 +26,7 @@ def animate(i):
return ani


def get_grouped_animation(
def make_grouped_animation(
trj, *, vlim=(-1, 1), grid=(3, 3), figsize=(10, 10), titles=None
):
"""
Expand Down
4 changes: 2 additions & 2 deletions exponax/ic/_base_ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import equinox as eqx
from jaxtyping import Array, Float, PRNGKeyArray

from .._utils import get_grid
from .._utils import make_grid


class BaseIC(eqx.Module, ABC):
Expand Down Expand Up @@ -60,7 +60,7 @@ def __call__(
- `u`: The initial condition evaluated at the grid points.
"""
ic_fun = self.gen_ic_fun(num_points, key=key)
grid = get_grid(
grid = make_grid(
self.num_spatial_dims,
self.domain_extent,
num_points,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_builtin_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_nonlinear_normalized_stepper():
diffusivity = 0.1
convection_scale = 1.0

grid = ex.get_grid(num_spatial_dims, domain_extent, num_points)
grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
u_0 = jnp.sin(2 * jnp.pi * grid / domain_extent) + 0.3

regular_burgers_stepper = ex.stepper.Burgers(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_wrap_bc(num_spatial_dims):
domain_extent = 3.0
num_points = 10

grid = ex.get_grid(num_spatial_dims, domain_extent, num_points)
full_grid = ex.get_grid(num_spatial_dims, domain_extent, num_points, full=True)
grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
full_grid = ex.make_grid(num_spatial_dims, domain_extent, num_points, full=True)

u = jnp.sin(2 * jnp.pi * grid[0:1] / domain_extent)
full_u = jnp.sin(2 * jnp.pi * full_grid[0:1] / domain_extent)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_advection_1d():
4 * 2 * jnp.pi * (x - velocity * t) / domain_extent
)

grid = ex.get_grid(num_spatial_dims, domain_extent, num_points)
grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

Expand Down Expand Up @@ -49,7 +49,7 @@ def test_diffusion_1d():
-((4 * 2 * jnp.pi / domain_extent) ** 2) * diffusivity * t
) * jnp.sin(4 * 2 * jnp.pi * x / domain_extent)

grid = ex.get_grid(num_spatial_dims, domain_extent, num_points)
grid = ex.make_grid(num_spatial_dims, domain_extent, num_points)
u_0 = analytical_solution(0.0, grid)
u_1 = analytical_solution(dt, grid)

Expand Down

0 comments on commit 50ae48e

Please sign in to comment.