diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index e7116470447..e9714bd0316 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -370,6 +370,31 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): ] ] + edge_inverse = ( + ( + raw_sample_data["edge_inverse"][ + (raw_sample_data["input_offsets"][index] * 2) : ( + raw_sample_data["input_offsets"][index + 1] * 2 + ) + ] + ) + if "edge_inverse" in raw_sample_data + else None + ) + + if edge_inverse is None: + metadata = ( + input_index, + None, # TODO this will eventually include time + ) + else: + metadata = ( + input_index, + edge_inverse.view(2, -1), + None, + None, # TODO this will eventually include time + ) + return torch_geometric.sampler.SamplerOutput( node=renumber_map.cpu(), row=minors, @@ -378,10 +403,7 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): batch=renumber_map[:num_seeds], num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, - metadata=( - input_index, - None, # TODO this will eventually include time - ), + metadata=metadata, ) def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py index 04d214fc846..73d9630be73 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py @@ -113,6 +113,13 @@ def __write_minibatches_coo(self, minibatch_dict): partition_start : (partition_end + 1) ] input_index_p = minibatch_dict[input_offsets_p[0] : input_offsets_p[-1]] + edge_inverse_p = ( + minibatch_dict["edge_inverse"][ + (input_offsets_p[0] * 2) : (input_offsets_p[-1] * 2) + ] + if "edge_inverse" in minibatch_dict + else None + ) start_ix, end_ix = label_hop_offsets_array_p[[0, -1]] majors_array_p = minibatch_dict["majors"][start_ix:end_ix] @@ -158,6 +165,7 @@ def __write_minibatches_coo(self, minibatch_dict): "renumber_map_offsets": renumber_map_offsets_array_p, "input_index": input_index_p, "input_offsets": input_offsets_p, + "edge_inverse": edge_inverse_p, } ) diff --git a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py index 147f58151c0..4036b840a96 100644 --- a/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py +++ b/python/cugraph/cugraph/gnn/data_loading/dist_sampler.py @@ -464,17 +464,31 @@ def __sample_from_edges_func( # The returned unique values must be sorted or else the inverse won't line up # In the future this may be a good target for a C++ function # Each element is a tuple of (unique, index, inverse) - # TODO make sure this is compatible with negative sampling - u = [ - torch.unique( + # The seeds must be presorted with a stable sort prior to calling + # unique_consecutive in order to support negative sampling. This is + # because if we put positive edges after negative ones, then we may + # inadvertently turn a true positive into a false negative. + y = ( + torch.sort( t, - return_inverse=True, - sorted=True, + stable=True, ) for t in current_seeds + ) + z = ((v, torch.sort(i)[1]) for v, i in y) + + u = [ + ( + torch.unique_consecutive( + t, + return_inverse=True, + ), + i, + ) + for t, i in z ] - current_seeds = torch.concat([a[0] for a in u]) - current_inv = torch.concat([a[1] for a in u]) + current_seeds = torch.concat([a[0] for a, _ in u]) + current_inv = torch.concat([a[1][i] for a, i in u]) current_batches = torch.concat( [ torch.full( @@ -489,11 +503,14 @@ def __sample_from_edges_func( del u # Join with the leftovers - # TODO make sure this is compatible with negative sampling - leftover_seeds, leftover_inv = leftover_seeds.flatten().unique( - return_inverse=True, - sorted=True, + leftover_seeds, lyi = torch.sort( + leftover_seeds.flatten(), + stable=True, ) + lz = torch.sort(lyi)[1] + leftover_seeds, lui = leftover_seeds.unique_consecutive(return_inverse=True) + leftover_inv = lui[lz] + current_seeds = torch.concat([current_seeds, leftover_seeds]) current_inv = torch.concat([current_inv, leftover_inv]) current_batches = torch.concat( @@ -507,6 +524,9 @@ def __sample_from_edges_func( ), ] ) + del leftover_seeds + del lz + del lui minibatch_dict = self.sample_batches( seeds=current_seeds,