Skip to content

Commit

Permalink
dist sampling io
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Apr 2, 2024
1 parent 9e393a0 commit b35ad1f
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 36 deletions.
105 changes: 105 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/examples/cugraph_dist_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This example shows how to use cuGraph nccl-only comms, pylibcuGraph,
# and PyTorch DDP to run a multi-GPU sampling workflow. Most users of the
# GNN packages will not interact with cuGraph directly. This example
# is intented for users who want to extend cuGraph within a DDP workflow.

import os

import tempfile
import numpy as np
import torch
import torch.multiprocessing as tmp
import torch.distributed as dist

import cudf

from cugraph.gnn import (
cugraph_comms_init,
cugraph_comms_shutdown,
cugraph_comms_create_unique_id,
cugraph_comms_get_raft_handle,
DistSampleWriter,
UniformNeighborSampler
)

from pylibcugraph import MGGraph, ResourceHandle, GraphProperties

from ogb.nodeproppred import NodePropPredDataset


def init_pytorch(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)


def sample(rank: int, world_size: int, uid, edgelist, directory):
init_pytorch(rank, world_size)

device = rank
cugraph_comms_init(rank, world_size, uid, device)

print(f"rank {rank} initialized cugraph")

src = cudf.Series(np.array_split(edgelist[0], world_size)[rank])
dst = cudf.Series(np.array_split(edgelist[1], world_size)[rank])

seeds = cudf.Series(np.arange(rank * 50, (rank + 1) * 50))
handle = ResourceHandle(cugraph_comms_get_raft_handle().getHandle())

print("constructing graph")
G = MGGraph(
handle,
GraphProperties(is_multigraph=True, is_symmetric=False),
[src],
[dst],
)
print("graph constructed")

sample_writer = DistSampleWriter(
directory=directory,
batches_per_partition=2
)
sampler = UniformNeighborSampler(
G,
sample_writer,
fanout=[5,5],
)

sampler.sample_from_nodes(seeds, batch_size=16, random_state=62)

dist.barrier()
cugraph_comms_shutdown()
print(f"rank {rank} shut down cugraph")


def main():
world_size = torch.cuda.device_count()
uid = cugraph_comms_create_unique_id()

dataset = NodePropPredDataset("ogbn-products")
el = dataset[0][0]["edge_index"].astype("int64")

with tempfile.TemporaryDirectory() as directory:
tmp.spawn(
sample,
args=(world_size, uid, el, '.'),
nprocs=world_size,
)


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions python/cugraph/cugraph/gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

from .feature_storage.feat_storage import FeatureStore
from .data_loading.bulk_sampler import BulkSampler
from .data_loading.dist_sampler import (
DistSampler,
DistSampleWriter,
UniformNeighborSampler
)
from .comms.cugraph_nccl_comms import (
cugraph_comms_init,
cugraph_comms_shutdown,
Expand Down
7 changes: 7 additions & 0 deletions python/cugraph/cugraph/gnn/comms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .cugraph_nccl_comms import (
cugraph_comms_init,
cugraph_comms_shutdown,
cugraph_comms_create_unique_id,
cugraph_comms_get_raft_handle,
)
5 changes: 5 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@
# limitations under the License.

from cugraph.gnn.data_loading.bulk_sampler import BulkSampler
from cugraph.gnn.data_loading.dist_sampler import (
DistSampler,
DistSampleWriter,
UniformNeighborSampler
)
2 changes: 1 addition & 1 deletion python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_df_from_disjoint_series(series_list: List[cudf.Series]):
return df


def create_df_from_disjoint_arrays(array_dict: Dict[cupy.array]):
def create_df_from_disjoint_arrays(array_dict: Dict[str, cupy.array]):
for k in list(array_dict.keys()):
array_dict[k] = cudf.Series(array_dict[k], name=k)

Expand Down
122 changes: 87 additions & 35 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import os
import warnings
from math import ceil

import pylibcugraph
Expand All @@ -21,7 +22,7 @@

from typing import Union, List
from cugraph.utilities import import_optional
from cugraph.gnn import cugraph_comms_get_raft_handle
from cugraph.gnn.comms import cugraph_comms_get_raft_handle

from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays

Expand All @@ -32,7 +33,10 @@


class DistSampleWriter:
def __init__(self, format: str, directory: str, batches_per_partition: int):
def __init__(self, directory: str, *, batches_per_partition: int=256, format: str="parquet"):
if format != "parquet":
raise ValueError("Invalid format (currently supported: 'parquet')")

self.__format = format
self.__directory = directory
self.__batches_per_partition = batches_per_partition
Expand All @@ -54,10 +58,12 @@ def __write_minibatches_coo(self, minibatch_dict):
has_edge_types = minibatch_dict['edge_type'] is not None
has_weights = minibatch_dict['weight'] is not None

print(minibatch_dict)
if minibatch_dict['renumber_map'] is None:
raise ValueError("Distributed sampling without renumbering is not supported")

fanout_length = (len(minibatch_dict['label_hop_offsets']) - 1) // len(minibatch_dict['batch_id'])
rank_batch_offset = minibatch_dict['batch_id'][0]

for p in range(0, int(ceil(len(minibatch_dict['batch_id']) / self.__batches_per_partition))):
partition_start = p * (self.__batches_per_partition)
Expand All @@ -68,7 +74,7 @@ def __write_minibatches_coo(self, minibatch_dict):
]

batch_id_array_p = minibatch_dict['batch_id'][partition_start:partition_end]
start_batch_id = batch_id_array_p[0]
start_batch_id = batch_id_array_p[0] - rank_batch_offset

start_ix, end_ix = label_hop_offsets_array_p[[0, -1]]
majors_array_p = minibatch_dict["majors"][start_ix:end_ix]
Expand Down Expand Up @@ -111,9 +117,16 @@ def __write_minibatches_coo(self, minibatch_dict):
)

end_batch_id = start_batch_id + len(batch_id_array_p) - 1
full_output_path = os.path.join(
self.__directory, f"batch={start_batch_id:010d}-{end_batch_id:010d}.parquet"
)
if 'rank' in minibatch_dict:
rank = minibatch_dict['rank']
full_output_path = os.path.join(
self.__directory, f"batch={rank:05d}{start_batch_id:08d}-{rank:05d}{end_batch_id:08d}.parquet"
)
else:
full_output_path = os.path.join(
self.__directory, f"batch={start_batch_id:010d}-{end_batch_id:010d}.parquet"
)


results_dataframe_p.to_parquet(
full_output_path, compression=None, index=False, force_nullable_schema=True
Expand All @@ -137,29 +150,45 @@ def __init__(
graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph],
writer: DistSampleWriter,
local_seeds_per_call: int = 32768,
rank: int = 0,
):
self.__graph = graph
self.__writer = writer
self.__local_seeds_per_call = local_seeds_per_call
self.__rank = rank

def sample_batches(
self, seeds: TensorType, batch_ids: TensorType, random_state: int = 0
):
raise NotImplementedError("Must be implemented by subclass")

def sample_from_nodes(self, nodes: TensorType, batch_size: int, random_state: int):
def sample_from_nodes(self, nodes: TensorType, *, batch_size: int=16, random_state: int=62):
batches_per_call = self._local_seeds_per_call // batch_size
actual_seeds_per_call = batches_per_call * batch_size
num_calls = int(ceil(len(nodes) / actual_seeds_per_call))

nodes = torch.split(torch.as_tensor(nodes, device="cuda"), num_calls)
num_seeds = len(nodes)
nodes = torch.split(torch.as_tensor(nodes, device="cuda"), actual_seeds_per_call)

if self.is_multi_gpu:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

t = torch.empty((world_size,), dtype=torch.int64, device='cuda')
local_size = torch.tensor([int(ceil(num_seeds / batch_size))], dtype=torch.int64, device='cuda')

torch.distributed.all_gather_into_tensor(t, local_size)
if (t != local_size).any() and rank == 0:
warnings.warn(
"Not all ranks received the same number of batches. "
"This might cause your training loop to hang due to uneven inputs."
)

batch_id_start = t.cumsum(dim=0)[rank] - t[0]
else:
batch_id_start = 0

for i, current_seeds in enumerate(nodes):
current_batches = torch.arange(
i * batches_per_call,
(i + 1) * batches_per_call,
batch_id_start + i * batches_per_call,
batch_id_start + (i + 1) * batches_per_call,
device="cuda",
dtype=torch.int32,
)
Expand All @@ -182,67 +211,90 @@ def is_multi_gpu(self):
@property
def _local_seeds_per_call(self):
return self.__local_seeds_per_call

@property
def rank(self):
return self.__rank
def _graph(self):
return self.__graph


class UniformNeighborSampler(DistSampler):
def __init__(
self,
graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph],
fanout: List[int],
prior_sources_behavior: str,
deduplicate_sources: bool,
compression: str,
compress_per_hop: bool,
writer: DistSampleWriter,
*,
local_seeds_per_call:int = 32768,
fanout: List[int] = [-1],
prior_sources_behavior: str = 'exclude',
deduplicate_sources: bool = True,
compression: str = "COO",
compress_per_hop: bool = False,
with_replacement: bool = False,
):
super(graph)
super().__init__(graph, writer, local_seeds_per_call=local_seeds_per_call)
self.__fanout = fanout
self.__prior_sources_behavior = prior_sources_behavior
self.__deduplicate_sources = deduplicate_sources
self.__compress_per_hop = compress_per_hop
self.__compression = compression
self.__with_replacement = with_replacement

def sample_batches(
self, seeds: TensorType, batch_ids: TensorType, random_state: int = 0
):
# FIXME allow skipping of the synchronization logic with a boolean
# flag that assumes all ranks have the same number of batches.
if self.is_multi_gpu:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
handle = pylibcugraph.ResourceHandle(
cugraph_comms_get_raft_handle().getHandle()
)
label_to_output_comm_rank = torch.full(
(len(seeds),), self.rank, dtype=torch.int32, device="cuda"

local_label_list = torch.unique(batch_ids)
local_label_to_output_comm_rank = torch.full(
(len(local_label_list),), rank, dtype=torch.int32, device="cuda"
)
label_list = torch.unique(batch_ids)

num_batches = torch.tensor([len(local_label_list)], device='cuda', dtype=torch.int64)
torch.distributed.all_reduce(num_batches, op=torch.distributed.ReduceOp.SUM)

label_list = torch.empty((num_batches,), device='cuda', dtype=torch.int32)
w1 = torch.distributed.all_gather_into_tensor(label_list, local_label_list, async_op=True)

label_to_output_comm_rank = torch.empty((num_batches,), device='cuda', dtype=torch.int32)
w2 = torch.distributed.all_gather_into_tensor(label_to_output_comm_rank, local_label_to_output_comm_rank, async_op=True)

w1.wait()
w2.wait()

sampling_results_dict = pylibcugraph.uniform_neighbor_sample(
handle,
self.__graph,
start_list=seeds,
batch_id_list=batch_ids,
label_list=label_list,
label_to_output_comm_rank=label_to_output_comm_rank,
self._graph,
start_list=cupy.asarray(seeds),
batch_id_list=cupy.asarray(batch_ids),
label_list=cupy.asarray(label_list),
label_to_output_comm_rank=cupy.asarray(label_to_output_comm_rank),
h_fan_out=np.array(self.__fanout, dtype="int32"),
with_replacement=self.__with_replacement,
do_expensive_check=False,
with_edge_properties=True,
random_state=random_state,
random_state=random_state + rank,
prior_sources_behavior=self.__prior_sources_behavior,
deduplicate_sources=self.__deduplicate_sources,
return_hops=True,
renumber=self.__renumber,
renumber=True,
compression=self.__compression,
compress_per_hop=self.__compress_per_hop,
return_dict=True,
)
sampling_results_dict['rank'] = rank
else:
sampling_results_dict = pylibcugraph.uniform_neighbor_sample(
pylibcugraph.ResourceHandle(),
self.__graph,
start_list=seeds,
batch_id_list=batch_ids,
start_list=cupy.asarray(seeds),
batch_id_list=cupy.asarray(batch_ids),
h_fan_out=np.array(self.__fanout, dtype="int32"),
with_replacement=self.__with_replacement,
do_expensive_check=False,
Expand All @@ -251,7 +303,7 @@ def sample_batches(
prior_sources_behavior=self.__prior_sources_behavior,
deduplicate_sources=self.__deduplicate_sources,
return_hops=self.__return_hops,
renumber=self.__renumber,
renumber=True,
compression=self.__compression,
compress_per_hop=self.__compress_per_hop,
return_dict=True,
Expand Down

0 comments on commit b35ad1f

Please sign in to comment.