Skip to content

Commit

Permalink
refactor: np.where(cond) -> np.asarray(cond).nonzero() (#2238)
Browse files Browse the repository at this point in the history
The docs for `np.where()` (https://numpy.org/doc/stable/reference/generated/numpy.where.html) suggest to prefer `nonzero()` over `where()` without `x` and `y` arguments. In the spirit of defensive programming I included `np.asarray(cond)` even where `cond` is already an array.

This PR also fixes a bug I introduced in the model splitter in #2124: while E711 (https://www.flake8rules.com/rules/E711.html) dictates comparisons to `None` should use identity rather than equality, this rule should not be applied to NumPy array selection conditions as it will change the semantics:

>>> a = np.array([None, None])
>>> a[a != None]
array([], dtype=object)
>>> a[a is not None]
array([[None, None]], dtype=object)

Unrelatedly, mark `test_mt3d.py::test_mfnwt_keat_uzf()` slow, it should not be included in smoke tests (this was causing the optional dependency CI tests to fail due to timeout). And clean up some unused imports in `conftest.py`.
  • Loading branch information
wpbonelli authored Jun 17, 2024
1 parent 18dfcb0 commit 59040d0
Show file tree
Hide file tree
Showing 35 changed files with 235 additions and 208 deletions.
4 changes: 1 addition & 3 deletions autotest/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import re
from importlib import metadata
from io import BytesIO, StringIO
from pathlib import Path
from platform import system
from typing import List, Optional
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import pytest
from modflow_devtools.misc import is_in_ci

Expand Down
6 changes: 3 additions & 3 deletions autotest/test_lake_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_lake(function_tmpdir, example_data_path):
# mm.plot_array(bot_tm)

# determine a reasonable lake bottom
idx = np.where(lakes > -1)
idx = np.asarray(lakes > -1).nonzero()
lak_bot = bot_tm[idx].max() + 2.0

# interpolate top elevations
Expand Down Expand Up @@ -634,9 +634,9 @@ def test_embedded_lak_prudic_mixed(example_data_path):
lake_map[0, :, :] = lakibd[:, :] - 1

lakebed_leakance = np.zeros(shape2d, dtype=object)
idx = np.where(lake_map[0, :, :] == 0)
idx = np.asarray(lake_map[0, :, :] == 0).nonzero()
lakebed_leakance[idx] = "none"
idx = np.where(lake_map[0, :, :] == 1)
idx = np.asarray(lake_map[0, :, :] == 1).nonzero()
lakebed_leakance[idx] = 1.0
lakebed_leakance = lakebed_leakance.tolist()

Expand Down
1 change: 1 addition & 0 deletions autotest/test_mt3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def test_mf2000_zeroth(function_tmpdir, mf2kmt3d_model_path):
assert success, f"{mt.name} did not run"


@pytest.mark.slow
@flaky(max_runs=3)
@requires_exe("mfnwt", "mt3dms")
@excludes_platform(
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_sfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def interpolate_to_reaches(sfr):
sfr.get_slopes(minimum_slope=-100, maximum_slope=100)
reach_inds = 29
outreach = sfr.reach_data.outreach[reach_inds]
out_inds = np.where(sfr.reach_data.reachID == outreach)
out_inds = np.asarray(sfr.reach_data.reachID == outreach).nonzero()
assert (
sfr.reach_data.slope[reach_inds]
== (
Expand Down
4 changes: 2 additions & 2 deletions autotest/test_zonbud_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def test_compare2zonebudget(cbc_f, zon_f, zbud_f, rtol):
zb_arr = zba[zba["totim"] == time]
fp_arr = fpa[fpa["totim"] == time]
for name in fp_arr["name"]:
r1 = np.where(zb_arr["name"] == name)
r2 = np.where(fp_arr["name"] == name)
r1 = np.asarray(zb_arr["name"] == name).nonzero()
r2 = np.asarray(fp_arr["name"] == name).nonzero()
if r1[0].shape[0] < 1 or r2[0].shape[0] < 1:
continue
if r1[0].shape[0] != r2[0].shape[0]:
Expand Down
6 changes: 3 additions & 3 deletions flopy/discretization/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,15 +455,15 @@ def saturated_thickness(self, array, mask=None):
bot = self.remove_confining_beds(bot)
array = self.remove_confining_beds(array)

idx = np.where((array < top) & (array > bot))
idx = np.asarray((array < top) & (array > bot)).nonzero()
thickness[idx] = array[idx] - bot[idx]
idx = np.where(array <= bot)
idx = np.asarray(array <= bot).nonzero()
thickness[idx] = 0.0
if mask is not None:
if isinstance(mask, (float, int)):
mask = [float(mask)]
for mask_value in mask:
thickness[np.where(array == mask_value)] = np.nan
thickness[np.asarray(array == mask_value).nonzero()] = np.nan
return thickness

def saturated_thick(self, array, mask=None):
Expand Down
4 changes: 2 additions & 2 deletions flopy/discretization/structuredgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def intersect(self, x, y, z=None, local=False, forgive=False):
"x, y point given is outside of the model area"
)
else:
col = np.where(xcomp)[0][-1]
col = np.asarray(xcomp).nonzero()[0][-1]

ycomp = y < ye
if np.all(ycomp) or not np.any(ycomp):
Expand All @@ -941,7 +941,7 @@ def intersect(self, x, y, z=None, local=False, forgive=False):
"x, y point given is outside of the model area"
)
else:
row = np.where(ycomp)[0][-1]
row = np.asarray(ycomp).nonzero()[0][-1]
if np.any(np.isnan([row, col])):
row = col = np.nan
if z is not None:
Expand Down
2 changes: 1 addition & 1 deletion flopy/export/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def difference(

d_data[np.isnan(d_data)] = FILLVALUE
if mask_zero_diff:
d_data[np.where(d_data == 0.0)] = FILLVALUE
d_data[np.asarray(d_data == 0.0).nonzero()] = FILLVALUE

var = new_net.create_variable(
vname, attrs, s_var.dtype, dimensions=s_var.dimensions
Expand Down
2 changes: 1 addition & 1 deletion flopy/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _add_output_nc_variable(
logger.log(f"creating array for {var_name}")

for mask_val in mask_vals:
array[np.where(array == mask_val)] = np.nan
array[np.asarray(array == mask_val).nonzero()] = np.nan
mx, mn = np.nanmax(array), np.nanmin(array)
array[np.isnan(array)] = netcdf.FILLVALUE

Expand Down
6 changes: 4 additions & 2 deletions flopy/export/vtk.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,9 @@ def _build_hfbs(self, pkg):

pts = []
for v in v1:
ix = np.where((v2.T[0] == v[0]) & (v2.T[1] == v[1]))
ix = np.asarray(
(v2.T[0] == v[0]) & (v2.T[1] == v[1])
).nonzero()
if len(ix[0]) > 0 and len(pts) < 2:
pts.append(v2[ix[0][0]])

Expand Down Expand Up @@ -652,7 +654,7 @@ def _build_point_scalar_array(self, array):
ps_array[pt] = array[value["idx"][ix]]
else:
ps_graph = self._point_scalar_numpy_graph.copy()
idxs = np.where(np.isnan(array))
idxs = np.asarray(np.isnan(array)).nonzero()
not_graphed = np.isin(ps_graph, idxs[0])
ps_graph[not_graphed] = -1
ps_array = np.where(ps_graph >= 0, array[ps_graph], np.nan)
Expand Down
6 changes: 4 additions & 2 deletions flopy/mf6/utils/lakpak_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_lak_connections(modelgrid, lake_map, idomain=None, bedleak=None):
unique = np.unique(lake_map)

# exclude lakes with lake numbers less than 0
idx = np.where(unique > -1)
idx = np.asarray(unique > -1).nonzero()
unique = unique[idx]

dx, dy = None, None
Expand Down Expand Up @@ -199,7 +199,9 @@ def get_lak_connections(modelgrid, lake_map, idomain=None, bedleak=None):

# reset idomain for lake
if iconn > 0:
idx = np.where((lake_map == lake_number) & (idomain > 0))
idx = np.asarray(
(lake_map == lake_number) & (idomain > 0)
).nonzero()
idomain[idx] = 0

return idomain, connection_dict, connectiondata
Expand Down
Loading

0 comments on commit 59040d0

Please sign in to comment.