Skip to content

Commit

Permalink
link pred
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 3, 2024
1 parent 5f21d2b commit 6266e45
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 15 deletions.
30 changes: 26 additions & 4 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,31 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
]
]

edge_inverse = (
(
raw_sample_data["edge_inverse"][
(raw_sample_data["input_offsets"][index] * 2) : (
raw_sample_data["input_offsets"][index + 1] * 2
)
]
)
if "edge_inverse" in raw_sample_data
else None
)

if edge_inverse is None:
metadata = (
input_index,
None, # TODO this will eventually include time
)
else:
metadata = (
input_index,
edge_inverse.view(2, -1),
None,
None, # TODO this will eventually include time
)

return torch_geometric.sampler.SamplerOutput(
node=renumber_map.cpu(),
row=minors,
Expand All @@ -378,10 +403,7 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
batch=renumber_map[:num_seeds],
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=(
input_index,
None, # TODO this will eventually include time
),
metadata=metadata,
)

def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
Expand Down
8 changes: 8 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ def __write_minibatches_coo(self, minibatch_dict):
partition_start : (partition_end + 1)
]
input_index_p = minibatch_dict[input_offsets_p[0] : input_offsets_p[-1]]
edge_inverse_p = (
minibatch_dict["edge_inverse"][
(input_offsets_p[0] * 2) : (input_offsets_p[-1] * 2)
]
if "edge_inverse" in minibatch_dict
else None
)

start_ix, end_ix = label_hop_offsets_array_p[[0, -1]]
majors_array_p = minibatch_dict["majors"][start_ix:end_ix]
Expand Down Expand Up @@ -158,6 +165,7 @@ def __write_minibatches_coo(self, minibatch_dict):
"renumber_map_offsets": renumber_map_offsets_array_p,
"input_index": input_index_p,
"input_offsets": input_offsets_p,
"edge_inverse": edge_inverse_p,
}
)

Expand Down
42 changes: 31 additions & 11 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,17 +464,31 @@ def __sample_from_edges_func(
# The returned unique values must be sorted or else the inverse won't line up
# In the future this may be a good target for a C++ function
# Each element is a tuple of (unique, index, inverse)
# TODO make sure this is compatible with negative sampling
u = [
torch.unique(
# The seeds must be presorted with a stable sort prior to calling
# unique_consecutive in order to support negative sampling. This is
# because if we put positive edges after negative ones, then we may
# inadvertently turn a true positive into a false negative.
y = (
torch.sort(
t,
return_inverse=True,
sorted=True,
stable=True,
)
for t in current_seeds
)
z = ((v, torch.sort(i)[1]) for v, i in y)

u = [
(
torch.unique_consecutive(
t,
return_inverse=True,
),
i,
)
for t, i in z
]
current_seeds = torch.concat([a[0] for a in u])
current_inv = torch.concat([a[1] for a in u])
current_seeds = torch.concat([a[0] for a, _ in u])
current_inv = torch.concat([a[1][i] for a, i in u])
current_batches = torch.concat(
[
torch.full(
Expand All @@ -489,11 +503,14 @@ def __sample_from_edges_func(
del u

# Join with the leftovers
# TODO make sure this is compatible with negative sampling
leftover_seeds, leftover_inv = leftover_seeds.flatten().unique(
return_inverse=True,
sorted=True,
leftover_seeds, lyi = torch.sort(
leftover_seeds.flatten(),
stable=True,
)
lz = torch.sort(lyi)[1]
leftover_seeds, lui = leftover_seeds.unique_consecutive(return_inverse=True)
leftover_inv = lui[lz]

current_seeds = torch.concat([current_seeds, leftover_seeds])
current_inv = torch.concat([current_inv, leftover_inv])
current_batches = torch.concat(
Expand All @@ -507,6 +524,9 @@ def __sample_from_edges_func(
),
]
)
del leftover_seeds
del lz
del lui

minibatch_dict = self.sample_batches(
seeds=current_seeds,
Expand Down

0 comments on commit 6266e45

Please sign in to comment.