-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmpi.py
74 lines (60 loc) · 2.46 KB
/
mpi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy
from mpi4py import MPI
class InterComm( object ):
'generic MPI communicator wrapper'
def __init__( self, cmd, **kwargs ):
self.__comm = MPI.COMM_SELF.Spawn( cmd, **kwargs )
self.nprocs = self.__comm.remote_size
def isconnected( self ):
return bool( self.__comm )
def bcast( self, data, dtype=None ):
if dtype is not None:
data = numpy.ascontiguousarray( data, dtype )
else:
assert isinstance( data, numpy.ndarray )
self.__comm.Bcast( [ data, MPI.BYTE ], root=MPI.ROOT )
def gather( self, dtype ):
array = numpy.empty( self.nprocs, dtype=dtype )
self.__comm.Gather( None, [ array, MPI.BYTE ], root=MPI.ROOT )
return array
def gather_equal( self, dtype ):
array = self.gather( dtype )
assert numpy.all( array[1:] == array[0] )
return array[0]
def scatter( self, array, dtype ):
array = numpy.ascontiguousarray( array, dtype=dtype )
self.__comm.Scatter( [ array, MPI.BYTE ], None, root=MPI.ROOT )
def scatterv( self, arrays, dtype ):
arrays = [ numpy.asarray( array, dtype=dtype ) for array in arrays ]
nbytes = [ arr.nbytes for arr in arrays ]
offsets = numpy.concatenate( [ [0], numpy.cumsum( nbytes[:-1] ) ] ) # first offset = 0
data = numpy.concatenate( arrays )
self.__comm.Scatterv( [ data, nbytes, offsets, MPI.BYTE ], None, root=MPI.ROOT )
def gatherv( self, lengths, dtype ):
data = numpy.empty( sum(lengths), dtype=dtype )
arrays = [ data[i-n:i] for i, n in zip( numpy.cumsum(lengths), lengths ) ]
nbytes = [ arr.nbytes for arr in arrays ]
offsets = numpy.concatenate( [ [0], numpy.cumsum( nbytes[:-1] ) ] ) # first offset = 0
self.__comm.Gatherv( None, [ data, nbytes, offsets, MPI.BYTE ], root=MPI.ROOT )
return arrays
def send( self, rank, array, dtype ):
array = numpy.ascontiguousarray( array, dtype )
self.__comm.Send( [ array, MPI.BYTE ], rank, tag=0 )
def verify( self ):
raise NotImplementedError
strlen = self.gather( int )
good = (strlen == 0)
if good.all():
return 0
msgs = [ 'In libmatrix: %d errors occurred:' % (~good).sum() ]
for i, s in enumerate( strlen ):
if s:
x = numpy.empty( s, dtype='c' )
self.__comm.Recv( [x,MPI.CHAR], i, tag=10 )
msgs.append( '[%d] %s' % ( i, x.tostring() ) )
raise Exception( '\n '.join( msgs ) )
def disconnect( self ):
if self.isconnected():
self.__comm.Disconnect()
def __del__( self ):
self.disconnect()