Skip to content

Commit

Permalink
enable csc loader
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Sep 27, 2023
1 parent 7f838ae commit 6531e14
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
2 changes: 2 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ def __iter__(self):
if self.sparse_format == "csc":
kwargs["compression"] = "CSR"
kwargs["compress_per_hop"] = True
# The following kwargs will be deprecated in uniform sampler.
kwargs["use_legacy_names"] = False
kwargs["include_hop_column"] = False

else:
kwargs["deduplicate_sources"] = False
Expand Down
23 changes: 17 additions & 6 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cugraph_dgl.dataloading.utils.sampling_helpers import (
create_homogeneous_sampled_graphs_from_dataframe,
create_heterogeneous_sampled_graphs_from_dataframe,
create_homogeneous_sampled_graphs_from_dataframe_csc,
)


Expand Down Expand Up @@ -62,10 +63,20 @@ def __getitem__(self, idx: int):

fn, batch_offset = self._batch_to_fn_d[idx]
if fn != self._current_batch_fn:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = create_homogeneous_sampled_graphs_from_dataframe(
sampled_df=df, edge_dir=self.edge_dir, return_type=self._return_type
)
if self.sparse_format == "csc":
df = _load_sampled_file(dataset_obj=self, fn=fn, skip_rename=True)
self._current_batches = (
create_homogeneous_sampled_graphs_from_dataframe_csc(df)
)
else:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = (
create_homogeneous_sampled_graphs_from_dataframe(
sampled_df=df,
edge_dir=self.edge_dir,
return_type=self._return_type,
)
)
current_offset = idx - batch_offset
return self._current_batches[current_offset]

Expand Down Expand Up @@ -152,9 +163,9 @@ def set_input_files(
)


def _load_sampled_file(dataset_obj, fn):
def _load_sampled_file(dataset_obj, fn, skip_rename=False):
df = cudf.read_parquet(os.path.join(fn))
if dataset_obj.edge_dir == "in":
if dataset_obj.edge_dir == "in" and not skip_rename:
df.rename(
columns={"sources": "destinations", "destinations": "sources"},
inplace=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,5 +522,7 @@ def _create_homogeneous_sparse_graphs_from_csc(
return output


def create_homogeneous_sampled_graphs_from_dataframe_csc(df):
return _create_homogeneous_sparse_graphs_from_csc(*(_process_sampled_df_csc(df)))
def create_homogeneous_sampled_graphs_from_dataframe_csc(sampled_df: cudf.DataFrame):
return _create_homogeneous_sparse_graphs_from_csc(
*(_process_sampled_df_csc(sampled_df))
)

0 comments on commit 6531e14

Please sign in to comment.