Skip to content

Commit

Permalink
refactoring ray_runner and let it create ray ObjRefs in __partitioned__
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Jul 22, 2021
1 parent 993458b commit 24ca13f
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 121 deletions.
151 changes: 84 additions & 67 deletions heat/cw4heat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,26 @@
impl_str = "impl"
dndarray_str = "impl.DNDarray"

_distributor = None
_comm = None
_fini = None
_runner = None


class _partRef:
"""
Handle used in __partitioned__. Identifies one chunk of a distributed array.
"""

def __init__(self, id_, rank_):
self.id = id_
self.rank = rank_


def _getPartForRef(pref):
"""
Return actual partition data for given _partRef.
"""
# FIXME Ray
ret = _runner.distributor.getPart(pref, "larray")
return ret


def _setComm(c):
Expand All @@ -89,28 +106,38 @@ def init(doStart=True, ctxt=False):
For now we assume all ranks (controller and workers) are started through mpirun,
workers will never leave distributor.start() and so this function.
"""
global _distributor
global _comm
global _fini
global _runner

if _distributor is not None:
if _runner is not None:
return

_launcher = getenv("CW4H_LAUNCHER", default="mpi").lower()

# atexit.register(fini)
if _launcher == "ray":
assert ctxt is False, "Controller-worker context is useless with ray launcher."
from .ray_runner import init as ray_init, fini as ray_fini
from .ray_runner import init as ray_init

_comm, _distributor, _futures = ray_init(_setComm)
_distributor.start(initImpl=_setComm)
_fini = ray_fini
_runner = ray_init(_setComm)
_runner.distributor.start(initImpl=_setComm)
elif _launcher == "mpi":
_comm = MPI.COMM_WORLD
_distributor = Distributor(_comm)

class MPIRunner:
def __init__(self, dist, comm):
self.comm = comm
self.distributor = dist
self.publish = lambda id, distributor: [
(i, _partRef(id, i)) for i in range(self.comm.size)
]
self.get = _getPartForRef

def fini(self):
pass

c = MPI.COMM_WORLD
_runner = MPIRunner(Distributor(c), c)
if doStart:
_distributor.start(initImpl=_setComm)
_runner.distributor.start(initImpl=_setComm)
else:
raise Exception(f"unknown launcher {_launcher}. CW4H_LAUNCHER must be 'mpi', or 'ray'.")

Expand All @@ -120,9 +147,10 @@ def fini():
Finalize/shutdown distribution engine. Automatically called at exit.
When called on controller, workers will sys.exit from init().
"""
_distributor.fini()
if _fini:
_fini()
global _runner
_runner.distributor.fini()
if _runner:
_runner.fini()


class cw4h:
Expand All @@ -143,7 +171,7 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
if _comm.rank == 0:
if _runner.comm.rank == 0:
fini()

def controller(self):
Expand All @@ -152,10 +180,10 @@ def controller(self):
the code block protected as controller.
Non-root workers will not finish until self gets deleted.
"""
if _comm.rank == 0:
if _runner.comm.rank == 0:
return True
else:
_distributor.start(doExit=False, initImpl=_setComm)
_runner.distributor.start(doExit=False, initImpl=_setComm)
return False


Expand Down Expand Up @@ -195,7 +223,9 @@ def _submit(name, args, kwargs, unwrap="*", numout=1):
"""
scalar_args = tuple(x for x in args if not isinstance(x, DDParray))
deps = [x._handle.getId() for x in args if isinstance(x, DDParray)]
return _distributor.submitPP(_Task(name, scalar_args, kwargs, unwrap=unwrap), deps, numout)
return _runner.distributor.submitPP(
_Task(name, scalar_args, kwargs, unwrap=unwrap), deps, numout
)


def _submitProperty(name, self):
Expand All @@ -204,7 +234,7 @@ def _submitProperty(name, self):
"""
t = _PropertyTask(name)
try:
res = _distributor.submitPP(t, [self._handle.getId()])
res = _runner.distributor.submitPP(t, [self._handle.getId()])
except Exception:
assert False
return res
Expand All @@ -216,14 +246,6 @@ def _setitem_normalized(self, value, key):
self.__setitem__(key, value)


def _getPartForRef(pref):
"""
Return actual partition data for given partRef.
"""
# FIXME Ray
return _distributor.getPart(pref, "larray")


#######################################################################
# Our array is just a wrapper. Actual array is stored as a handle to
# allow delayed execution.
Expand Down Expand Up @@ -252,7 +274,7 @@ def __init__(self, handle):
# Return heat native array.
# With delayed execution, triggers computation as needed and blocks until array is available.
# """
# return _distributor.get(self._handle)
# return _runner.distributor.get(self._handle)

def __getitem__(self, key):
"""
Expand All @@ -275,48 +297,23 @@ def T(self):
"""
return DDParray(_submitProperty("T", self))

#######################################################################
# Now we add methods/properties through the standard process.
#######################################################################

# dynamically generate class methods from list of methods in array-API
# we simply make lambdas which submit appropriate Tasks
# FIXME: aa_inplace_operators,others?
fixme_afuncs = ["squeeze", "astype", "balance", "resplit"]
for method in aa_methods_a + aa_reflected_operators + fixme_afuncs:
if method not in ["__getitem__", "__setitem__"] and hasattr(dndarray, method):
exec(
f"{method} = lambda self, *args, **kwargs: DDParray(_submit('{dndarray_str}.{method}', (self, *args), kwargs))"
)

for method in aa_methods_s + ["__str__"]:
if hasattr(dndarray, method):
exec(
f"{method} = lambda self, *args, **kwargs: _distributor.get(_submit('{dndarray_str}.{method}', (self, *args), kwargs))"
)

class partRef:
"""
Handle used in __partitioned__. Identifies one chunk of a distributed array.
"""

def __init__(self, id_, rank_):
self.id = id_
self.rank = rank_

# @property
@property
def __partitioned__(self):
"""
Return partitioning meta data.
"""
parts = _distributor.get(
global _runner

parts = _runner.distributor.get(
_submit(f"{dndarray_str}.create_partition_interface", (self, True), {})
)
# Provide all data as handle/reference
for _, p in parts["partitions"].items():
p["data"] = self.partRef(self._handle._id, p["location"])
futures = _runner.publish(self._handle._id, _runner.distributor)
for i, p in enumerate(parts["partitions"].values()):
p["location"] = futures[i][0]
p["data"] = futures[i][1]
# set getter
parts["get"] = _getPartForRef
parts["get"] = _runner.get
# remove SPMD local key
del parts["locals"]
return parts
Expand All @@ -327,13 +324,33 @@ def __getattr__(self, attr):
Caches attributes from workers, so we communicate only once.
"""
if self._attributes is None:
self._attributes = _distributor.get(
self._attributes = _runner.distributor.get(
_submit(
"(lambda a: {x: getattr(a, x) for x in aa_attributes if x != 'T'})", (self,), {}
)
)
return self._attributes[attr]

#######################################################################
# Now we add methods/properties through the standard process.
#######################################################################

# dynamically generate class methods from list of methods in array-API
# we simply make lambdas which submit appropriate Tasks
# FIXME: aa_inplace_operators,others?
fixme_afuncs = ["squeeze", "astype", "balance", "resplit", "reshape"]
for method in aa_methods_a + aa_reflected_operators + fixme_afuncs:
if method not in ["__getitem__", "__setitem__"] and hasattr(dndarray, method):
exec(
f"{method} = lambda self, *args, **kwargs: DDParray(_submit('{dndarray_str}.{method}', (self, *args), kwargs))"
)

for method in aa_methods_s + ["__str__"]:
if hasattr(dndarray, method):
exec(
f"{method} = lambda self, *args, **kwargs: _runner.distributor.get(_submit('{dndarray_str}.{method}', (self, *args), kwargs))"
)


#######################################################################
# first define top-level functions through the standard process.
Expand All @@ -344,7 +361,7 @@ def __getattr__(self, attr):
# (lists taken from list of methods in array-API)
# Again, we simply make lambdas which submit appropriate Tasks

fixme_funcs = ["load_csv", "array", "triu"]
fixme_funcs = ["load_csv", "array", "triu", "copy", "repeat"]
for func in aa_tlfuncs + fixme_funcs:
if func == "meshgrid":
exec(
Expand Down
17 changes: 17 additions & 0 deletions heat/cw4heat/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
GO = 2
GET = 3
GETPART = 4
PUBPART = 5


class _TaskQueue:
Expand Down Expand Up @@ -124,6 +125,7 @@ def start(self, doExit=True, initImpl=None):
if self._comm.rank == 0:
return True
else:
print("Entering worker loop", flush=True)
done = False
header = None
while not done:
Expand All @@ -142,6 +144,10 @@ def start(self, doExit=True, initImpl=None):
val = _RemoteTask.getVal(header[2])
attr = getattr(val, header[3])
self._comm.send(attr, dest=0, tag=GETPART)
elif header[0] == PUBPART:
val = _RemoteTask.getVal(header[1])
attr = header[3](getattr(val, header[2]))
self._comm.gather(attr, root=0)
elif header[0] == END:
done = True
self._comm.Barrier()
Expand Down Expand Up @@ -201,6 +207,16 @@ def getPart(self, handle, attr):
val = self._comm.recv(source=handle.rank, tag=GETPART)
return val

def publishParts(self, id, attr, publish):
"""
Publish array's attribute for each partition and gather handles on root.
"""
assert self._comm.rank == 0
header = [PUBPART, id, attr, publish]
_ = self._comm.bcast(header, 0)
val = publish(getattr(_RemoteTask.getVal(id), attr))
return self._comm.gather(val, root=0)

def submitPP(self, task, deps, numout=1):
"""
Submit a process-parallel task and return a handle/future.
Expand Down Expand Up @@ -277,6 +293,7 @@ def go(self):
"""
Actually run the task.
"""
# print(self._task._func)
deps = [_RemoteTask.s_pms[i] for i in self._depIds]
res = self._task.run(deps)
if self._nOut == 1:
Expand Down
17 changes: 17 additions & 0 deletions heat/cw4heat/examples/t1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pickle
import heat.cw4heat as ht

ht.init()

a = ht.arange(8, split=0)
b = ht.ones(8, split=0)
c = a @ b
# assert hasattr(c, "__partitioned__")
print(type(c))
p = a.__partitioned__()
print(a.shape, a, p)
for k, v in p["partitions"].items():
print(33)
print(k, p["get"](v["data"]))
print("kkkkkk")
ht.fini()
9 changes: 5 additions & 4 deletions heat/cw4heat/examples/tcw4h.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
a = ht.arange(8, split=0)
b = ht.ones(8, split=0)
c = a @ b
assert hasattr(c, "__partitioned__")
p = a.__partitioned__()
# assert hasattr(c, "__partitioned__")
print(type(c))
p = c.__partitioned__()
print(c.shape, c, p)
for k, v in p["partitions"].items():
print(k, p["get"](v["data"]))
Expand All @@ -23,8 +24,8 @@
a = ht.arange(8, split=0)
b = ht.ones(8, split=0)
c = a @ b
assert hasattr(c, "__partitioned__")
p = a.__partitioned__()
# assert hasattr(c, "__partitioned__")
p = c.__partitioned__()
print(c.shape, c, p)
for k, v in p["partitions"].items():
print(k, p["get"](v["data"]))
Expand Down
Loading

0 comments on commit 24ca13f

Please sign in to comment.