Skip to content

Commit

Permalink
Merge pull request #954 from angus-g/curvilinear-index-search
Browse files Browse the repository at this point in the history
Add a method to pre-seed curvilinear indices using kdtree
  • Loading branch information
erikvansebille authored Jan 28, 2021
2 parents 6c83aa0 + a3def9a commit 400abb4
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 0 deletions.
1 change: 1 addition & 0 deletions environment_py3_osx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ dependencies:
- pytest
- nbval
- scikit-learn
- pykdtree
1 change: 1 addition & 0 deletions environment_py3_win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ dependencies:
- ipykernel<5.0
- pytest
- nbval
- pykdtree
1 change: 1 addition & 0 deletions environment_py3p6_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ dependencies:
- pytest
- nbval
- scikit-learn
- pykdtree
36 changes: 36 additions & 0 deletions parcels/particlesets/particlesetsoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import xarray as xr

from parcels.grid import GridCode
from parcels.grid import CurvilinearGrid
from parcels.kernel import Kernel
from parcels.particle import JITParticle
from parcels.particlefile import ParticleFile
Expand All @@ -20,6 +21,11 @@
from mpi4py import MPI
except:
MPI = None
# == comment CK: prevents us from adding KDTree as 'mandatory' dependency == #
try:
from pykdtree.kdtree import KDTree
except:
KDTree = None

__all__ = ['ParticleSet', 'ParticleSetSOA']

Expand Down Expand Up @@ -187,6 +193,36 @@ def indexed_subset(self, indices):
return ParticleCollectionIteratorSOA(self._collection,
subset=indices)

def populate_indices(self):
"""Pre-populate guesses of particle xi/yi indices using a kdtree.
This is only intended for curvilinear grids, where the initial index search
may be quite expensive.
"""

if self.fieldset is None:
# we need to be attached to a fieldset to have a valid
# gridset to search for indices
return

if KDTree is None:
return
else:
for i, grid in enumerate(self.fieldset.gridset.grids):
if not isinstance(grid, CurvilinearGrid):
continue

tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1)
tree = KDTree(tree_data)
# stack all the particle positions for a single query
pts = np.stack((self._collection.data['lon'], self._collection.data['lat']), axis=-1)
# query datatype needs to match tree datatype
_, idx = tree.query(pts.astype(tree_data.dtype))
yi, xi = np.unravel_index(idx, grid.lon.shape)

self._collection.data['xi'][:, i] = xi
self._collection.data['yi'][:, i] = yi

@property
def error_particles(self):
"""Get an iterator over all particles that are in an error state.
Expand Down

0 comments on commit 400abb4

Please sign in to comment.