Skip to content

Commit

Permalink
Merge branch 'branch-23.10' into cugraph-pyg-mfg
Browse files Browse the repository at this point in the history
  • Loading branch information
BradReesWork authored Oct 4, 2023
2 parents 4438f82 + 5ce3ee1 commit 4826b01
Show file tree
Hide file tree
Showing 15 changed files with 401 additions and 41 deletions.
2 changes: 1 addition & 1 deletion python/cugraph-dgl/cugraph_dgl/dataloading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from cugraph_dgl.dataloading.dataset import (
HomogenousBulkSamplerDataset,
HetrogenousBulkSamplerDataset,
HeterogenousBulkSamplerDataset,
)
from cugraph_dgl.dataloading.neighbor_sampler import NeighborSampler
from cugraph_dgl.dataloading.dataloader import DataLoader
49 changes: 35 additions & 14 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dask.distributed import default_client, Event
from cugraph_dgl.dataloading import (
HomogenousBulkSamplerDataset,
HetrogenousBulkSamplerDataset,
HeterogenousBulkSamplerDataset,
)
from cugraph_dgl.dataloading.utils.extract_graph_helpers import (
create_cugraph_graph_from_edges_dict,
Expand All @@ -47,19 +47,20 @@ def __init__(
graph_sampler: cugraph_dgl.dataloading.NeighborSampler,
sampling_output_dir: str,
batches_per_partition: int = 50,
seeds_per_call: int = 400_000,
seeds_per_call: int = 200_000,
device: torch.device = None,
use_ddp: bool = False,
ddp_seed: int = 0,
batch_size: int = 1024,
drop_last: bool = False,
shuffle: bool = False,
sparse_format: str = "coo",
**kwargs,
):
"""
Constructor for CuGraphStorage:
-------------------------------
graph : CuGraphStorage
graph : CuGraphStorage
The graph.
indices : Tensor or dict[ntype, Tensor]
The set of indices. It can either be a tensor of
Expand Down Expand Up @@ -89,7 +90,12 @@ def __init__(
The seed for shuffling the dataset in
:class:`torch.utils.data.distributed.DistributedSampler`.
Only effective when :attr:`use_ddp` is True.
batch_size: int,
batch_size: int
Batch size.
sparse_format: str, default = "coo"
The sparse format of the emitted sampled graphs. Choose between "csc"
and "coo". When using "csc", the graphs are of type
cugraph_dgl.nn.SparseGraph.
kwargs : dict
Key-word arguments to be passed to the parent PyTorch
:py:class:`torch.utils.data.DataLoader` class. Common arguments are:
Expand Down Expand Up @@ -123,6 +129,12 @@ def __init__(
... for input_nodes, output_nodes, blocks in dataloader:
...
"""
if sparse_format not in ["coo", "csc"]:
raise ValueError(
f"sparse_format must be one of 'coo', 'csc', "
f"but got {sparse_format}."
)
self.sparse_format = sparse_format

self.ddp_seed = ddp_seed
self.use_ddp = use_ddp
Expand Down Expand Up @@ -156,11 +168,12 @@ def __init__(
self.cugraph_dgl_dataset = HomogenousBulkSamplerDataset(
total_number_of_nodes=graph.total_number_of_nodes,
edge_dir=self.graph_sampler.edge_dir,
sparse_format=sparse_format,
)
else:
etype_id_to_etype_str_dict = {v: k for k, v in graph._etype_id_dict.items()}

self.cugraph_dgl_dataset = HetrogenousBulkSamplerDataset(
self.cugraph_dgl_dataset = HeterogenousBulkSamplerDataset(
num_nodes_dict=graph.num_nodes_dict,
etype_id_dict=etype_id_to_etype_str_dict,
etype_offset_dict=graph._etype_offset_d,
Expand Down Expand Up @@ -210,14 +223,23 @@ def __iter__(self):
output_dir = os.path.join(
self._sampling_output_dir, "epoch_" + str(self.epoch_number)
)
kwargs = {}
if isinstance(self.cugraph_dgl_dataset, HomogenousBulkSamplerDataset):
deduplicate_sources = True
prior_sources_behavior = "carryover"
renumber = True
kwargs["deduplicate_sources"] = True
kwargs["prior_sources_behavior"] = "carryover"
kwargs["renumber"] = True

if self.sparse_format == "csc":
kwargs["compression"] = "CSR"
kwargs["compress_per_hop"] = True
# The following kwargs will be deprecated in uniform sampler.
kwargs["use_legacy_names"] = False
kwargs["include_hop_column"] = False

else:
deduplicate_sources = False
prior_sources_behavior = None
renumber = False
kwargs["deduplicate_sources"] = False
kwargs["prior_sources_behavior"] = None
kwargs["renumber"] = False

bs = BulkSampler(
output_path=output_dir,
Expand All @@ -227,10 +249,9 @@ def __iter__(self):
seeds_per_call=self._seeds_per_call,
fanout_vals=self.graph_sampler._reversed_fanout_vals,
with_replacement=self.graph_sampler.replace,
deduplicate_sources=deduplicate_sources,
prior_sources_behavior=prior_sources_behavior,
renumber=renumber,
**kwargs,
)

if self.shuffle:
self.tensorized_indices_ds.shuffle()

Expand Down
37 changes: 25 additions & 12 deletions python/cugraph-dgl/cugraph_dgl/dataloading/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cugraph_dgl.dataloading.utils.sampling_helpers import (
create_homogeneous_sampled_graphs_from_dataframe,
create_heterogeneous_sampled_graphs_from_dataframe,
create_homogeneous_sampled_graphs_from_dataframe_csc,
)


Expand All @@ -33,17 +34,19 @@ def __init__(
total_number_of_nodes: int,
edge_dir: str,
return_type: str = "dgl.Block",
sparse_format: str = "coo",
):
if return_type not in ["dgl.Block", "cugraph_dgl.nn.SparseGraph"]:
raise ValueError(
"return_type must be either 'dgl.Block' or \
'cugraph_dgl.nn.SparseGraph' "
"return_type must be either 'dgl.Block' or "
"'cugraph_dgl.nn.SparseGraph'."
)
# TODO: Deprecate `total_number_of_nodes`
# as it is no longer needed
# in the next release
self.total_number_of_nodes = total_number_of_nodes
self.edge_dir = edge_dir
self.sparse_format = sparse_format
self._current_batch_fn = None
self._input_files = None
self._return_type = return_type
Expand All @@ -60,10 +63,20 @@ def __getitem__(self, idx: int):

fn, batch_offset = self._batch_to_fn_d[idx]
if fn != self._current_batch_fn:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = create_homogeneous_sampled_graphs_from_dataframe(
sampled_df=df, edge_dir=self.edge_dir, return_type=self._return_type
)
if self.sparse_format == "csc":
df = _load_sampled_file(dataset_obj=self, fn=fn, skip_rename=True)
self._current_batches = (
create_homogeneous_sampled_graphs_from_dataframe_csc(df)
)
else:
df = _load_sampled_file(dataset_obj=self, fn=fn)
self._current_batches = (
create_homogeneous_sampled_graphs_from_dataframe(
sampled_df=df,
edge_dir=self.edge_dir,
return_type=self._return_type,
)
)
current_offset = idx - batch_offset
return self._current_batches[current_offset]

Expand All @@ -87,7 +100,7 @@ def set_input_files(
)


class HetrogenousBulkSamplerDataset(torch.utils.data.Dataset):
class HeterogenousBulkSamplerDataset(torch.utils.data.Dataset):
def __init__(
self,
num_nodes_dict: Dict[str, int],
Expand Down Expand Up @@ -141,18 +154,18 @@ def set_input_files(
----------
input_directory: str
input_directory which contains all the files that will be
loaded by HetrogenousBulkSamplerDataset
loaded by HeterogenousBulkSamplerDataset
input_file_paths: List[str]
File names that will be loaded by the HetrogenousBulkSamplerDataset
File names that will be loaded by the HeterogenousBulkSamplerDataset
"""
_set_input_files(
self, input_directory=input_directory, input_file_paths=input_file_paths
)


def _load_sampled_file(dataset_obj, fn):
def _load_sampled_file(dataset_obj, fn, skip_rename=False):
df = cudf.read_parquet(os.path.join(fn))
if dataset_obj.edge_dir == "in":
if dataset_obj.edge_dir == "in" and not skip_rename:
df.rename(
columns={"sources": "destinations", "destinations": "sources"},
inplace=True,
Expand Down Expand Up @@ -181,7 +194,7 @@ def get_batch_to_fn_d(files):


def _set_input_files(
dataset_obj: Union[HomogenousBulkSamplerDataset, HetrogenousBulkSamplerDataset],
dataset_obj: Union[HomogenousBulkSamplerDataset, HeterogenousBulkSamplerDataset],
input_directory: Optional[str] = None,
input_file_paths: Optional[List[str]] = None,
) -> None:
Expand Down
155 changes: 154 additions & 1 deletion python/cugraph-dgl/cugraph_dgl/dataloading/utils/sampling_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Tuple, Dict, Optional
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
import cudf
import cupy
from cugraph.utilities.utils import import_optional
from cugraph_dgl.nn import SparseGraph

dgl = import_optional("dgl")
torch = import_optional("torch")
Expand Down Expand Up @@ -401,3 +403,154 @@ def create_heterogenous_dgl_block_from_tensors_dict(
block = dgl.to_block(sampled_graph, dst_nodes=seed_nodes, src_nodes=src_d)
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
return block


def _process_sampled_df_csc(
df: cudf.DataFrame,
reverse_hop_id: bool = True,
) -> Tuple[
Dict[int, Dict[int, Dict[str, torch.Tensor]]],
List[torch.Tensor],
List[List[int, int]],
]:
"""
Convert a dataframe generated by BulkSampler to a dictionary of tensors, to
facilitate MFG creation. The sampled graphs in the dataframe use CSC-format.
Parameters
----------
df: cudf.DataFrame
The output from BulkSampler compressed in CSC format. The dataframe
should be generated with `compression="CSR"` in BulkSampler,
since the sampling routine treats seed nodes as sources.
reverse_hop_id: bool (default=True)
Reverse hop id.
Returns
-------
tensors_dict: dict
A nested dictionary keyed by batch id and hop id.
`tensor_dict[batch_id][hop_id]` holds "minors" and "major_offsets"
values for CSC MFGs.
renumber_map_list: list
List of renumbering maps for looking up global indices of nodes. One
map for each batch.
mfg_sizes: list
List of the number of nodes in each message passing layer. For the
k-th hop, mfg_sizes[k] and mfg_sizes[k+1] is the number of sources and
destinations, respectively.
"""
# dropna
major_offsets = df.major_offsets.dropna().values
label_hop_offsets = df.label_hop_offsets.dropna().values
renumber_map_offsets = df.renumber_map_offsets.dropna().values
renumber_map = df.map.dropna().values
minors = df.minors.dropna().values

n_batches = renumber_map_offsets.size - 1
n_hops = int((label_hop_offsets.size - 1) / n_batches)

# make global offsets local
major_offsets -= major_offsets[0]
label_hop_offsets -= label_hop_offsets[0]
renumber_map_offsets -= renumber_map_offsets[0]

# get the sizes of each adjacency matrix (for MFGs)
mfg_sizes = (label_hop_offsets[1:] - label_hop_offsets[:-1]).reshape(
(n_batches, n_hops)
)
n_nodes = renumber_map_offsets[1:] - renumber_map_offsets[:-1]
mfg_sizes = cupy.hstack((mfg_sizes, n_nodes.reshape(n_batches, -1)))
if reverse_hop_id:
mfg_sizes = mfg_sizes[:, ::-1]

tensors_dict = {}
renumber_map_list = []
for batch_id in range(n_batches):
batch_dict = {}

for hop_id in range(n_hops):
hop_dict = {}
idx = batch_id * n_hops + hop_id # idx in label_hop_offsets
major_offsets_start = label_hop_offsets[idx].item()
major_offsets_end = label_hop_offsets[idx + 1].item()
minors_start = major_offsets[major_offsets_start].item()
minors_end = major_offsets[major_offsets_end].item()
# Note: minors and major_offsets from BulkSampler are of type int32
# and int64 respectively. Since pylibcugraphops binding code doesn't
# support distinct node and edge index type, we simply casting both
# to int32 for now.
hop_dict["minors"] = torch.as_tensor(
minors[minors_start:minors_end], device="cuda"
).int()
hop_dict["major_offsets"] = torch.as_tensor(
major_offsets[major_offsets_start : major_offsets_end + 1]
- major_offsets[major_offsets_start],
device="cuda",
).int()
if reverse_hop_id:
batch_dict[n_hops - 1 - hop_id] = hop_dict
else:
batch_dict[hop_id] = hop_dict

tensors_dict[batch_id] = batch_dict

renumber_map_list.append(
torch.as_tensor(
renumber_map[
renumber_map_offsets[batch_id] : renumber_map_offsets[batch_id + 1]
],
device="cuda",
)
)

return tensors_dict, renumber_map_list, mfg_sizes.tolist()


def _create_homogeneous_sparse_graphs_from_csc(
tensors_dict: Dict[int, Dict[int, Dict[str, torch.Tensor]]],
renumber_map_list: List[torch.Tensor],
mfg_sizes: List[int, int],
) -> List[List[torch.Tensor, torch.Tensor, List[SparseGraph]]]:
"""Create mini-batches of MFGs. The input arguments are the outputs of
the function `_process_sampled_df_csc`.
Returns
-------
output: list
A list of mini-batches. Each mini-batch is a list that consists of
`input_nodes` tensor, `output_nodes` tensor and a list of MFGs.
"""
n_batches, n_hops = len(mfg_sizes), len(mfg_sizes[0]) - 1
output = []
for b_id in range(n_batches):
output_batch = []
output_batch.append(renumber_map_list[b_id])
output_batch.append(renumber_map_list[b_id][: mfg_sizes[b_id][-1]])
mfgs = [
SparseGraph(
size=(mfg_sizes[b_id][h_id], mfg_sizes[b_id][h_id + 1]),
src_ids=tensors_dict[b_id][h_id]["minors"],
cdst_ids=tensors_dict[b_id][h_id]["major_offsets"],
formats=["csc"],
reduce_memory=True,
)
for h_id in range(n_hops)
]

output_batch.append(mfgs)

output.append(output_batch)

return output


def create_homogeneous_sampled_graphs_from_dataframe_csc(sampled_df: cudf.DataFrame):
"""Public API to create mini-batches of MFGs using a dataframe output by
BulkSampler, where the sampled graph is compressed in CSC format."""
return _create_homogeneous_sparse_graphs_from_csc(
*(_process_sampled_df_csc(sampled_df))
)
Loading

0 comments on commit 4826b01

Please sign in to comment.