Skip to content

Commit

Permalink
fix loader bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Nov 6, 2023
1 parent 586451d commit f1ef95f
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def __init__(
if batch_size is None or batch_size < 1:
raise ValueError("Batch size must be >= 1")

self.__directory = tempfile.TemporaryDirectory(dir=directory)
self.__directory = (
tempfile.TemporaryDirectory() if directory is None
else directory
)

if isinstance(num_neighbors, dict):
raise ValueError("num_neighbors dict is currently unsupported!")
Expand All @@ -175,7 +178,7 @@ def __init__(

bulk_sampler = BulkSampler(
batch_size,
self.__directory.name,
self.__directory if isinstance(self.__directory, str) else self.__directory.name,
self.__graph_store._subgraph(edge_types),
fanout_vals=num_neighbors,
with_replacement=replace,
Expand Down Expand Up @@ -219,7 +222,11 @@ def __init__(
)

bulk_sampler.flush()
self.__input_files = iter(os.listdir(self.__directory.name))
self.__input_files = iter(os.listdir(
self.__directory
if isinstance(self.__directory, str)
else self.__directory.name
))

def __next__(self):
from time import perf_counter
Expand Down Expand Up @@ -437,11 +444,10 @@ def __next__(self):

# Account for CSR format in cuGraph vs. CSC format in PyG
if self.__coo and self.__graph_store.order == "CSC":
for node_type in out.edge_index_dict:
out[node_type].edge_index[0], out[node_type].edge_index[1] = (
out[node_type].edge_index[1],
out[node_type].edge_index[0],
)
for edge_type in out.edge_index_dict:
src = out[edge_type].edge_index[0]
dst = out[edge_type].edge_index[1]
out[edge_type].edge_index = torch.stack([dst,src])

out.set_value_dict("num_sampled_nodes", sampler_output.num_sampled_nodes)
out.set_value_dict("num_sampled_edges", sampler_output.num_sampled_edges)
Expand Down

0 comments on commit f1ef95f

Please sign in to comment.