diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py index 4ec513cbf9b..ecc51006995 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/neighbor_sampler.py @@ -197,10 +197,8 @@ def sample( if g.is_homogeneous: indices = torch.concat(list(indices)) - ds.sample_from_nodes(indices.long(), batch_size=batch_size) - return HomogeneousSampleReader( - ds.get_reader(), self.output_format, self.edge_dir - ) + reader = ds.sample_from_nodes(indices.long(), batch_size=batch_size) + return HomogeneousSampleReader(reader, self.output_format, self.edge_dir) raise ValueError( "Sampling heterogeneous graphs is currently" diff --git a/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py b/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py index 731ec1b8d6f..7ea608e7e53 100644 --- a/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py +++ b/python/cugraph-dgl/cugraph_dgl/dataloading/sampler.py @@ -20,7 +20,6 @@ create_homogeneous_sampled_graphs_from_tensors_csc, ) -from cugraph.gnn import DistSampleReader from cugraph.utilities.utils import import_optional @@ -33,14 +32,18 @@ class SampleReader: Iterator that processes results from the cuGraph distributed sampler. """ - def __init__(self, base_reader: DistSampleReader, output_format: str = "dgl.Block"): + def __init__( + self, + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]], + output_format: str = "dgl.Block", + ): """ Constructs a new SampleReader. Parameters ---------- - base_reader: DistSampleReader - The reader responsible for loading saved samples produced by + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + The iterator responsible for loading saved samples produced by the cuGraph distributed sampler. """ self.__output_format = output_format @@ -83,7 +86,7 @@ class HomogeneousSampleReader(SampleReader): def __init__( self, - base_reader: DistSampleReader, + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]], output_format: str = "dgl.Block", edge_dir="in", ): @@ -92,7 +95,7 @@ def __init__( Parameters ---------- - base_reader: DistSampleReader + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] The reader responsible for loading saved samples produced by the cuGraph distributed sampler. output_format: str diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py index 7002d7ebded..127ca809d91 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py @@ -185,6 +185,8 @@ def run_train( wall_clock_start, tempdir=None, num_layers=3, + in_memory=False, + seeds_per_call=-1, ): optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) @@ -196,20 +198,23 @@ def run_train( from cugraph_pyg.loader import NeighborLoader ix_train = split_idx["train"].cuda() - train_path = os.path.join(tempdir, f"train_{global_rank}") - os.mkdir(train_path) + train_path = None if in_memory else os.path.join(tempdir, f"train_{global_rank}") + if train_path: + os.mkdir(train_path) train_loader = NeighborLoader( data, input_nodes=ix_train, directory=train_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) ix_test = split_idx["test"].cuda() - test_path = os.path.join(tempdir, f"test_{global_rank}") - os.mkdir(test_path) + test_path = None if in_memory else os.path.join(tempdir, f"test_{global_rank}") + if test_path: + os.mkdir(test_path) test_loader = NeighborLoader( data, input_nodes=ix_test, @@ -221,14 +226,16 @@ def run_train( ) ix_valid = split_idx["valid"].cuda() - valid_path = os.path.join(tempdir, f"valid_{global_rank}") - os.mkdir(valid_path) + valid_path = None if in_memory else os.path.join(tempdir, f"valid_{global_rank}") + if valid_path: + os.mkdir(valid_path) valid_loader = NeighborLoader( data, input_nodes=ix_valid, directory=valid_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) @@ -347,6 +354,9 @@ def parse_args(): parser.add_argument("--skip_partition", action="store_true") parser.add_argument("--wg_mem_type", type=str, default="distributed") + parser.add_argument("--in_memory", action="store_true", default=False) + parser.add_argument("--seeds_per_call", type=int, default=-1) + return parser.parse_args() @@ -429,6 +439,8 @@ def parse_args(): wall_clock_start, tempdir, args.num_layers, + args.in_memory, + args.seeds_per_call, ) else: warnings.warn("This script should be run with 'torchrun`. Exiting.") diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py index 09d874bd87d..0f9c39bf04d 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py @@ -91,10 +91,20 @@ def test(loader: NeighborLoader, val_steps: Optional[int] = None): def create_loader( - data, num_neighbors, input_nodes, replace, batch_size, samples_dir, stage_name + data, + num_neighbors, + input_nodes, + replace, + batch_size, + samples_dir, + stage_name, + local_seeds_per_call, ): - directory = os.path.join(samples_dir, stage_name) - os.mkdir(directory) + if samples_dir is not None: + directory = os.path.join(samples_dir, stage_name) + os.mkdir(directory) + else: + directory = None return NeighborLoader( data, num_neighbors=num_neighbors, @@ -102,6 +112,7 @@ def create_loader( replace=replace, batch_size=batch_size, directory=directory, + local_seeds_per_call=local_seeds_per_call, ) @@ -147,6 +158,8 @@ def parse_args(): parser.add_argument("--tempdir_root", type=str, default=None) parser.add_argument("--dataset_root", type=str, default="dataset") parser.add_argument("--dataset", type=str, default="ogbn-products") + parser.add_argument("--in_memory", action="store_true", default=False) + parser.add_argument("--seeds_per_call", type=int, default=-1) return parser.parse_args() @@ -170,7 +183,10 @@ def parse_args(): "num_neighbors": [args.fan_out] * args.num_layers, "replace": False, "batch_size": args.batch_size, - "samples_dir": samples_dir, + "samples_dir": None if args.in_memory else samples_dir, + "local_seeds_per_call": None + if args.seeds_per_call <= 0 + else args.seeds_per_call, } train_loader = create_loader( diff --git a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py index b1bb0240e71..73efbc92a24 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py @@ -86,6 +86,8 @@ def run_train( wall_clock_start, tempdir=None, num_layers=3, + in_memory=False, + seeds_per_call=-1, ): init_pytorch_worker( @@ -119,20 +121,23 @@ def run_train( dist.barrier() ix_train = torch.tensor_split(split_idx["train"], world_size)[rank].cuda() - train_path = os.path.join(tempdir, f"train_{rank}") - os.mkdir(train_path) + train_path = None if in_memory else os.path.join(tempdir, f"train_{rank}") + if train_path: + os.mkdir(train_path) train_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_train, directory=train_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) ix_test = torch.tensor_split(split_idx["test"], world_size)[rank].cuda() - test_path = os.path.join(tempdir, f"test_{rank}") - os.mkdir(test_path) + test_path = None if in_memory else os.path.join(tempdir, f"test_{rank}") + if test_path: + os.mkdir(test_path) test_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_test, @@ -144,14 +149,16 @@ def run_train( ) ix_valid = torch.tensor_split(split_idx["valid"], world_size)[rank].cuda() - valid_path = os.path.join(tempdir, f"valid_{rank}") - os.mkdir(valid_path) + valid_path = None if in_memory else os.path.join(tempdir, f"valid_{rank}") + if valid_path: + os.mkdir(valid_path) valid_loader = NeighborLoader( (feature_store, graph_store), input_nodes=ix_valid, directory=valid_path, shuffle=True, drop_last=True, + local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None, **kwargs, ) @@ -269,6 +276,8 @@ def run_train( parser.add_argument("--tempdir_root", type=str, default=None) parser.add_argument("--dataset_root", type=str, default="dataset") parser.add_argument("--dataset", type=str, default="ogbn-products") + parser.add_argument("--in_memory", action="store_true", default=False) + parser.add_argument("--seeds_per_call", type=int, default=-1) parser.add_argument( "--n_devices", @@ -322,6 +331,8 @@ def run_train( wall_clock_start, tempdir, args.num_layers, + args.in_memory, + args.seeds_per_call, ), nprocs=world_size, join=True, diff --git a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py index 7f12bbb3fe6..1da2c6dc381 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py @@ -12,7 +12,6 @@ # limitations under the License. import warnings -import tempfile from typing import Union, Tuple, Optional, Callable, List, Dict @@ -123,14 +122,14 @@ def __init__( The number of input nodes per output minibatch. See torch.utils.dataloader. directory: str (optional, default=None) - The directory where samples will be temporarily stored. - It is recommend that this be set by the user, usually - setting it to a tempfile.TemporaryDirectory with a context + The directory where samples will be temporarily stored, + if spilling samples to disk. If None, this loader + will perform buffered in-memory sampling. + If writing to disk, setting this argument + to a tempfile.TemporaryDirectory with a context manager is a good option but depending on the filesystem, you may want to choose an alternative location with fast I/O intead. - If not set, this will create a TemporaryDirectory that will - persist until this object is garbage collected. See cugraph.gnn.DistSampleWriter. batches_per_partition: int (optional, default=256) The number of batches per partition if writing samples to @@ -182,20 +181,19 @@ def __init__( # Will eventually automatically convert these objects to cuGraph objects. raise NotImplementedError("Currently can't accept non-cugraph graphs") - if directory is None: - warnings.warn("Setting a directory to store samples is recommended.") - self._tempdir = tempfile.TemporaryDirectory() - directory = self._tempdir.name - if compression is None: compression = "CSR" elif compression not in ["CSR", "COO"]: raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')") - writer = DistSampleWriter( - directory=directory, - batches_per_partition=batches_per_partition, - format=format, + writer = ( + None + if directory is None + else DistSampleWriter( + directory=directory, + batches_per_partition=batches_per_partition, + format=format, + ) ) feature_store, graph_store = data diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index 268e9ffebbd..36076ca412d 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -14,7 +14,7 @@ from typing import Optional, Iterator, Union, Dict, Tuple from cugraph.utilities.utils import import_optional -from cugraph.gnn import DistSampler, DistSampleReader +from cugraph.gnn import DistSampler from .sampler_utils import filter_cugraph_pyg_store @@ -152,13 +152,15 @@ class SampleReader: Iterator that processes results from the cuGraph distributed sampler. """ - def __init__(self, base_reader: DistSampleReader): + def __init__( + self, base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + ): """ Constructs a new SampleReader. Parameters ---------- - base_reader: DistSampleReader + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] The reader responsible for loading saved samples produced by the cuGraph distributed sampler. """ @@ -202,14 +204,16 @@ class HomogeneousSampleReader(SampleReader): produced by the cuGraph distributed sampler. """ - def __init__(self, base_reader: DistSampleReader): + def __init__( + self, base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + ): """ Constructs a new HomogeneousSampleReader Parameters ---------- - base_reader: DistSampleReader - The reader responsible for loading saved samples produced by + base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]] + The iterator responsible for loading saved samples produced by the cuGraph distributed sampler. """ super().__init__(base_reader) @@ -353,7 +357,7 @@ def sample_from_nodes( "torch_geometric.sampler.SamplerOutput", ] ]: - self.__sampler.sample_from_nodes( + reader = self.__sampler.sample_from_nodes( index.node, batch_size=self.__batch_size, **kwargs ) @@ -362,7 +366,7 @@ def sample_from_nodes( len(edge_attrs) == 1 and edge_attrs[0].edge_type[0] == edge_attrs[0].edge_type[2] ): - return HomogeneousSampleReader(self.__sampler.get_reader()) + return HomogeneousSampleReader(reader) else: # TODO implement heterogeneous sampling raise NotImplementedError( diff --git a/python/cugraph/cugraph/gnn/data_loading/__init__.py b/python/cugraph/cugraph/gnn/data_loading/__init__.py index 9e2c81ec749..25f58be88aa 100644 --- a/python/cugraph/cugraph/gnn/data_loading/__init__.py +++ b/python/cugraph/cugraph/gnn/data_loading/__init__.py @@ -14,9 +14,12 @@ from cugraph.gnn.data_loading.bulk_sampler import BulkSampler from cugraph.gnn.data_loading.dist_sampler import ( DistSampler, + NeighborSampler, +) +from cugraph.gnn.data_loading.dist_io import ( DistSampleWriter, DistSampleReader, - NeighborSampler, + BufferedSampleReader, ) diff --git a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py index 6abbd82647b..222fb49a836 100644 --- a/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py +++ b/python/cugraph/cugraph/gnn/data_loading/bulk_sampler_io.py @@ -33,10 +33,12 @@ def create_df_from_disjoint_series(series_list: List[cudf.Series]): def create_df_from_disjoint_arrays(array_dict: Dict[str, cupy.array]): + series_dict = {} for k in list(array_dict.keys()): - array_dict[k] = cudf.Series(array_dict[k], name=k) + if array_dict[k] is not None: + series_dict[k] = cudf.Series(array_dict[k], name=k) - return create_df_from_disjoint_series(list(array_dict.values())) + return create_df_from_disjoint_series(list(series_dict.values())) def _write_samples_to_parquet_csr( diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/__init__.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/__init__.py new file mode 100644 index 00000000000..29bb5489be2 --- /dev/null +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .reader import BufferedSampleReader, DistSampleReader +from .writer import DistSampleWriter diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/reader.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/reader.py new file mode 100644 index 00000000000..69f909e7a8d --- /dev/null +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/reader.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import re + +import cudf + +from typing import Callable, Iterator, Tuple, Dict, Optional + +from cugraph.utilities.utils import import_optional, MissingModule + +# Prevent PyTorch from being imported and causing an OOM error +torch = MissingModule("torch") + + +class DistSampleReader: + def __init__( + self, + directory: str, + *, + format: str = "parquet", + rank: Optional[int] = None, + filelist=None, + ): + torch = import_optional("torch") + + self.__format = format + self.__directory = directory + + if format != "parquet": + raise ValueError("Invalid format (currently supported: 'parquet')") + + if filelist is None: + files = os.listdir(directory) + ex = re.compile(r"batch\=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet") + filematch = [ex.match(f) for f in files] + filematch = [f for f in filematch if f] + + if rank is not None: + filematch = [f for f in filematch if int(f[1]) == rank] + + batch_count = sum([int(f[4]) - int(f[2]) + 1 for f in filematch]) + filematch = sorted(filematch, key=lambda f: int(f[2]), reverse=True) + + self.__files = filematch + else: + self.__files = list(filelist) + + if rank is None: + self.__batch_count = batch_count + else: + # TODO maybe remove this in favor of warning users that they are + # probably going to cause a hang, instead of attempting to resolve + # the hang for them by dropping batches. + batch_count = torch.tensor([batch_count], device="cuda") + torch.distributed.all_reduce(batch_count, torch.distributed.ReduceOp.MIN) + self.__batch_count = int(batch_count) + + def __iter__(self): + return self + + def __next__(self) -> Tuple[Dict[str, "torch.Tensor"], int, int]: + torch = import_optional("torch") + + if len(self.__files) > 0: + f = self.__files.pop() + fname = f[0] + start_inclusive = int(f[2]) + end_inclusive = int(f[4]) + + if (end_inclusive - start_inclusive + 1) > self.__batch_count: + end_inclusive = start_inclusive + self.__batch_count - 1 + self.__batch_count = 0 + else: + self.__batch_count -= end_inclusive - start_inclusive + 1 + + df = cudf.read_parquet(os.path.join(self.__directory, fname)) + tensors = {} + for col in list(df.columns): + s = df[col].dropna() + if len(s) > 0: + tensors[col] = torch.as_tensor(s, device="cuda") + df.drop(col, axis=1, inplace=True) + + return tensors, start_inclusive, end_inclusive + + raise StopIteration + + +class BufferedSampleReader: + def __init__( + self, + nodes_call_groups: list["torch.Tensor"], + sample_fn: Callable[..., Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]], + *args, + **kwargs, + ): + self.__sample_args = args + self.__sample_kwargs = kwargs + + self.__nodes_call_groups = iter(nodes_call_groups) + self.__sample_fn = sample_fn + self.__current_call_id = 0 + self.__current_reader = None + + def __next__(self) -> Tuple[Dict[str, "torch.Tensor"], int, int]: + new_reader = False + + if self.__current_reader is None: + new_reader = True + else: + try: + out = next(self.__current_reader) + except StopIteration: + new_reader = True + + if new_reader: + # Will trigger StopIteration if there are no more call groups + self.__current_reader = self.__sample_fn( + self.__current_call_id, + next(self.__nodes_call_groups), + *self.__sample_args, + **self.__sample_kwargs, + ) + + self.__current_call_id += 1 + out = next(self.__current_reader) + + return out + + def __iter__(self) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + return self diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py new file mode 100644 index 00000000000..9062e54ef62 --- /dev/null +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py @@ -0,0 +1,287 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from math import ceil + + +import cupy + +from cugraph.utilities.utils import MissingModule +from cugraph.gnn.data_loading.dist_io import DistSampleReader + +from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays + +from typing import Iterator, Tuple, Dict + +torch = MissingModule("torch") + + +class DistSampleWriter: + def __init__( + self, + directory: str, + *, + batches_per_partition: int = 256, + format: str = "parquet", + ): + """ + Parameters + ---------- + directory: str (required) + The directory where samples will be written. This + writer can only write to disk. + batches_per_partition: int (optional, default=256) + The number of batches to write in a single file. + format: str (optional, default='parquet') + The file format of the output files containing the + sampled minibatches. Currently, only parquet format + is supported. + """ + if format != "parquet": + raise ValueError("Invalid format (currently supported: 'parquet')") + + self.__format = format + self.__directory = directory + self.__batches_per_partition = batches_per_partition + + @property + def _format(self): + return self.__format + + @property + def _directory(self): + return self.__directory + + @property + def _batches_per_partition(self): + return self.__batches_per_partition + + def get_reader( + self, rank: int + ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + """ + Returns an iterator over sampled data. + """ + + # currently only disk reading is supported + return DistSampleReader(self._directory, format=self._format, rank=rank) + + def __write_minibatches_coo(self, minibatch_dict): + has_edge_ids = minibatch_dict["edge_id"] is not None + has_edge_types = minibatch_dict["edge_type"] is not None + has_weights = minibatch_dict["weight"] is not None + + if minibatch_dict["renumber_map"] is None: + raise ValueError( + "Distributed sampling without renumbering is not supported" + ) + + # Quit if there are no batches to write. + if len(minibatch_dict["batch_id"]) == 0: + return + + fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( + minibatch_dict["batch_id"] + ) + + for p in range( + 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) + ): + partition_start = p * (self.__batches_per_partition) + partition_end = (p + 1) * (self.__batches_per_partition) + + label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ + partition_start * fanout_length : partition_end * fanout_length + 1 + ] + + batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] + start_batch_id = batch_id_array_p[0] + + start_ix, end_ix = label_hop_offsets_array_p[[0, -1]] + majors_array_p = minibatch_dict["majors"][start_ix:end_ix] + minors_array_p = minibatch_dict["minors"][start_ix:end_ix] + edge_id_array_p = ( + minibatch_dict["edge_id"][start_ix:end_ix] + if has_edge_ids + else cupy.array([], dtype="int64") + ) + edge_type_array_p = ( + minibatch_dict["edge_type"][start_ix:end_ix] + if has_edge_types + else cupy.array([], dtype="int32") + ) + weight_array_p = ( + minibatch_dict["weight"][start_ix:end_ix] + if has_weights + else cupy.array([], dtype="float32") + ) + + # create the renumber map offsets + renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ + partition_start : partition_end + 1 + ] + + renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ + [0, -1] + ] + + renumber_map_array_p = minibatch_dict["renumber_map"][ + renumber_map_start_ix:renumber_map_end_ix + ] + + results_dataframe_p = create_df_from_disjoint_arrays( + { + "majors": majors_array_p, + "minors": minors_array_p, + "map": renumber_map_array_p, + "label_hop_offsets": label_hop_offsets_array_p, + "weight": weight_array_p, + "edge_id": edge_id_array_p, + "edge_type": edge_type_array_p, + "renumber_map_offsets": renumber_map_offsets_array_p, + } + ) + + end_batch_id = start_batch_id + len(batch_id_array_p) - 1 + rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 + + full_output_path = os.path.join( + self.__directory, + f"batch={rank:05d}.{start_batch_id:08d}-" + f"{rank:05d}.{end_batch_id:08d}.parquet", + ) + + results_dataframe_p.to_parquet( + full_output_path, + compression=None, + index=False, + force_nullable_schema=True, + ) + + def __write_minibatches_csr(self, minibatch_dict): + has_edge_ids = minibatch_dict["edge_id"] is not None + has_edge_types = minibatch_dict["edge_type"] is not None + has_weights = minibatch_dict["weight"] is not None + + if minibatch_dict["renumber_map"] is None: + raise ValueError( + "Distributed sampling without renumbering is not supported" + ) + + # Quit if there are no batches to write. + if len(minibatch_dict["batch_id"]) == 0: + return + + fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( + minibatch_dict["batch_id"] + ) + + for p in range( + 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) + ): + partition_start = p * (self.__batches_per_partition) + partition_end = (p + 1) * (self.__batches_per_partition) + + label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ + partition_start * fanout_length : partition_end * fanout_length + 1 + ] + + batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] + start_batch_id = batch_id_array_p[0] + + # major offsets and minors + ( + major_offsets_start_incl, + major_offsets_end_incl, + ) = label_hop_offsets_array_p[[0, -1]] + + start_ix, end_ix = minibatch_dict["major_offsets"][ + [major_offsets_start_incl, major_offsets_end_incl] + ] + + major_offsets_array_p = minibatch_dict["major_offsets"][ + major_offsets_start_incl : major_offsets_end_incl + 1 + ] + + minors_array_p = minibatch_dict["minors"][start_ix:end_ix] + edge_id_array_p = ( + minibatch_dict["edge_id"][start_ix:end_ix] + if has_edge_ids + else cupy.array([], dtype="int64") + ) + edge_type_array_p = ( + minibatch_dict["edge_type"][start_ix:end_ix] + if has_edge_types + else cupy.array([], dtype="int32") + ) + weight_array_p = ( + minibatch_dict["weight"][start_ix:end_ix] + if has_weights + else cupy.array([], dtype="float32") + ) + + # create the renumber map offsets + renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ + partition_start : partition_end + 1 + ] + + renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ + [0, -1] + ] + + renumber_map_array_p = minibatch_dict["renumber_map"][ + renumber_map_start_ix:renumber_map_end_ix + ] + + results_dataframe_p = create_df_from_disjoint_arrays( + { + "major_offsets": major_offsets_array_p, + "minors": minors_array_p, + "map": renumber_map_array_p, + "label_hop_offsets": label_hop_offsets_array_p, + "weight": weight_array_p, + "edge_id": edge_id_array_p, + "edge_type": edge_type_array_p, + "renumber_map_offsets": renumber_map_offsets_array_p, + } + ) + + end_batch_id = start_batch_id + len(batch_id_array_p) - 1 + rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 + + full_output_path = os.path.join( + self.__directory, + f"batch={rank:05d}.{start_batch_id:08d}-" + f"{rank:05d}.{end_batch_id:08d}.parquet", + ) + + results_dataframe_p.to_parquet( + full_output_path, + compression=None, + index=False, + force_nullable_schema=True, + ) + + def write_minibatches(self, minibatch_dict): + if (minibatch_dict["majors"] is not None) and ( + minibatch_dict["minors"] is not None + ): + self.__write_minibatches_coo(minibatch_dict) + elif (minibatch_dict["major_offsets"] is not None) and ( + minibatch_dict["minors"] is not None + ): + self.__write_minibatches_csr(minibatch_dict) + else: + raise ValueError("invalid columns") diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index 52ffd8fadfd..57a8f1bedfb 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -11,8 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re import warnings from math import ceil from functools import reduce @@ -27,348 +25,19 @@ from cugraph.utilities.utils import import_optional, MissingModule from cugraph.gnn.comms import cugraph_comms_get_raft_handle -from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays + +from cugraph.gnn.data_loading.dist_io import BufferedSampleReader +from cugraph.gnn.data_loading.dist_io import DistSampleWriter torch = MissingModule("torch") TensorType = Union["torch.Tensor", cupy.ndarray, cudf.Series] -class DistSampleReader: - def __init__( - self, - directory: str, - *, - format: str = "parquet", - rank: Optional[int] = None, - filelist=None, - ): - torch = import_optional("torch") - - self.__format = format - self.__directory = directory - - if format != "parquet": - raise ValueError("Invalid format (currently supported: 'parquet')") - - if filelist is None: - files = os.listdir(directory) - ex = re.compile(r"batch\=([0-9]+)\.([0-9]+)\-([0-9]+)\.([0-9]+)\.parquet") - filematch = [ex.match(f) for f in files] - filematch = [f for f in filematch if f] - - if rank is not None: - filematch = [f for f in filematch if int(f[1]) == rank] - - batch_count = sum([int(f[4]) - int(f[2]) + 1 for f in filematch]) - filematch = sorted(filematch, key=lambda f: int(f[2]), reverse=True) - - self.__files = filematch - else: - self.__files = list(filelist) - - if rank is None: - self.__batch_count = batch_count - else: - batch_count = torch.tensor([batch_count], device="cuda") - torch.distributed.all_reduce(batch_count, torch.distributed.ReduceOp.MIN) - self.__batch_count = int(batch_count) - - def __iter__(self): - return self - - def __next__(self): - torch = import_optional("torch") - - if len(self.__files) > 0: - f = self.__files.pop() - fname = f[0] - start_inclusive = int(f[2]) - end_inclusive = int(f[4]) - - if (end_inclusive - start_inclusive + 1) > self.__batch_count: - end_inclusive = start_inclusive + self.__batch_count - 1 - self.__batch_count = 0 - else: - self.__batch_count -= end_inclusive - start_inclusive + 1 - - df = cudf.read_parquet(os.path.join(self.__directory, fname)) - tensors = {} - for col in list(df.columns): - s = df[col].dropna() - if len(s) > 0: - tensors[col] = torch.as_tensor(s, device="cuda") - df.drop(col, axis=1, inplace=True) - - return tensors, start_inclusive, end_inclusive - - raise StopIteration - - -class DistSampleWriter: - def __init__( - self, - directory: str, - *, - batches_per_partition: int = 256, - format: str = "parquet", - ): - """ - Parameters - ---------- - directory: str (required) - The directory where samples will be written. This - writer can only write to disk. - batches_per_partition: int (optional, default=256) - The number of batches to write in a single file. - format: str (optional, default='parquet') - The file format of the output files containing the - sampled minibatches. Currently, only parquet format - is supported. - """ - if format != "parquet": - raise ValueError("Invalid format (currently supported: 'parquet')") - - self.__format = format - self.__directory = directory - self.__batches_per_partition = batches_per_partition - - @property - def _format(self): - return self.__format - - @property - def _directory(self): - return self.__directory - - @property - def _batches_per_partition(self): - return self.__batches_per_partition - - def get_reader( - self, rank: int - ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: - """ - Returns an iterator over sampled data. - """ - - # currently only disk reading is supported - return DistSampleReader(self._directory, format=self._format, rank=rank) - - def __write_minibatches_coo(self, minibatch_dict): - has_edge_ids = minibatch_dict["edge_id"] is not None - has_edge_types = minibatch_dict["edge_type"] is not None - has_weights = minibatch_dict["weight"] is not None - - if minibatch_dict["renumber_map"] is None: - raise ValueError( - "Distributed sampling without renumbering is not supported" - ) - - # Quit if there are no batches to write. - if len(minibatch_dict["batch_id"]) == 0: - return - - fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( - minibatch_dict["batch_id"] - ) - rank_batch_offset = minibatch_dict["batch_id"][0] - - for p in range( - 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) - ): - partition_start = p * (self.__batches_per_partition) - partition_end = (p + 1) * (self.__batches_per_partition) - - label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ - partition_start * fanout_length : partition_end * fanout_length + 1 - ] - - batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] - start_batch_id = batch_id_array_p[0] - rank_batch_offset - - start_ix, end_ix = label_hop_offsets_array_p[[0, -1]] - majors_array_p = minibatch_dict["majors"][start_ix:end_ix] - minors_array_p = minibatch_dict["minors"][start_ix:end_ix] - edge_id_array_p = ( - minibatch_dict["edge_id"][start_ix:end_ix] - if has_edge_ids - else cupy.array([], dtype="int64") - ) - edge_type_array_p = ( - minibatch_dict["edge_type"][start_ix:end_ix] - if has_edge_types - else cupy.array([], dtype="int32") - ) - weight_array_p = ( - minibatch_dict["weight"][start_ix:end_ix] - if has_weights - else cupy.array([], dtype="float32") - ) - - # create the renumber map offsets - renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ - partition_start : partition_end + 1 - ] - - renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ - [0, -1] - ] - - renumber_map_array_p = minibatch_dict["renumber_map"][ - renumber_map_start_ix:renumber_map_end_ix - ] - - results_dataframe_p = create_df_from_disjoint_arrays( - { - "majors": majors_array_p, - "minors": minors_array_p, - "map": renumber_map_array_p, - "label_hop_offsets": label_hop_offsets_array_p, - "weight": weight_array_p, - "edge_id": edge_id_array_p, - "edge_type": edge_type_array_p, - "renumber_map_offsets": renumber_map_offsets_array_p, - } - ) - - end_batch_id = start_batch_id + len(batch_id_array_p) - 1 - rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 - - full_output_path = os.path.join( - self.__directory, - f"batch={rank:05d}.{start_batch_id:08d}-" - f"{rank:05d}.{end_batch_id:08d}.parquet", - ) - - results_dataframe_p.to_parquet( - full_output_path, - compression=None, - index=False, - force_nullable_schema=True, - ) - - def __write_minibatches_csr(self, minibatch_dict): - has_edge_ids = minibatch_dict["edge_id"] is not None - has_edge_types = minibatch_dict["edge_type"] is not None - has_weights = minibatch_dict["weight"] is not None - - if minibatch_dict["renumber_map"] is None: - raise ValueError( - "Distributed sampling without renumbering is not supported" - ) - - # Quit if there are no batches to write. - if len(minibatch_dict["batch_id"]) == 0: - return - - fanout_length = (len(minibatch_dict["label_hop_offsets"]) - 1) // len( - minibatch_dict["batch_id"] - ) - - for p in range( - 0, int(ceil(len(minibatch_dict["batch_id"]) / self.__batches_per_partition)) - ): - partition_start = p * (self.__batches_per_partition) - partition_end = (p + 1) * (self.__batches_per_partition) - - label_hop_offsets_array_p = minibatch_dict["label_hop_offsets"][ - partition_start * fanout_length : partition_end * fanout_length + 1 - ] - - batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] - start_batch_id = batch_id_array_p[0] - - # major offsets and minors - ( - major_offsets_start_incl, - major_offsets_end_incl, - ) = label_hop_offsets_array_p[[0, -1]] - - start_ix, end_ix = minibatch_dict["major_offsets"][ - [major_offsets_start_incl, major_offsets_end_incl] - ] - - major_offsets_array_p = minibatch_dict["major_offsets"][ - major_offsets_start_incl : major_offsets_end_incl + 1 - ] - - minors_array_p = minibatch_dict["minors"][start_ix:end_ix] - edge_id_array_p = ( - minibatch_dict["edge_id"][start_ix:end_ix] - if has_edge_ids - else cupy.array([], dtype="int64") - ) - edge_type_array_p = ( - minibatch_dict["edge_type"][start_ix:end_ix] - if has_edge_types - else cupy.array([], dtype="int32") - ) - weight_array_p = ( - minibatch_dict["weight"][start_ix:end_ix] - if has_weights - else cupy.array([], dtype="float32") - ) - - # create the renumber map offsets - renumber_map_offsets_array_p = minibatch_dict["renumber_map_offsets"][ - partition_start : partition_end + 1 - ] - - renumber_map_start_ix, renumber_map_end_ix = renumber_map_offsets_array_p[ - [0, -1] - ] - - renumber_map_array_p = minibatch_dict["renumber_map"][ - renumber_map_start_ix:renumber_map_end_ix - ] - - results_dataframe_p = create_df_from_disjoint_arrays( - { - "major_offsets": major_offsets_array_p, - "minors": minors_array_p, - "map": renumber_map_array_p, - "label_hop_offsets": label_hop_offsets_array_p, - "weight": weight_array_p, - "edge_id": edge_id_array_p, - "edge_type": edge_type_array_p, - "renumber_map_offsets": renumber_map_offsets_array_p, - } - ) - - end_batch_id = start_batch_id + len(batch_id_array_p) - 1 - rank = minibatch_dict["rank"] if "rank" in minibatch_dict else 0 - - full_output_path = os.path.join( - self.__directory, - f"batch={rank:05d}.{start_batch_id:08d}-" - f"{rank:05d}.{end_batch_id:08d}.parquet", - ) - - results_dataframe_p.to_parquet( - full_output_path, - compression=None, - index=False, - force_nullable_schema=True, - ) - - def write_minibatches(self, minibatch_dict): - if (minibatch_dict["majors"] is not None) and ( - minibatch_dict["minors"] is not None - ): - self.__write_minibatches_coo(minibatch_dict) - elif (minibatch_dict["major_offsets"] is not None) and ( - minibatch_dict["minors"] is not None - ): - self.__write_minibatches_csr(minibatch_dict) - else: - raise ValueError("invalid columns") - - class DistSampler: def __init__( self, graph: Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph], - writer: DistSampleWriter, + writer: Optional[DistSampleWriter], local_seeds_per_call: int, retain_original_seeds: bool = False, ): @@ -379,7 +48,8 @@ def __init__( The pylibcugraph graph object that will be sampled. writer: DistSampleWriter (required) The writer responsible for writing samples to disk - or, in the future, device or host memory. + or; if None, then samples will be written to memory + instead. local_seeds_per_call: int The number of seeds on this rank this sampler will process in a single sampling call. Batches will @@ -402,14 +72,6 @@ def __init__( self.__handle = None self.__retain_original_seeds = retain_original_seeds - def get_reader(self) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: - """ - Returns an iterator over sampled data. - """ - torch = import_optional("torch") - rank = torch.distributed.get_rank() if self.is_multi_gpu else None - return self.__writer.get_reader(rank) - def sample_batches( self, seeds: TensorType, @@ -564,6 +226,54 @@ def get_start_batch_offset( else: return 0, input_size_is_equal + def __sample_from_nodes_func( + self, + call_id: int, + current_seeds: "torch.Tensor", + batch_id_start: int, + batch_size: int, + batches_per_call: int, + random_state: int, + assume_equal_input_size: bool, + ) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]: + torch = import_optional("torch") + + current_batches = torch.arange( + batch_id_start + call_id * batches_per_call, + batch_id_start + + call_id * batches_per_call + + int(ceil(len(current_seeds))) + + 1, + device="cuda", + dtype=torch.int32, + ) + + current_batches = current_batches.repeat_interleave(batch_size)[ + : len(current_seeds) + ] + + minibatch_dict = self.sample_batches( + seeds=current_seeds, + batch_ids=current_batches, + random_state=random_state, + assume_equal_input_size=assume_equal_input_size, + ) + + if self.__writer is None: + # rename renumber_map -> map to match unbuffered format + minibatch_dict["map"] = minibatch_dict["renumber_map"] + del minibatch_dict["renumber_map"] + minibatch_dict = { + k: torch.as_tensor(v, device="cuda") + for k, v in minibatch_dict.items() + if v is not None + } + + return iter([(minibatch_dict, current_batches[0], current_batches[-1])]) + else: + self.__writer.write_minibatches(minibatch_dict) + return None + def sample_from_nodes( self, nodes: TensorType, @@ -571,7 +281,7 @@ def sample_from_nodes( batch_size: int = 16, random_state: int = 62, assume_equal_input_size: bool = False, - ): + ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: """ Performs node-based sampling. Accepts a list of seed nodes, and batch size. Splits the seed list into batches, then divides the batches into call groups @@ -622,29 +332,31 @@ def sample_from_nodes( * (int(num_call_groups) - len(nodes_call_groups)) ) - # Make a call to sample_batches for each call group - for i, current_seeds in enumerate(nodes_call_groups): - current_batches = torch.arange( - batch_id_start + i * batches_per_call, - batch_id_start - + i * batches_per_call - + int(ceil(len(current_seeds))) - + 1, - device="cuda", - dtype=torch.int32, - ) - - current_batches = current_batches.repeat_interleave(batch_size)[ - : len(current_seeds) - ] + sample_args = ( + batch_id_start, + batch_size, + batches_per_call, + random_state, + input_size_is_equal, + ) - minibatch_dict = self.sample_batches( - seeds=current_seeds, - batch_ids=current_batches, - random_state=random_state, - assume_equal_input_size=input_size_is_equal, + if self.__writer is None: + # Buffered sampling + return BufferedSampleReader( + nodes_call_groups, self.__sample_from_nodes_func, *sample_args ) - self.__writer.write_minibatches(minibatch_dict) + else: + # Unbuffered sampling + for i, current_seeds in enumerate(nodes_call_groups): + self.__sample_from_nodes_func( + i, + current_seeds, + *sample_args, + ) + + # Return a reader that points to the stored samples + rank = torch.distributed.get_rank() if self.is_multi_gpu else None + return self.__writer.get_reader(rank) @property def is_multi_gpu(self): @@ -709,6 +421,8 @@ def __init__( # sampling. So setting the function here is safe. In the future, # if libcugraph allows setting a new attribute, this API might # change. + # TODO allow func to be a call to a future remote sampling API + # if the provided graph is in another process (rapidsai/cugraph#4623). self.__func = ( pylibcugraph.biased_neighbor_sample if biased diff --git a/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py b/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py index 70b20e7baec..64db0232fb1 100644 --- a/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py +++ b/python/cugraph/cugraph/tests/sampling/test_dist_sampler.py @@ -20,6 +20,7 @@ from cugraph.datasets import karate from cugraph.gnn import UniformNeighborSampler, DistSampleWriter +from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays from pylibcugraph import SGGraph, ResourceHandle, GraphProperties @@ -41,7 +42,7 @@ @pytest.fixture -def karate_graph(): +def karate_graph() -> SGGraph: el = karate.get_edgelist().reset_index().rename(columns={"index": "eid"}) G = SGGraph( ResourceHandle(), @@ -101,3 +102,60 @@ def test_dist_sampler_simple( assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i] shutil.rmtree(samples_path) + + +@pytest.mark.sg +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("seeds_per_call", [4, 5, 10]) +@pytest.mark.parametrize("compression", ["COO", "CSR"]) +def test_dist_sampler_buffered_in_memory( + scratch_dir: str, karate_graph: SGGraph, seeds_per_call: int, compression: str +): + G = karate_graph + + samples_path = os.path.join(scratch_dir, "test_bulk_sampler_buffered_in_memory") + create_directory_with_overwrite(samples_path) + + seeds = cupy.arange(10, dtype="int64") + + unbuffered_sampler = UniformNeighborSampler( + G, + writer=DistSampleWriter(samples_path), + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + buffered_sampler = UniformNeighborSampler( + G, + writer=None, + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + unbuffered_results = unbuffered_sampler.sample_from_nodes( + seeds, + batch_size=4, + ) + + unbuffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in unbuffered_results + ] + + buffered_results = buffered_sampler.sample_from_nodes(seeds, batch_size=4) + buffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in buffered_results + ] + + assert len(buffered_results) == len(unbuffered_results) + + for k in range(len(buffered_results)): + br, bs, be = buffered_results[k] + ur, us, ue = unbuffered_results[k] + + assert bs == us + assert be == ue + + for col in ur.columns: + assert (br[col].dropna() == ur[col].dropna()).all() + + shutil.rmtree(samples_path) diff --git a/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py b/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py index a1c32938994..5bb541d6cf3 100644 --- a/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py +++ b/python/cugraph/cugraph/tests/sampling/test_dist_sampler_mg.py @@ -18,6 +18,8 @@ import cupy import cudf +from typing import Any + from cugraph.datasets import karate from cugraph.gnn import ( UniformNeighborSampler, @@ -27,6 +29,7 @@ cugraph_comms_init, cugraph_comms_shutdown, ) +from cugraph.gnn.data_loading.bulk_sampler_io import create_df_from_disjoint_arrays from pylibcugraph import MGGraph, ResourceHandle, GraphProperties from cugraph.utilities.utils import ( @@ -235,3 +238,80 @@ def test_dist_sampler_uneven(scratch_dir, batch_size, fanout, seeds_per_call): assert original_el.dst.iloc[edge_id.iloc[i]] == dst.iloc[i] shutil.rmtree(samples_path) + + +def run_test_dist_sampler_buffered_in_memory( + rank: int, + world_size: int, + uid: Any, + samples_path: str, + seeds_per_call: int, + compression: str, +): + init_pytorch(rank, world_size) + cugraph_comms_init(rank, world_size, uid, device=rank) + + G = karate_mg_graph(rank, world_size) + + num_seeds = 8 + seeds = cupy.random.randint(0, 34, num_seeds, dtype="int64") + + unbuffered_sampler = UniformNeighborSampler( + G, + writer=DistSampleWriter(samples_path), + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + buffered_sampler = UniformNeighborSampler( + G, + writer=None, + local_seeds_per_call=seeds_per_call, + compression=compression, + ) + + unbuffered_results = unbuffered_sampler.sample_from_nodes( + seeds, + batch_size=4, + ) + + unbuffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in unbuffered_results + ] + + buffered_results = buffered_sampler.sample_from_nodes(seeds, batch_size=4) + buffered_results = [ + (create_df_from_disjoint_arrays(r[0]), r[1], r[2]) for r in buffered_results + ] + + assert len(buffered_results) == len(unbuffered_results) + + for k in range(len(buffered_results)): + br, bs, be = buffered_results[k] + ur, us, ue = unbuffered_results[k] + + assert bs == us + assert be == ue + + for col in ur.columns: + assert (br[col].dropna() == ur[col].dropna()).all() + + +@pytest.mark.mg +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.parametrize("seeds_per_call", [4, 5, 10]) +@pytest.mark.parametrize("compression", ["COO", "CSR"]) +def test_dist_sampler_buffered_in_memory(scratch_dir, seeds_per_call, compression): + uid = cugraph_comms_create_unique_id() + + samples_path = os.path.join(scratch_dir, "test_bulk_sampler_buffered_in_memory_mg") + create_directory_with_overwrite(samples_path) + + world_size = torch.cuda.device_count() + torch.multiprocessing.spawn( + run_test_dist_sampler_buffered_in_memory, + args=(world_size, uid, samples_path, seeds_per_call, compression), + nprocs=world_size, + ) + + shutil.rmtree(samples_path)