Skip to content

Commit

Permalink
change to set wg properties at init
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 26, 2023
1 parent a99b6f5 commit 9967da4
Showing 1 changed file with 89 additions and 32 deletions.
121 changes: 89 additions & 32 deletions python/cugraph/cugraph/gnn/feature_storage/feat_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,63 @@


class FeatureStore:
"""The feature-store class used to store feature data for GNNS"""
"""The feature-store class used to store feature data for GNNs"""

def __init__(
self,
backend: str = "numpy",
wg_comm: object = None,
wg_type: str = None,
wg_location: str = None,
):
"""
Constructs a new FeatureStore object
Parameters:
----------
backend: str ('numpy', 'torch', 'wholegraph')
Optional (default='numpy')
The name of the backend to use.
wg_comm: WholeMemoryCommunicator
Optional (default=automatic)
Only used with the 'wholegraph' backend.
The communicator to use to store features in WholeGraph.
wg_type: str ('distributed', 'continuous', 'chunked')
Optional (default='distributed')
Only used with the 'wholegraph' backend.
The memory format (distributed, continuous, or chunked) of
this FeatureStore. For more information see the WholeGraph
documentation.
wg_location: str ('cpu', 'cuda')
Optional (default='cuda')
Only used with the 'wholegraph' backend.
Where the data is stored (cpu or cuda).
Defaults to storing on the GPU (cuda).
"""

def __init__(self, backend="numpy"):
self.fd = defaultdict(dict)
if backend not in ["numpy", "torch", "wholegraph"]:
raise ValueError(
f"backend {backend} not supported. Supported backends are numpy, torch"
f"backend {backend} not supported. "
"Supported backends are numpy, torch, wholegraph"
)
self.backend = backend

if backend == "wholegraph":
self.__wg_comm = (
wg_comm if wg_comm is not None else wgth.get_local_node_communicator()
)
self.__wg_type = wg_type if wg_type is not None else "distributed"
self.__wg_location = wg_location if wg_location is not None else "cuda"

if self.__wg_type not in ["distributed", "chunked", "continuous"]:
raise ValueError(f"invalid memory format {self.__wg_type}")
if (self.__wg_location != "cuda") and (self.__wg_location != "cpu"):
raise ValueError(f"invalid location {self.__wg_location}")

def add_data(
self, feat_obj: Sequence, type_name: str, feat_name: str, **kwargs
) -> None:
Expand All @@ -52,7 +99,12 @@ def add_data(
None
"""
self.fd[feat_name][type_name] = self._cast_feat_obj_to_backend(
feat_obj, self.backend, **kwargs
feat_obj,
self.backend,
wg_comm=self.__wg_comm,
wg_type=self.__wg_type,
wg_location=self.__wg_location,
**kwargs,
)

def add_data_no_cast(self, feat_obj, type_name: str, feat_name: str) -> None:
Expand Down Expand Up @@ -125,39 +177,44 @@ def get_feature_list(self) -> list[str]:
def _cast_feat_obj_to_backend(feat_obj, backend: str, **kwargs):
if backend == "numpy":
if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)):
return _cast_to_numpy_ar(feat_obj.values)
return _cast_to_numpy_ar(feat_obj.values, **kwargs)
else:
return _cast_to_numpy_ar(feat_obj)
return _cast_to_numpy_ar(feat_obj, **kwargs)
elif backend == "torch":
if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)):
return _cast_to_torch_tensor(feat_obj.values)
return _cast_to_torch_tensor(feat_obj.values, **kwargs)
else:
return _cast_to_torch_tensor(feat_obj)
return _cast_to_torch_tensor(feat_obj, **kwargs)
elif backend == "wholegraph":
wg_comm_obj = kwargs.get("wg_comm", wgth.get_local_node_communicator())
wg_type_str = kwargs.get("wg_type", "distributed")
wg_location_str = kwargs.get("wg_location", "cuda")
if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)):
th_tensor = _cast_to_torch_tensor(feat_obj.values)
else:
th_tensor = _cast_to_torch_tensor(feat_obj)
wg_embedding = wgth.create_embedding(
wg_comm_obj,
wg_type_str,
wg_location_str,
th_tensor.dtype,
th_tensor.shape,
)
(
local_wg_tensor,
local_ld_offset,
) = wg_embedding.get_embedding_tensor().get_local_tensor()
local_th_tensor = th_tensor[
local_ld_offset : local_ld_offset + local_wg_tensor.shape[0]
]
local_wg_tensor.copy_(local_th_tensor)
wg_comm_obj.barrier()
return wg_embedding
return _get_wg_embedding(feat_obj, **kwargs)


def _get_wg_embedding(feat_obj, wg_comm=None, wg_type=None, wg_location=None):
wg_comm_obj = wg_comm or wgth.get_local_node_communicator()
wg_type_str = wg_type or "distributed"
wg_location_str = wg_location or "cuda"

if isinstance(feat_obj, (cudf.DataFrame, pd.DataFrame)):
th_tensor = _cast_to_torch_tensor(feat_obj.values)
else:
th_tensor = _cast_to_torch_tensor(feat_obj)
wg_embedding = wgth.create_embedding(
wg_comm_obj,
wg_type_str,
wg_location_str,
th_tensor.dtype,
th_tensor.shape,
)
(
local_wg_tensor,
local_ld_offset,
) = wg_embedding.get_embedding_tensor().get_local_tensor()
local_th_tensor = th_tensor[
local_ld_offset : local_ld_offset + local_wg_tensor.shape[0]
]
local_wg_tensor.copy_(local_th_tensor)
wg_comm_obj.barrier()
return wg_embedding


def _cast_to_torch_tensor(ar):
Expand Down

0 comments on commit 9967da4

Please sign in to comment.