Skip to content

Commit

Permalink
update grid
Browse files Browse the repository at this point in the history
  • Loading branch information
dotmet committed Sep 14, 2023
1 parent db033a0 commit e8ef2e0
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 6 deletions.
100 changes: 94 additions & 6 deletions mpcmd/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from warnings import WarningMessage
from .tools.make_gsd_snapshot import make_snapshot
from .geometry.grid import Grid3D

from numba import prange, jit, njit

Expand Down Expand Up @@ -465,8 +466,94 @@ def __init__(self, mpcd_sys=None):

self.sorted_data = {'grids':None, 'particles':None, 'vcm':None}

self.Grid3d = Grid3D(grid_length=1, geometry=self.geometry)

def velocity_field(self, plane='xoy', loc=None, grid_length=1, contain_solute=False, transpose=False, show=True):
def velocity_field(self, plane='xoy', dim=None, loc=None, agg_method=np.nanmean,
grid_length=1, contain_solute=False, show=True,
transpose=False, stream_plot=False, **kwargs):
'''
Show the velocity field in given plane.
Parameters
----------
plane: str
The plane to visualize.
dim: array like
The alias of plane which must have length with 3
loc: float
The location of plane.
agg_method: function
The method to aggregate the velocity in each grid.
grid_length: float
The length of grid.
contain_solute: bool
Whether to contain solute particles.
transpose: bool
Whether to transpose the velocity field.
show: bool
Whether to show the velocity field.
stream_plot: bool
Whether to use stream plot to show the velocity field.
kwargs: dict
kwargs for plot.
Returns
-------
(alist, blist): tuple
The locations of grids.
(valist, vblist): tuple
The velocity field.
'''
if plane == 'xoy':
dim = [0, 1, 2]
elif plane == 'yoz':
dim = [1, 2, 0]
elif plane == 'xoz':
dim = [0, 2, 1]

grid = self.Grid3d
if self.Grid3d.grid_length != grid_length:
grid = Grid3D(geometry=self.geometry, grid_length=grid_length)
grid.center_zero()

posi, velo = self.fluid.position, self.fluid.velocity
if contain_solute:
posi = np.vstack([posi, self.solute.position])
velo = np.vstack([velo, self.solute.velocity])

if loc is not None:
select = np.where((posi[:,dim[2]]<loc+grid_length/2) &
(posi[:,dim[2]]>=loc-grid_length/2))
posi = posi[select]
velo = velo[select]

grid_centers, posi_ids = grid.scatter_points(posi, dim=dim[:2])
alist, blist = grid_centers[:,0], grid_centers[:,1]
velo_res = np.array([agg_method(velo[ids], axis=0) for ids in posi_ids])
valist, vblist = velo_res[:,dim[0]], velo_res[:,dim[1]]
if show:
print(f'Show velocity field in {plane} plane ...')
if not stream_plot:
if transpose:
plt.quiver(blist, alist, vblist, valist, **kwargs)
else:
plt.quiver(alist, blist, valist, vblist, **kwargs)
plt.show()
else:
als = np.unique(alist)
bls = np.unique(blist)

x = alist.reshape(-1, len(bls)).T
y = blist.reshape(len(als), -1).T
u = valist.reshape(-1, len(bls)).T
v = vblist.reshape(len(als), -1).T
if transpose:
x, y, u, v = y.T, x.T, v.T, u.T
plt.streamplot(x, y, u, v, **kwargs)
return (alist, blist), (valist, vblist)


def old_velocity_field(self, plane='xoy', loc=None, grid_length=1, contain_solute=False, transpose=False, show=True):
'''
Show the velocity field in given plane.
Expand Down Expand Up @@ -521,7 +608,7 @@ def velocity_field(self, plane='xoy', loc=None, grid_length=1, contain_solute=Fa
else:
return alist, blist, valist, vblist

def velocity_distribution(self, bins=100, axis='x'):
def velocity_distribution(self, axis='x', bins=100):
'''
Show the velocity distribution.
Expand Down Expand Up @@ -549,7 +636,7 @@ def velocity_distribution(self, bins=100, axis='x'):
plt.hist(vs, bins)
plt.show()

def velocity_profile(self, plane='xoz', plane_loc=None, loc_cross=0.0, grid_length=1, dim=1):
def velocity_profile(self, plane='xoz', plane_loc=None, agg_method=np.nansum, loc_cross=0.0, grid_length=1, dim=1):
'''
Get the velocity profile in given plane.
Expand All @@ -573,17 +660,18 @@ def velocity_profile(self, plane='xoz', plane_loc=None, loc_cross=0.0, grid_leng
vs: numpy.ndarray
The velocity profile.
'''
alist, blist, valist, vblist=self.velocity_field(plane, plane_loc, grid_length, show=False)
(alist, blist), (valist, vblist)=self.velocity_field(plane=plane, loc=plane_loc, agg_method=agg_method,
grid_length=grid_length, show=False)
als, nas = np.unique(alist, return_counts=True)
bls, nbs = np.unique(blist, return_counts=True)
if dim==0:
locs = bls
vs = valist.reshape(nas[0], nbs[0])
vs = np.mean(vs, axis=1)
vs = np.nanmean(vs, axis=1)
elif dim==1:
locs = als
vs = vblist.reshape(nbs[0], nas[0])
vs = np.mean(vs, axis=1)
vs = np.nanmean(vs, axis=1)
plt.plot(locs, vs)
plt.show()
return locs, vs
Expand Down
161 changes: 161 additions & 0 deletions mpcmd/geometry/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from sklearn.neighbors import KDTree
import numpy as np
import copy

class Grid3D:

box = None
grid_length = None
grid_size = None
grid_num = None
xs = None
ys = None
zs = None
kdtree = None
kdtreex = None
kdtreey = None
kdtreez = None
center = (0, 0, 0)

box_err_msg = ' is not a valid box, the supported shape is (3,) or (6,) or (3,2):\
\n\t(3,) -> [xl, yl, zl] \n\t(6,) -> [xlo, xhi, ylo, yhi, zlo, zhi]\
\n\t(3,2) -> [[xlo, xhi], [ylo, yhi], [zlo, zhi]]. And lo < hi.'

def __init__(self, box=None, grid_length=1, geometry=None, center=(0,0,0)):

origin_box = copy.deepcopy(box)
if box is not None:
if isinstance(box, list) or isinstance(box, np.ndarray):
box = np.array(box)
if len(box) != 3 or len(box) != 6:
raise ValueError(f'{origin_box}'+self.box_err_msg)
elif len(box) == 3:
if box.ndim==2:
self.box = box
elif box.ndim==1:
self.box = np.array([0, 0, 0, box[0], box[1], box[2]])
else:
raise ValueError(f'{origin_box}'+self.box_err_msg)
elif len(box) == 6:
self.box = box.reshape(3,2)
if np.any(self.box[:,0] > self.box[:,1]):
raise ValueError(f'{origin_box}'+self.box_err_msg)
else:
raise TypeError(f'{origin_box}'+self.box_err_msg)
elif geometry is not None:
try:
self.box = np.array(geometry.bounding_box).reshape(3,2) + np.tile([-1, 1], 3).reshape(3,2)
center = geometry.shift_vec
except AttributeError:
raise AttributeError('geometry must have attribute bounding_box')
else:
raise ValueError('box or geometry must be provided')

self.box = self.box - np.mean(self.box, axis=1).reshape(3,1) + \
np.array(center).reshape(3,1)
self.center = np.array(center)
self.grid_length = grid_length
self.grid_size = np.ceil((self.box[:,1] - self.box[:,0]) / grid_length).astype(int)
self.grid_num = np.prod(self.grid_size)
self.xs = np.linspace(self.box[0,0], self.box[0,1], self.grid_size[0]+1)
self.ys = np.linspace(self.box[1,0], self.box[1,1], self.grid_size[1]+1)
self.zs = np.linspace(self.box[2,0], self.box[2,1], self.grid_size[2]+1)

def __add__(self, vector):
if not isinstance(vector, np.ndarray) and not isinstance(vector, list):
raise TypeError('vector must be a list or numpy.ndarray')
elif len(vector) != 3:
raise ValueError('vector must be a 3D vector')
return Grid3D(box=self.box, grid_length=self.grid_length, center=self.center+vector)

def __get_centers(self):
xcs = (self.xs[:-1] + self.xs[1:]) / 2
ycs = (self.ys[:-1] + self.ys[1:]) / 2
zcs = (self.zs[:-1] + self.zs[1:]) / 2
return xcs, ycs, zcs

def __parse_dim(self, dim):
if dim is None:
return np.arange(3)
elif isinstance(dim, int):
dim = [dim]
elif not isinstance(dim, list) and not isinstance(dim, np.ndarray):
raise TypeError('dim must be a list or numpy.ndarray or an integer')
if not all([d in [0, 1, 2] for d in dim]):
raise ValueError('dim must be 0, 1 or 2, or a list (array) of them')
else:
return np.sort(dim)

def get_grid_centers(self, dim=None):
xcs, ycs, zcs = self.__get_centers()
dim = self.__parse_dim(dim)
if len(dim) == 1:
return np.array([xcs, ycs, zcs][dim]).reshape(-1,1)
elif len(dim) == 2:
c1s, c2s = np.array([xcs, ycs, zcs])[dim]
return np.vstack([np.repeat(c1s, len(c2s)), np.tile(c2s, len(c1s))]).T
elif len(dim) == 3:
gxcs = np.repeat(xcs, len(ycs)*len(zcs))
gycs = np.tile(np.repeat(ycs, len(zcs)), len(xcs))
gzcs = np.tile(zcs, len(xcs)*len(ycs))
return np.vstack([gxcs, gycs, gzcs]).T

def get_grid_bounds(self, dim=None):
dim = self.__parse_dim(dim)
xl, yl, zl = self.xs[:-1], self.ys[:-1], self.zs[:-1]
xh, yh, zh = self.xs[1:], self.ys[1:], self.zs[1:]
if len(dim) == 1:
return np.array([[xl, xh], [yl, yh], [zl, zh]][dim]).T
elif len(dim) == 2:
xls, xhs = np.array([[xl, xh], [yl, yh], [zl, zh]][dim[0]])
yls, yhs = np.array([[xl, xh], [yl, yh], [zl, zh]][dim[1]])
return np.vstack([np.repeat(xls, len(yls)), np.repeat(xhs, len(yhs)), \
np.tile(yls, len(xls)), np.tile(yhs, len(xhs))]).T
elif len(dim) == 3:
xls = np.repeat(xl, len(yl)*len(zl))
yls = np.tile(np.repeat(yl, len(zl)), len(xl))
zls = np.tile(zl, len(xl)*len(yl))
xhs = np.repeat(xh, len(yh)*len(zh))
yhs = np.tile(np.repeat(yh, len(zh)), len(xh))
zhs = np.tile(zh, len(xh)*len(yh))
return np.vstack([xls, xhs, yls, yhs, zls, zhs]).T

def gen_kdtree(self, dim=None):
self.kdtree = KDTree(self.get_grid_centers(dim=dim), metric='chebyshev')
return self.kdtree

def get_grid_index(self, pos):
if self.kdtree is None:
self.gen_kdtree()
return self.kdtree.query(pos, return_distance=False).flatten()

def shift(self, vector, inplace=True):
if not inplace:
return Grid3D(box=self.box, grid_length=self.grid_length, center=self.center+vector)
self.center += np.array(vector)
self.box = self.box + np.array(vector).reshape(3,1)
self.xs = self.xs + vector[0]
self.ys = self.ys + vector[1]
self.zs = self.zs + vector[2]
# self.gen_kdtree()

def kdtree_1d(self, dim=0):
return self.gen_kdtree(dim)

def kdtree_2d(self, no_dim=0):
dim = [0, 1, 2]
dim.remove(no_dim)
return self.gen_kdtree(dim)

def center_zero(self):
self.shift(-self.center)

def scatter_points(self, posi, dim=[0, 1, 2]):
dim = self.__parse_dim(dim)
tree = None
if len(dim) == 1:
tree = KDTree(posi[:,dim].reshape(-1,1), metric='chebyshev')
elif len(dim) >= 2:
tree = KDTree(posi[:,dim], metric='chebyshev')
gcenters = self.get_grid_centers(dim=dim)
return gcenters, tree.query_radius(gcenters, return_distance=False, r=self.grid_length/2)

0 comments on commit e8ef2e0

Please sign in to comment.