Skip to content

Commit

Permalink
initial write in cugraph-pyg
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 20, 2024
1 parent 61fb2d3 commit ecf4230
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 4 deletions.
23 changes: 23 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/loader/link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,29 @@ def __init__(
input_type=input_type,
)

# Edge label check from torch_geometric.loader.LinkLoader
if (
neg_sampling is not None
and neg_sampling.is_binary()
and edge_label is not None
and edge_label.min() == 0
):
edge_label = edge_label + 1

if (
neg_sampling is not None
and neg_sampling.is_triplet()
and edge_label is not None
):
raise ValueError(
"'edge_label' needs to be undefined for "
"'triplet'-based negative sampling. Please use "
"`src_index`, `dst_pos_index` and "
"`neg_pos_index` of the returned mini-batch "
"instead to differentiate between positive and "
"negative samples."
)

self.__data = data

self.__link_sampler = link_sampler
Expand Down
19 changes: 16 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from cugraph.utilities.utils import import_optional
from cugraph.gnn import DistSampler

from .sampler_utils import filter_cugraph_pyg_store
from .sampler_utils import filter_cugraph_pyg_store, neg_sample

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")
Expand Down Expand Up @@ -467,11 +467,24 @@ def sample_from_edges(
"torch_geometric.sampler.SamplerOutput",
]
]:
src = index.row
dst = index.col
if neg_sampling:
raise NotImplementedError("negative sampling is currently unsupported")
# TODO handle temporal sampling (node_time)
src_neg, dst_neg = neg_sample(
self.__graph_store,
index.row,
index.col,
neg_sampling,
None, # src_time,
None, # src_node_time,
)
if neg_sampling.is_binary():
src = torch.cat([src, src_neg], dim=0)
dst = torch.cat([dst, dst_neg], dim=0)

reader = self.__sampler.sample_from_edges(
torch.stack([index.row, index.col]), # reverse of usual convention
torch.stack([src, dst]), # reverse of usual convention
input_id=index.input_id,
batch_size=self.__batch_size,
**kwargs,
Expand Down
46 changes: 45 additions & 1 deletion python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

from typing import Sequence, Dict, Tuple

from cugraph_pyg.data import DaskGraphStore
from math import ceil

from cugraph_pyg.data import GraphStore, DaskGraphStore

from cugraph.utilities.utils import import_optional
import cudf
import cupy
import pylibcugraph

dask_cudf = import_optional("dask_cudf")
torch_geometric = import_optional("torch_geometric")
Expand Down Expand Up @@ -429,3 +433,43 @@ def filter_cugraph_pyg_store(
data[attr.attr_name] = tensors[i]

return data


def neg_sample(
graph_store: GraphStore,
seed_src: "torch.Tensor",
seed_dst: "torch.Tensor",
neg_sampling: "torch_geometric.sampler.NegativeSampling",
time: "torch.Tensor",
node_time: "torch.Tensor",
) -> Tuple["torch.Tensor", "torch.Tensor"]:
unweighted = neg_sampling.src_weight is None and neg_sampling.dst_weight is None

num_neg = int(ceil(neg_sampling.amount * seed_src.numel()))

if node_time is None:
result_dict = pylibcugraph.negative_sample(
graph_store._resource_handle,
graph_store._graph,
num_neg,
vertices=None
if unweighted
else cupy.arange(neg_sampling.src_weight.numel(), dtype="int64"),
src_bias=None
if neg_sampling.src_weight is None
else cupy.asarray(neg_sampling.src_weight),
dst_bias=None
if neg_sampling.dst_weight is None
else cupy.asarray(neg_sampling.dst_weight),
remove_duplicates=False,
remove_false_negatives=False,
exact_number_of_samples=True,
do_expensive_check=False,
)
return torch.as_tensor(result_dict["sources"], device="cuda"), torch.as_tensor(
result_dict["destinations"], device="cuda"
)

raise NotImplementedError(
"Temporal negative sampling is currently unimplemented in cuGraph-PyG"
)

0 comments on commit ecf4230

Please sign in to comment.