Skip to content

Commit

Permalink
fix bad kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 27, 2023
1 parent 9967da4 commit 78254c1
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/cugraph/cugraph/gnn/feature_storage/feat_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def __init__(
)
self.backend = backend

self.__wg_comm = None
self.__wg_type = None
self.__wg_location = None

if backend == "wholegraph":
self.__wg_comm = (
wg_comm if wg_comm is not None else wgth.get_local_node_communicator()
Expand Down Expand Up @@ -189,7 +193,7 @@ def _cast_feat_obj_to_backend(feat_obj, backend: str, **kwargs):
return _get_wg_embedding(feat_obj, **kwargs)


def _get_wg_embedding(feat_obj, wg_comm=None, wg_type=None, wg_location=None):
def _get_wg_embedding(feat_obj, wg_comm=None, wg_type=None, wg_location=None, **kwargs):
wg_comm_obj = wg_comm or wgth.get_local_node_communicator()
wg_type_str = wg_type or "distributed"
wg_location_str = wg_location or "cuda"
Expand Down Expand Up @@ -217,7 +221,7 @@ def _get_wg_embedding(feat_obj, wg_comm=None, wg_type=None, wg_location=None):
return wg_embedding


def _cast_to_torch_tensor(ar):
def _cast_to_torch_tensor(ar, **kwargs):
if isinstance(ar, cp.ndarray):
ar = torch.as_tensor(ar, device="cuda")
elif isinstance(ar, np.ndarray):
Expand All @@ -227,7 +231,7 @@ def _cast_to_torch_tensor(ar):
return ar


def _cast_to_numpy_ar(ar):
def _cast_to_numpy_ar(ar, **kwargs):
if isinstance(ar, cp.ndarray):
ar = ar.get()
elif type(ar).__name__ == "Tensor":
Expand Down

0 comments on commit 78254c1

Please sign in to comment.