Skip to content

Commit

Permalink
initial write
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 30, 2024
1 parent 6553464 commit 5f21d2b
Show file tree
Hide file tree
Showing 5 changed files with 560 additions and 32 deletions.
169 changes: 169 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/loader/link_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import cugraph_pyg
from typing import Union, Tuple, Callable, Optional

from cugraph.utilities.utils import import_optional

torch_geometric = import_optional("torch_geometric")
torch = import_optional("torch")


class LinkLoader:
"""
Duck-typed version of torch_geometric.loader.LinkLoader.
Loads samples from batches of input nodes using a
`~cugraph_pyg.sampler.BaseSampler.sample_from_edges`
function.
"""

def __init__(
self,
data: Union[
"torch_geometric.data.Data",
"torch_geometric.data.HeteroData",
Tuple[
"torch_geometric.data.FeatureStore", "torch_geometric.data.GraphStore"
],
],
link_sampler: "cugraph_pyg.sampler.BaseSampler",
edge_label_index: "torch_geometric.typing.InputEdges" = None,
edge_label: "torch_geometric.typing.OptTensor" = None,
edge_label_time: "torch_geometric.typing.OptTensor" = None,
neg_sampling: Optional["torch_geometric.sampler.NegativeSampling"] = None,
neg_sampling_ratio: Optional[Union[int, float]] = None,
transform: Optional[Callable] = None,
transform_sampler_output: Optional[Callable] = None,
filter_per_worker: Optional[bool] = None,
custom_cls: Optional["torch_geometric.data.HeteroData"] = None,
input_id: "torch_geometric.typing.OptTensor" = None,
batch_size: int = 1, # refers to number of edges in batch
shuffle: bool = False,
drop_last: bool = False,
**kwargs,
):
"""
Parameters
----------
data: Data, HeteroData, or Tuple[FeatureStore, GraphStore]
See torch_geometric.loader.NodeLoader.
link_sampler: BaseSampler
See torch_geometric.loader.LinkLoader.
edge_label_index: InputEdges
See torch_geometric.loader.LinkLoader.
edge_label: OptTensor
See torch_geometric.loader.LinkLoader.
edge_label_time: OptTensor
See torch_geometric.loader.LinkLoader.
neg_sampling: Optional[NegativeSampling]
Type of negative sampling to perform, if desired.
See torch_geometric.loader.LinkLoader.
neg_sampling_ratio: Optional[Union[int, float]]
Negative sampling ratio. Affects how many negative
samples are generated.
See torch_geometric.loader.LinkLoader.
transform: Callable (optional, default=None)
This argument currently has no effect.
transform_sampler_output: Callable (optional, default=None)
This argument currently has no effect.
filter_per_worker: bool (optional, default=False)
This argument currently has no effect.
custom_cls: HeteroData
This argument currently has no effect. This loader will
always return a Data or HeteroData object.
input_id: OptTensor
See torch_geometric.loader.LinkLoader.
"""
if not isinstance(data, (list, tuple)) or not isinstance(
data[1], cugraph_pyg.data.GraphStore
):
# Will eventually automatically convert these objects to cuGraph objects.
raise NotImplementedError("Currently can't accept non-cugraph graphs")

if not isinstance(link_sampler, cugraph_pyg.sampler.BaseSampler):
raise NotImplementedError("Must provide a cuGraph sampler")

if edge_label_time is not None:
raise ValueError("Temporal sampling is currently unsupported")

if filter_per_worker:
warnings.warn("filter_per_worker is currently ignored")

if custom_cls is not None:
warnings.warn("custom_cls is currently ignored")

if transform is not None:
warnings.warn("transform is currently ignored.")

if transform_sampler_output is not None:
warnings.warn("transform_sampler_output is currently ignored.")

(
input_type,
edge_label_index,
) = torch_geometric.loader.utils.get_edge_label_index(
data,
edge_label_index,
)

self.__input_data = torch_geometric.sampler.EdgeSamplerInput(
input_id=torch.arange(
edge_label_index.shape[-1], dtype=torch.int64, device="cuda"
)
if input_id is None
else input_id,
row=edge_label_index[0],
col=edge_label_index[1],
label=edge_label,
time=edge_label_time,
input_type=input_type,
)

self.__data = data

self.__link_sampler = link_sampler

self.__batch_size = batch_size
self.__shuffle = shuffle
self.__drop_last = drop_last

def __iter__(self):
if self.__shuffle:
perm = torch.randperm(self.__input_data.row.numel())
else:
perm = torch.arange(self.__input_data.row.numel())

if self.__drop_last:
d = perm.numel() % self.__batch_size
perm = perm[:-d]

input_data = torch_geometric.loader.node_loader.EdgeSamplerInput(
input_id=self.__input_data.input_id[perm],
row=self.__input_data.row[perm],
col=self.__input_data.col[perm],
label=None
if self.__input_data.label is None
else self.__input_data.label[perm],
time=None
if self.__input_data.time is None
else self.__input_data.time[perm],
input_type=self.__input_data.input_type,
)

return cugraph_pyg.sampler.SampleIterator(
self.__data, self.__link_sampler.sample_from_edges(input_data)
)
10 changes: 5 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def __init__(
input_id,
)

self.__input_data = torch_geometric.loader.node_loader.NodeSamplerInput(
input_id=input_id,
self.__input_data = torch_geometric.sampler.NodeSamplerInput(
input_id=torch.arange(len(input_nodes), dtype=torch.int64, device="cuda")
if input_id is None
else input_id,
node=input_nodes,
time=None,
input_type=input_type,
Expand All @@ -136,9 +138,7 @@ def __iter__(self):
perm = perm[:-d]

input_data = torch_geometric.loader.node_loader.NodeSamplerInput(
input_id=None
if self.__input_data.input_id is None
else self.__input_data.input_id[perm],
input_id=self.__input_data.input_id[perm],
node=self.__input_data.node[perm],
time=None
if self.__input_data.time is None
Expand Down
83 changes: 79 additions & 4 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,20 @@ def __next__(self):
data.num_sampled_nodes = next_sample.num_sampled_nodes
data.num_sampled_edges = next_sample.num_sampled_edges

data.input_id = data.batch
data.seed_time = None
data.input_id = next_sample.metadata[0]
data.batch_size = data.input_id.size(0)

if len(next_sample.metadata) == 2:
data.seed_time = next_sample.metadata[1]
elif len(next_sample.metadata) == 4:
(
data.edge_label_index,
data.edge_label,
data.seed_time,
) = next_sample.metadata[1:]
else:
raise ValueError("Invalid metadata")

elif isinstance(next_sample, torch_geometric.sampler.HeteroSamplerOutput):
col = {}
for edge_type, col_idx in next_sample.col:
Expand Down Expand Up @@ -175,6 +185,9 @@ def __next__(self):
self.__base_reader
)

self.__raw_sample_data["input_offsets"] -= self.__raw_sample_data[
"input_offsets"
][0].clone()
self.__raw_sample_data["label_hop_offsets"] -= self.__raw_sample_data[
"label_hop_offsets"
][0].clone()
Expand Down Expand Up @@ -266,6 +279,37 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
[num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)]
)

input_index = raw_sample_data["input_index"][
raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][
index + 1
]
]

edge_inverse = (
(
raw_sample_data["edge_inverse"][
(raw_sample_data["input_offsets"][index] * 2) : (
raw_sample_data["input_offsets"][index + 1] * 2
)
]
)
if "edge_inverse" in raw_sample_data
else None
)

if edge_inverse is None:
metadata = (
input_index,
None, # TODO this will eventually include time
)
else:
metadata = (
input_index,
edge_inverse.view(2, -1),
None,
None, # TODO this will eventually include time
)

return torch_geometric.sampler.SamplerOutput(
node=renumber_map.cpu(),
row=minors,
Expand All @@ -274,6 +318,7 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
batch=renumber_map[:num_seeds],
num_sampled_nodes=num_sampled_nodes.cpu(),
num_sampled_edges=num_sampled_edges.cpu(),
metadata=metadata,
)

def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
Expand Down Expand Up @@ -319,6 +364,12 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
[num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)]
)

input_index = raw_sample_data["input_index"][
raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][
index + 1
]
]

return torch_geometric.sampler.SamplerOutput(
node=renumber_map.cpu(),
row=minors,
Expand All @@ -327,6 +378,10 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
batch=renumber_map[:num_seeds],
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=(
input_index,
None, # TODO this will eventually include time
),
)

def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
Expand Down Expand Up @@ -358,7 +413,7 @@ def sample_from_nodes(
]
]:
reader = self.__sampler.sample_from_nodes(
index.node, batch_size=self.__batch_size, **kwargs
index.node, batch_size=self.__batch_size, input_id=index.input_id, **kwargs
)

edge_attrs = self.__graph_store.get_all_edge_attrs()
Expand All @@ -385,4 +440,24 @@ def sample_from_edges(
"torch_geometric.sampler.SamplerOutput",
]
]:
raise NotImplementedError("Edge sampling is currently unimplemented.")
if neg_sampling:
raise NotImplementedError("negative sampling is currently unsupported")

reader = self.__sampler.sample_from_edges(
torch.stack([index.row, index.col]), # reverse of usual convention
input_id=index.input_id,
**kwargs,
)

edge_attrs = self.__graph_store.get_all_edge_attrs()
if (
len(edge_attrs) == 1
and edge_attrs[0].edge_type[0] == edge_attrs[0].edge_type[2]
):
return HomogeneousSampleReader(reader)
else:
# TODO implement heterogeneous sampling
raise NotImplementedError(
"Sampling heterogeneous graphs is currently"
" unsupported in the non-dask API"
)
22 changes: 22 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/dist_io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def __write_minibatches_coo(self, minibatch_dict):
batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
start_batch_id = batch_id_array_p[0]

input_offsets_p = minibatch_dict["input_offsets"][
partition_start : (partition_end + 1)
]
input_index_p = minibatch_dict[input_offsets_p[0] : input_offsets_p[-1]]

start_ix, end_ix = label_hop_offsets_array_p[[0, -1]]
majors_array_p = minibatch_dict["majors"][start_ix:end_ix]
minors_array_p = minibatch_dict["minors"][start_ix:end_ix]
Expand Down Expand Up @@ -151,6 +156,8 @@ def __write_minibatches_coo(self, minibatch_dict):
"edge_id": edge_id_array_p,
"edge_type": edge_type_array_p,
"renumber_map_offsets": renumber_map_offsets_array_p,
"input_index": input_index_p,
"input_offsets": input_offsets_p,
}
)

Expand Down Expand Up @@ -201,6 +208,18 @@ def __write_minibatches_csr(self, minibatch_dict):
batch_id_array_p = minibatch_dict["batch_id"][partition_start:partition_end]
start_batch_id = batch_id_array_p[0]

input_offsets_p = minibatch_dict["input_offsets"][
partition_start : (partition_end + 1)
]
input_index_p = minibatch_dict[input_offsets_p[0] : input_offsets_p[-1]]
edge_inverse_p = (
minibatch_dict["edge_inverse"][
(input_offsets_p[0] * 2) : (input_offsets_p[-1] * 2)
]
if "edge_inverse" in minibatch_dict
else None
)

# major offsets and minors
(
major_offsets_start_incl,
Expand Down Expand Up @@ -255,6 +274,9 @@ def __write_minibatches_csr(self, minibatch_dict):
"edge_id": edge_id_array_p,
"edge_type": edge_type_array_p,
"renumber_map_offsets": renumber_map_offsets_array_p,
"input_index": input_index_p,
"input_offsets": input_offsets_p,
"edge_inverse": edge_inverse_p,
}
)

Expand Down
Loading

0 comments on commit 5f21d2b

Please sign in to comment.