From 3a6b6b90e4e20c78485ddea712d294dd8651c882 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 28 Sep 2023 07:41:11 -0700 Subject: [PATCH] docstring --- .../cugraph_dgl/dataloading/dataloader.py | 4 +- .../dataloading/utils/sampling_helpers.py | 46 ++++++++++++++----- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py b/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py index b8241f489e5..0ea02bdef1b 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py @@ -93,7 +93,9 @@ def __init__( batch_size: int Batch size. sparse_format: str, default = "coo" - Sparse format of the sample graph. Choose between "csc" and "coo". + The sparse format of the emitted sampled graphs. Choose between "csc" + and "coo". When using "csc", the graphs are of type + cugraph_dgl.nn.SparseGraph. kwargs : dict Key-word arguments to be passed to the parent PyTorch :py:class:`torch.utils.data.DataLoader` class. Common arguments are: diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py b/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py index 26e33166d4e..3a16c6580d2 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import List, Tuple, Dict, Optional, Any +from typing import List, Tuple, Dict, Optional from collections import defaultdict import cudf import cupy @@ -417,16 +417,31 @@ def _process_sampled_df_csc( Convert a dataframe generated by BulkSampler to a dictionary of tensors, to facilitate MFG creation. The sampled graphs in the dataframe use CSC-format. + Parameters + ---------- df: cudf.DataFrame - The CSR output by BulkSampler. - reverse_hop_id: bool, default=True + The output from BulkSampler compressed in CSC format. The dataframe + should be generated with `compression="CSR"` in BulkSampler, + since the sampling routine treats seed nodes as sources. + + reverse_hop_id: bool (default=True) Reverse hop id. - Returns: - tensor_dict[batch_id][hop_id] has three keys: - - src_ids: - - cdst_ids: - - mfg_size: + Returns + ------- + tensors_dict: dict + A nested dictionary keyed by batch id and hop id. + `tensor_dict[batch_id][hop_id]` holds "minors" and "major_offsets" + values for CSC MFGs. + + renumber_map_list: list + List of renumbering maps for looking up global indices of nodes. One + map for each batch. + + mfg_sizes: list + List of the number of nodes in each message passing layer. For the + k-th hop, mfg_sizes[k] and mfg_sizes[k+1] is the number of sources and + destinations, respectively. """ # dropna major_offsets = df.major_offsets.dropna().values @@ -495,9 +510,16 @@ def _create_homogeneous_sparse_graphs_from_csc( tensors_dict: Dict[int, Dict[int, Dict[str, torch.Tensor]]], renumber_map_list: List[torch.Tensor], mfg_sizes: List[int, int], -) -> Any: - """Create mini-batches of MFGs. The input argument are the outputs of - the function `_process_sampled_df_csc`.""" +) -> List[List[torch.Tensor, torch.Tensor, List[SparseGraph]]]: + """Create mini-batches of MFGs. The input arguments are the outputs of + the function `_process_sampled_df_csc`. + + Returns + ------- + output: list + A list of mini-batches. Each mini-batch is a list that consists of + `input_nodes` tensor, `output_nodes` tensor and a list of MFGs. + """ n_batches, n_hops = len(mfg_sizes), len(mfg_sizes[0]) - 1 output = [] for b_id in range(n_batches): @@ -523,6 +545,8 @@ def _create_homogeneous_sparse_graphs_from_csc( def create_homogeneous_sampled_graphs_from_dataframe_csc(sampled_df: cudf.DataFrame): + """Public API to create mini-batches of MFGs using a dataframe output by + BulkSampler, where the sampled graph is compressed in CSC format.""" return _create_homogeneous_sparse_graphs_from_csc( *(_process_sampled_df_csc(sampled_df)) )