From 5f21d2ba7fc7231626635060bcc5cc0759b738f0 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Fri, 30 Aug 2024 13:37:26 -0700 Subject: [PATCH] initial write --- .../cugraph_pyg/loader/link_loader.py | 169 ++++++++++ .../cugraph_pyg/loader/node_loader.py | 10 +- .../cugraph_pyg/sampler/sampler.py | 83 ++++- .../gnn/data_loading/dist_io/writer.py | 22 ++ .../cugraph/gnn/data_loading/dist_sampler.py | 308 ++++++++++++++++-- 5 files changed, 560 insertions(+), 32 deletions(-) create mode 100644 python/cugraph-pyg/cugraph_pyg/loader/link_loader.py diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py new file mode 100644 index 00000000000..21e92a817df --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_loader.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import cugraph_pyg +from typing import Union, Tuple, Callable, Optional + +from cugraph.utilities.utils import import_optional + +torch_geometric = import_optional("torch_geometric") +torch = import_optional("torch") + + +class LinkLoader: + """ + Duck-typed version of torch_geometric.loader.LinkLoader. + Loads samples from batches of input nodes using a + `~cugraph_pyg.sampler.BaseSampler.sample_from_edges` + function. + """ + + def __init__( + self, + data: Union[ + "torch_geometric.data.Data", + "torch_geometric.data.HeteroData", + Tuple[ + "torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore" + ], + ], + link_sampler: "cugraph_pyg.sampler.BaseSampler", + edge_label_index: "torch_geometric.typing.InputEdges" = None, + edge_label: "torch_geometric.typing.OptTensor" = None, + edge_label_time: "torch_geometric.typing.OptTensor" = None, + neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"] = None, + neg_sampling_ratio: Optional[Union[int, float]] = None, + transform: Optional[Callable] = None, + transform_sampler_output: Optional[Callable] = None, + filter_per_worker: Optional[bool] = None, + custom_cls: Optional["torch_geometric.data.HeteroData"] = None, + input_id: "torch_geometric.typing.OptTensor" = None, + batch_size: int = 1, # refers to number of edges in batch + shuffle: bool = False, + drop_last: bool = False, + **kwargs, + ): + """ + Parameters + ---------- + data: Data, HeteroData, or Tuple[FeatureStore, GraphStore] + See torch_geometric.loader.NodeLoader. + link_sampler: BaseSampler + See torch_geometric.loader.LinkLoader. + edge_label_index: InputEdges + See torch_geometric.loader.LinkLoader. + edge_label: OptTensor + See torch_geometric.loader.LinkLoader. + edge_label_time: OptTensor + See torch_geometric.loader.LinkLoader. + neg_sampling: Optional[NegativeSampling] + Type of negative sampling to perform, if desired. + See torch_geometric.loader.LinkLoader. + neg_sampling_ratio: Optional[Union[int, float]] + Negative sampling ratio. Affects how many negative + samples are generated. + See torch_geometric.loader.LinkLoader. + transform: Callable (optional, default=None) + This argument currently has no effect. + transform_sampler_output: Callable (optional, default=None) + This argument currently has no effect. + filter_per_worker: bool (optional, default=False) + This argument currently has no effect. + custom_cls: HeteroData + This argument currently has no effect. This loader will + always return a Data or HeteroData object. + input_id: OptTensor + See torch_geometric.loader.LinkLoader. + + """ + if not isinstance(data, (list, tuple)) or not isinstance( + data[1], cugraph_pyg.data.GraphStore + ): + # Will eventually automatically convert these objects to cuGraph objects. + raise NotImplementedError("Currently can't accept non-cugraph graphs") + + if not isinstance(link_sampler, cugraph_pyg.sampler.BaseSampler): + raise NotImplementedError("Must provide a cuGraph sampler") + + if edge_label_time is not None: + raise ValueError("Temporal sampling is currently unsupported") + + if filter_per_worker: + warnings.warn("filter_per_worker is currently ignored") + + if custom_cls is not None: + warnings.warn("custom_cls is currently ignored") + + if transform is not None: + warnings.warn("transform is currently ignored.") + + if transform_sampler_output is not None: + warnings.warn("transform_sampler_output is currently ignored.") + + ( + input_type, + edge_label_index, + ) = torch_geometric.loader.utils.get_edge_label_index( + data, + edge_label_index, + ) + + self.__input_data = torch_geometric.sampler.EdgeSamplerInput( + input_id=torch.arange( + edge_label_index.shape[-1], dtype=torch.int64, device="cuda" + ) + if input_id is None + else input_id, + row=edge_label_index[0], + col=edge_label_index[1], + label=edge_label, + time=edge_label_time, + input_type=input_type, + ) + + self.__data = data + + self.__link_sampler = link_sampler + + self.__batch_size = batch_size + self.__shuffle = shuffle + self.__drop_last = drop_last + + def __iter__(self): + if self.__shuffle: + perm = torch.randperm(self.__input_data.row.numel()) + else: + perm = torch.arange(self.__input_data.row.numel()) + + if self.__drop_last: + d = perm.numel() % self.__batch_size + perm = perm[:-d] + + input_data = torch_geometric.loader.node_loader.EdgeSamplerInput( + input_id=self.__input_data.input_id[perm], + row=self.__input_data.row[perm], + col=self.__input_data.col[perm], + label=None + if self.__input_data.label is None + else self.__input_data.label[perm], + time=None + if self.__input_data.time is None + else self.__input_data.time[perm], + input_type=self.__input_data.input_type, + ) + + return cugraph_pyg.sampler.SampleIterator( + self.__data, self.__link_sampler.sample_from_edges(input_data) + ) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py index 49923783d6b..fe7d2eaeeef 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/node_loader.py @@ -110,8 +110,10 @@ def __init__( input_id, ) - self.__input_data = torch_geometric.loader.node_loader.NodeSamplerInput( - input_id=input_id, + self.__input_data = torch_geometric.sampler.NodeSamplerInput( + input_id=torch.arange(len(input_nodes), dtype=torch.int64, device="cuda") + if input_id is None + else input_id, node=input_nodes, time=None, input_type=input_type, @@ -136,9 +138,7 @@ def __iter__(self): perm = perm[:-d] input_data = torch_geometric.loader.node_loader.NodeSamplerInput( - input_id=None - if self.__input_data.input_id is None - else self.__input_data.input_id[perm], + input_id=self.__input_data.input_id[perm], node=self.__input_data.node[perm], time=None if self.__input_data.time is None diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index 36076ca412d..e7116470447 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -101,10 +101,20 @@ def __next__(self): data.num_sampled_nodes = next_sample.num_sampled_nodes data.num_sampled_edges = next_sample.num_sampled_edges - data.input_id = data.batch - data.seed_time = None + data.input_id = next_sample.metadata[0] data.batch_size = data.input_id.size(0) + if len(next_sample.metadata) == 2: + data.seed_time = next_sample.metadata[1] + elif len(next_sample.metadata) == 4: + ( + data.edge_label_index, + data.edge_label, + data.seed_time, + ) = next_sample.metadata[1:] + else: + raise ValueError("Invalid metadata") + elif isinstance(next_sample, torch_geometric.sampler.HeteroSamplerOutput): col = {} for edge_type, col_idx in next_sample.col: @@ -175,6 +185,9 @@ def __next__(self): self.__base_reader ) + self.__raw_sample_data["input_offsets"] -= self.__raw_sample_data[ + "input_offsets" + ][0].clone() self.__raw_sample_data["label_hop_offsets"] -= self.__raw_sample_data[ "label_hop_offsets" ][0].clone() @@ -266,6 +279,37 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): [num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)] ) + input_index = raw_sample_data["input_index"][ + raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][ + index + 1 + ] + ] + + edge_inverse = ( + ( + raw_sample_data["edge_inverse"][ + (raw_sample_data["input_offsets"][index] * 2) : ( + raw_sample_data["input_offsets"][index + 1] * 2 + ) + ] + ) + if "edge_inverse" in raw_sample_data + else None + ) + + if edge_inverse is None: + metadata = ( + input_index, + None, # TODO this will eventually include time + ) + else: + metadata = ( + input_index, + edge_inverse.view(2, -1), + None, + None, # TODO this will eventually include time + ) + return torch_geometric.sampler.SamplerOutput( node=renumber_map.cpu(), row=minors, @@ -274,6 +318,7 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): batch=renumber_map[:num_seeds], num_sampled_nodes=num_sampled_nodes.cpu(), num_sampled_edges=num_sampled_edges.cpu(), + metadata=metadata, ) def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): @@ -319,6 +364,12 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): [num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)] ) + input_index = raw_sample_data["input_index"][ + raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][ + index + 1 + ] + ] + return torch_geometric.sampler.SamplerOutput( node=renumber_map.cpu(), row=minors, @@ -327,6 +378,10 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): batch=renumber_map[:num_seeds], num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, + metadata=( + input_index, + None, # TODO this will eventually include time + ), ) def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): @@ -358,7 +413,7 @@ def sample_from_nodes( ] ]: reader = self.__sampler.sample_from_nodes( - index.node, batch_size=self.__batch_size, **kwargs + index.node, batch_size=self.__batch_size, input_id=index.input_id, **kwargs ) edge_attrs = self.__graph_store.get_all_edge_attrs() @@ -385,4 +440,24 @@ def sample_from_edges( "torch_geometric.sampler.SamplerOutput", ] ]: - raise NotImplementedError("Edge sampling is currently unimplemented.") + if neg_sampling: + raise NotImplementedError("negative sampling is currently unsupported") + + reader = self.__sampler.sample_from_edges( + torch.stack([index.row, index.col]), # reverse of usual convention + input_id=index.input_id, + **kwargs, + ) + + edge_attrs = self.__graph_store.get_all_edge_attrs() + if ( + len(edge_attrs) == 1 + and edge_attrs[0].edge_type[0] == edge_attrs[0].edge_type[2] + ): + return HomogeneousSampleReader(reader) + else: + # TODO implement heterogeneous sampling + raise NotImplementedError( + "Sampling heterogeneous graphs is currently" + " unsupported in the non-dask API" + ) diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py index 9062e54ef62..04d214fc846 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py @@ -109,6 +109,11 @@ def __write_minibatches_coo(self, minibatch_dict): batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] start_batch_id = batch_id_array_p[0] + input_offsets_p = minibatch_dict["input_offsets"][ + partition_start : (partition_end + 1) + ] + input_index_p = minibatch_dict[input_offsets_p[0] : input_offsets_p[-1]] + start_ix, end_ix = label_hop_offsets_array_p[[0, -1]] majors_array_p = minibatch_dict["majors"][start_ix:end_ix] minors_array_p = minibatch_dict["minors"][start_ix:end_ix] @@ -151,6 +156,8 @@ def __write_minibatches_coo(self, minibatch_dict): "edge_id": edge_id_array_p, "edge_type": edge_type_array_p, "renumber_map_offsets": renumber_map_offsets_array_p, + "input_index": input_index_p, + "input_offsets": input_offsets_p, } ) @@ -201,6 +208,18 @@ def __write_minibatches_csr(self, minibatch_dict): batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end] start_batch_id = batch_id_array_p[0] + input_offsets_p = minibatch_dict["input_offsets"][ + partition_start : (partition_end + 1) + ] + input_index_p = minibatch_dict[input_offsets_p[0] : input_offsets_p[-1]] + edge_inverse_p = ( + minibatch_dict["edge_inverse"][ + (input_offsets_p[0] * 2) : (input_offsets_p[-1] * 2) + ] + if "edge_inverse" in minibatch_dict + else None + ) + # major offsets and minors ( major_offsets_start_incl, @@ -255,6 +274,9 @@ def __write_minibatches_csr(self, minibatch_dict): "edge_id": edge_id_array_p, "edge_type": edge_type_array_p, "renumber_map_offsets": renumber_map_offsets_array_p, + "input_index": input_index_p, + "input_offsets": input_offsets_p, + "edge_inverse": edge_inverse_p, } ) diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index 92ceb951d0c..147f58151c0 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -229,7 +229,7 @@ def get_start_batch_offset( def __sample_from_nodes_func( self, call_id: int, - current_seeds: "torch.Tensor", + current_seeds_and_ix: Tuple["torch.Tensor", "torch.Tensor"], batch_id_start: int, batch_size: int, batches_per_call: int, @@ -238,6 +238,8 @@ def __sample_from_nodes_func( ) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]: torch = import_optional("torch") + current_seeds, current_ix = current_seeds_and_ix + current_batches = torch.arange( batch_id_start + call_id * batches_per_call, batch_id_start @@ -252,12 +254,27 @@ def __sample_from_nodes_func( : len(current_seeds) ] + # do qr division to get the number of batch_size batches and the + # size of the last batch + num_full, last_count = divmod(len(current_seeds), batch_size) + input_offsets = torch.concatenate( + [ + torch.tensor([0], device="cuda", dtype=torch.int64), + torch.full((num_full,), batch_size, device="cuda", dtype=torch.int64), + torch.tensor([last_count], device="cuda", dtype=torch.int64) + if last_count > 0 + else torch.tensor([], device="cuda", dtype=torch.int64), + ] + ).cumsum(-1) + minibatch_dict = self.sample_batches( seeds=current_seeds, batch_ids=current_batches, random_state=random_state, assume_equal_input_size=assume_equal_input_size, ) + minibatch_dict["input_index"] = current_ix.cuda() + minibatch_dict["input_offsets"] = input_offsets if self.__writer is None: # rename renumber_map -> map to match unbuffered format @@ -274,6 +291,41 @@ def __sample_from_nodes_func( self.__writer.write_minibatches(minibatch_dict) return None + def __get_call_groups( + self, + seeds: TensorType, + input_id: TensorType, + seeds_per_call: int, + assume_equal_input_size: bool = False, + ): + # Split the input seeds into call groups. Each call group + # corresponds to one sampling call. A call group contains + # many batches. + seeds_call_groups = torch.split(seeds, seeds_per_call, dim=-1) + index_call_groups = torch.split(input_id, seeds_per_call, dim=-1) + + # Need to add empties to the list of call groups to handle the case + # where not all ranks have the same number of call groups. This + # prevents a hang since we need all ranks to make the same number + # of calls. + if not assume_equal_input_size: + num_call_groups = torch.tensor( + [len(seeds_call_groups)], device="cuda", dtype=torch.int32 + ) + torch.distributed.all_reduce( + num_call_groups, op=torch.distributed.ReduceOp.MAX + ) + seeds_call_groups = list(seeds_call_groups) + ( + [torch.tensor([], dtype=seeds.dtype, device="cuda")] + * (int(num_call_groups) - len(seeds_call_groups)) + ) + index_call_groups = list(index_call_groups) + ( + [torch.tensor([], dtype=torch.int64, device=index_call_groups.device)] + * (int(num_call_groups) - len(index_call_groups)) + ) + + return seeds_call_groups, index_call_groups + def sample_from_nodes( self, nodes: TensorType, @@ -281,6 +333,7 @@ def sample_from_nodes( batch_size: int = 16, random_state: int = 62, assume_equal_input_size: bool = False, + input_id: Optional[TensorType] = None, ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: """ Performs node-based sampling. Accepts a list of seed nodes, and batch size. @@ -297,40 +350,245 @@ def sample_from_nodes( The size of each batch. random_state: int The random seed to use for sampling. + assume_equal_input_size: bool + Whether the inputs across workers should be assumed to be equal in + dimension. Skips some checks if True. + input_id: Optional[TensorType] + Input ids corresponding to the original batch tensor, if it + was permuted prior to calling this function. If present, + will be saved with the samples. """ torch = import_optional("torch") nodes = torch.as_tensor(nodes, device="cuda") + num_seeds = nodes.numel() batches_per_call = self._local_seeds_per_call // batch_size actual_seeds_per_call = batches_per_call * batch_size - # Split the input seeds into call groups. Each call group - # corresponds to one sampling call. A call group contains - # many batches. - num_seeds = len(nodes) - nodes_call_groups = torch.split(nodes, actual_seeds_per_call) + if input_id is None: + input_id = torch.arange(num_seeds, dtype=torch.int64, device="cpu") local_num_batches = int(ceil(num_seeds / batch_size)) batch_id_start, input_size_is_equal = self.get_start_batch_offset( local_num_batches, assume_equal_input_size=assume_equal_input_size ) - # Need to add empties to the list of call groups to handle the case - # where not all nodes have the same number of call groups. This - # prevents a hang since we need all ranks to make the same number - # of calls. - if not input_size_is_equal: - num_call_groups = torch.tensor( - [len(nodes_call_groups)], device="cuda", dtype=torch.int32 - ) - torch.distributed.all_reduce( - num_call_groups, op=torch.distributed.ReduceOp.MAX + nodes_call_groups, index_call_groups = self.__get_call_groups( + nodes, + input_id, + actual_seeds_per_call, + assume_equal_input_size=input_size_is_equal, + ) + + sample_args = ( + batch_id_start, + batch_size, + batches_per_call, + random_state, + input_size_is_equal, + ) + + if self.__writer is None: + # Buffered sampling + return BufferedSampleReader( + zip(nodes_call_groups, index_call_groups), + self.__sample_from_nodes_func, + *sample_args, ) - nodes_call_groups = list(nodes_call_groups) + ( - [torch.tensor([], dtype=nodes.dtype, device="cuda")] - * (int(num_call_groups) - len(nodes_call_groups)) + else: + # Unbuffered sampling + for i, current_seeds_and_ix in enumerate( + zip(nodes_call_groups, index_call_groups) + ): + self.__sample_from_nodes_func( + i, + current_seeds_and_ix, + *sample_args, + ) + + # Return a reader that points to the stored samples + rank = torch.distributed.get_rank() if self.is_multi_gpu else None + return self.__writer.get_reader(rank) + + def __sample_from_edges_func( + self, + call_id: int, + current_seeds_and_ix: Tuple["torch.Tensor", "torch.Tensor"], + batch_id_start: int, + batch_size: int, + batches_per_call: int, + random_state: int, + assume_equal_input_size: bool, + ) -> Union[None, Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]]: + torch = import_optional("torch") + + current_seeds, current_ix = current_seeds_and_ix + num_seed_edges = current_ix.numel() + + # The index gets stored as-is regardless of what makes it into + # the final batch and in what order. + # do qr division to get the number of batch_size batches and the + # size of the last batch + num_whole_batches, last_count = divmod(num_seed_edges, batch_size) + input_offsets = torch.concatenate( + [ + torch.tensor([0], device="cuda", dtype=torch.int64), + torch.full( + (num_whole_batches,), batch_size, device="cuda", dtype=torch.int64 + ), + torch.tensor([last_count], device="cuda", dtype=torch.int64) + if last_count > 0 + else torch.tensor([], device="cuda", dtype=torch.int64), + ] + ).cumsum(-1) + + current_seeds, leftover_seeds = ( + current_seeds[:, num_whole_batches], + current_seeds[:, num_whole_batches:], + ) + + # For input edges, we need to translate this into unique vertices + # for each batch. + # We start by reorganizing the seed and index tensors so we can + # determine the unique vertices. This results in the expected + # src-to-dst concatenation for each batch + current_seeds = torch.concat( + [ + current_seeds[0].reshape((-1, batch_size)), + current_seeds[1].reshape((-1, batch_size)), + ], + axis=-1, + ) + + # The returned unique values must be sorted or else the inverse won't line up + # In the future this may be a good target for a C++ function + # Each element is a tuple of (unique, index, inverse) + # TODO make sure this is compatible with negative sampling + u = [ + torch.unique( + t, + return_inverse=True, + sorted=True, ) + for t in current_seeds + ] + current_seeds = torch.concat([a[0] for a in u]) + current_inv = torch.concat([a[1] for a in u]) + current_batches = torch.concat( + [ + torch.full( + (a[0].numel(),), + i + batch_id_start + (call_id * batches_per_call), + device="cuda", + dtype=torch.int32, + ) + for i, a in enumerate(u) + ] + ) + del u + + # Join with the leftovers + # TODO make sure this is compatible with negative sampling + leftover_seeds, leftover_inv = leftover_seeds.flatten().unique( + return_inverse=True, + sorted=True, + ) + current_seeds = torch.concat([current_seeds, leftover_seeds]) + current_inv = torch.concat([current_inv, leftover_inv]) + current_batches = torch.concat( + [ + current_batches, + torch.full( + (leftover_seeds.numel(),), + current_batches[-1] + 1, + device="cuda", + dtype=torch.int32, + ), + ] + ) + + minibatch_dict = self.sample_batches( + seeds=current_seeds, + batch_ids=current_batches, + random_state=random_state, + assume_equal_input_size=assume_equal_input_size, + ) + minibatch_dict["input_index"] = current_ix.cuda() + minibatch_dict["input_offsets"] = input_offsets + minibatch_dict[ + "edge_inverse" + ] = current_inv # (2 * batch_size) entries per batch + + if self.__writer is None: + # rename renumber_map -> map to match unbuffered format + minibatch_dict["map"] = minibatch_dict["renumber_map"] + del minibatch_dict["renumber_map"] + minibatch_dict = { + k: torch.as_tensor(v, device="cuda") + for k, v in minibatch_dict.items() + if v is not None + } + + return iter([(minibatch_dict, current_batches[0], current_batches[-1])]) + else: + self.__writer.write_minibatches(minibatch_dict) + return None + + def sample_from_edges( + self, + edges: TensorType, + *, + batch_size: int = 16, + random_state: int = 62, + assume_equal_input_size: bool = False, + input_id: Optional[TensorType] = None, + ) -> Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]: + """ + Performs sampling starting from seed edges. + + Parameters + ---------- + edges: TensorType + 2 x (# edges) tensor of edges to sample from. + Standard src/dst format. This will be converted + to a list of seed nodes. + batch_size: int + The size of each batch. + random_state: int + The random seed to use for sampling. + assume_equal_input_size: bool + Whether this function should assume that inputs + are equal across ranks. Skips some potentially + slow steps if True. + input_id: Optional[TensorType] + Input ids corresponding to the original batch tensor, if it + was permuted prior to calling this function. If present, + will be saved with the samples. + """ + + torch = import_optional("torch") + + edges = torch.as_tensor(edges, device="cuda") + num_seed_edges = edges.shape[-1] + + batches_per_call = self._local_seeds_per_call // batch_size + actual_seed_edges_per_call = batches_per_call * batch_size + + if input_id is None: + input_id = torch.arange(len(edges), dtype=torch.int64, device="cpu") + + local_num_batches = int(ceil(num_seed_edges / batch_size)) + batch_id_start, input_size_is_equal = self.get_start_batch_offset( + local_num_batches, assume_equal_input_size=assume_equal_input_size + ) + + edges_call_groups, index_call_groups = self.__get_call_groups( + edges, + input_id, + actual_seed_edges_per_call, + assume_equal_input_size=input_size_is_equal, + ) sample_args = ( batch_id_start, @@ -343,14 +601,18 @@ def sample_from_nodes( if self.__writer is None: # Buffered sampling return BufferedSampleReader( - nodes_call_groups, self.__sample_from_nodes_func, *sample_args + zip(edges_call_groups, index_call_groups), + self.__sample_from_edges_func, + *sample_args, ) else: # Unbuffered sampling - for i, current_seeds in enumerate(nodes_call_groups): - self.__sample_from_nodes_func( + for i, current_seeds_and_ix in enumerate( + zip(edges_call_groups, index_call_groups) + ): + self.__sample_from_edges_func( i, - current_seeds, + current_seeds_and_ix, *sample_args, )