Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] first draft of a controller-worker wrapper for heat #823

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a9b2059
first draft of a controller-worker wrapper for heat
fschlimb Jun 25, 2021
6343988
auto-init and auto-fini
fschlimb Jun 25, 2021
7c780cc
using dealyed execution; using double-quotes
fschlimb Jul 5, 2021
04919dc
making code flake8-, black-, and pydocstyle-compliant
fschlimb Jul 5, 2021
533194a
serving picky black
fschlimb Jul 5, 2021
ee414d9
first cut for supporting ray actors. Controller no longer a worker.
fschlimb Jul 7, 2021
c0c8c16
adding __partitioned__
fschlimb Jul 9, 2021
dbc3056
demoing cw/region (MPI backend)
fschlimb Jul 9, 2021
e8a7011
refactoring ray_runner and let it create ray ObjRefs in __partitioned__
fschlimb Jul 22, 2021
0ea1705
making location a list
fschlimb Jul 22, 2021
53fc158
fixes
fschlimb Aug 18, 2021
40f7f44
adding reset()
fschlimb Aug 18, 2021
e7e4398
using clear()
fschlimb Aug 18, 2021
d659c2e
adding dot
fschlimb Aug 20, 2021
1b0a4bf
quick workaround to have __localop in cw4heat
fschlimb Sep 1, 2021
06be72e
fixed GC
fschlimb Sep 1, 2021
df6a193
fixing GC issues
fschlimb Sep 1, 2021
9a68af9
quick hack to have random.normal
fschlimb Sep 2, 2021
b9c9315
allow barrier after go
fschlimb Sep 21, 2021
af939d5
allow spmd mode in cw4heat
fschlimb Sep 22, 2021
89eaf07
Merge branch 'main' into Enhancement/cw4heat
ClaudiaComito Jun 1, 2022
b7afb57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2022
68dcc20
Add type hints to `create_partition_interface`
ClaudiaComito Jun 1, 2022
41f6bfa
Merge branch 'main' into Enhancement/cw4heat
coquelin77 Jun 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,98 @@ def create_lshape_map(self, force_check: bool = True) -> torch.Tensor:
self.__lshape_map = lshape_map
return lshape_map

def create_partition_interface(self, no_data=False):
"""
Create a partition interface in line with the DPPY proposal. This is subject to change.
The intention of this to facilitate the usage of a general format for the referencing of
distributed datasets.
An example of the output and shape is shown below.
__partitioned__ = {
'shape': (27, 3, 2),
'partition_tiling': (4, 1, 1),
'partitions': {
(0, 0, 0): {
'start': (0, 0, 0),
'shape': (7, 3, 2),
'data': tensor([...], dtype=torch.int32),
'location': 0,
'dtype': torch.int32,
'device': 'cpu'
},
(1, 0, 0): {
'start': (7, 0, 0),
'shape': (7, 3, 2),
'data': None,
'location': 1,
'dtype': torch.int32,
'device': 'cpu'
},
(2, 0, 0): {
'start': (14, 0, 0),
'shape': (7, 3, 2),
'data': None,
'location': 2,
'dtype': torch.int32,
'device': 'cpu'
},
(3, 0, 0): {
'start': (21, 0, 0),
'shape': (6, 3, 2),
'data': None,
'location': 3,
'dtype': torch.int32,
'device': 'cpu'
}
},
'locals': [(rank, 0, 0)],
}
Returns
-------
dictionary containing the partition interface as shown above.
"""
# sp =
lshape_map = self.create_lshape_map()
start_idx_map = torch.zeros_like(lshape_map)

part_tiling = [1] * self.ndim
lcls = [0] * self.ndim

z = torch.tensor([0], device=self.device.torch_device, dtype=self.dtype.torch_type())
if self.split is not None:
starts = torch.cat((z, torch.cumsum(lshape_map[:, self.split], dim=0)[:-1]), dim=0)
lcls[self.split] = self.comm.rank
part_tiling[self.split] = self.comm.size
else:
starts = torch.zeros(self.ndim, dtype=torch.int, device=self.device.torch_device)

start_idx_map[:, self.split] = starts

partitions = {}
base_key = [0] * self.ndim
for r in range(self.comm.size):
if self.split is not None:
base_key[self.split] = r
dat = None if no_data or r != self.comm.rank else self.larray
else:
dat = self.larray

partitions[tuple(base_key)] = {
"start": tuple(start_idx_map[r].tolist()),
"shape": tuple(lshape_map[r].tolist()),
"data": dat,
"location": r,
"dtype": self.dtype.torch_type(),
"device": self.device.torch_device,
}

partition_dict = {
"shape": self.gshape,
"partition_tiling": tuple(part_tiling),
"partitions": partitions,
"locals": [tuple(lcls)],
}
return partition_dict

def __float__(self) -> DNDarray:
"""
Float scalar casting.
Expand Down
Loading