From f1ef95f520c5afe3c04c52d5d50c5e5d34621607 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Mon, 6 Nov 2023 09:25:46 -0800 Subject: [PATCH] fix loader bugs --- .../cugraph_pyg/loader/cugraph_node_loader.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index 8552e7412e0..fb9f3f9c70a 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -159,7 +159,10 @@ def __init__( if batch_size is None or batch_size < 1: raise ValueError("Batch size must be >= 1") - self.__directory = tempfile.TemporaryDirectory(dir=directory) + self.__directory = ( + tempfile.TemporaryDirectory() if directory is None + else directory + ) if isinstance(num_neighbors, dict): raise ValueError("num_neighbors dict is currently unsupported!") @@ -175,7 +178,7 @@ def __init__( bulk_sampler = BulkSampler( batch_size, - self.__directory.name, + self.__directory if isinstance(self.__directory, str) else self.__directory.name, self.__graph_store._subgraph(edge_types), fanout_vals=num_neighbors, with_replacement=replace, @@ -219,7 +222,11 @@ def __init__( ) bulk_sampler.flush() - self.__input_files = iter(os.listdir(self.__directory.name)) + self.__input_files = iter(os.listdir( + self.__directory + if isinstance(self.__directory, str) + else self.__directory.name + )) def __next__(self): from time import perf_counter @@ -437,11 +444,10 @@ def __next__(self): # Account for CSR format in cuGraph vs. CSC format in PyG if self.__coo and self.__graph_store.order == "CSC": - for node_type in out.edge_index_dict: - out[node_type].edge_index[0], out[node_type].edge_index[1] = ( - out[node_type].edge_index[1], - out[node_type].edge_index[0], - ) + for edge_type in out.edge_index_dict: + src = out[edge_type].edge_index[0] + dst = out[edge_type].edge_index[1] + out[edge_type].edge_index = torch.stack([dst,src]) out.set_value_dict("num_sampled_nodes", sampler_output.num_sampled_nodes) out.set_value_dict("num_sampled_edges", sampler_output.num_sampled_edges)