Skip to content

Commit

Permalink
heterogeneous sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Nov 18, 2024
1 parent d260ccb commit 35af4b4
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 11 deletions.
14 changes: 11 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,19 @@ def __init__(
# Will eventually automatically convert these objects to cuGraph objects.
raise NotImplementedError("Currently can't accept non-cugraph graphs")


feature_store, graph_store = data

if compression is None:
compression = "CSR"
compression = "CSR" if graph_store.is_homogeneous else 'COO'
elif compression not in ["CSR", "COO"]:
raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')")

if (not graph_store.is_homogeneous):
if compression != 'COO':
raise ValueError("Only COO format is supported for heterogeneous graphs!")
if directory is not None:
raise ValueError("Writing to disk is not supported for heterogeneous graphs!")

writer = (
None
Expand All @@ -203,8 +212,6 @@ def __init__(
)
)

feature_store, graph_store = data

if weight_attr is not None:
graph_store._set_weight_attr((feature_store, weight_attr))

Expand All @@ -221,6 +228,7 @@ def __init__(
with_replacement=replace,
local_seeds_per_call=local_seeds_per_call,
biased=(weight_attr is not None),
heterogeneous=(not graph_store.is_homogeneous)
),
(feature_store, graph_store),
batch_size=batch_size,
Expand Down
13 changes: 10 additions & 3 deletions python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,19 @@ def __init__(
# Will eventually automatically convert these objects to cuGraph objects.
raise NotImplementedError("Currently can't accept non-cugraph graphs")

feature_store, graph_store = data

if compression is None:
compression = "CSR"
compression = "CSR" if graph_store.is_homogeneous else 'COO'
elif compression not in ["CSR", "COO"]:
raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')")

if (not graph_store.is_homogeneous):
if compression != 'COO':
raise ValueError("Only COO format is supported for heterogeneous graphs!")
if directory is not None:
raise ValueError("Writing to disk is not supported for heterogeneous graphs!")

writer = (
None
if directory is None
Expand All @@ -196,8 +204,6 @@ def __init__(
)
)

feature_store, graph_store = data

if weight_attr is not None:
graph_store._set_weight_attr((feature_store, weight_attr))

Expand All @@ -214,6 +220,7 @@ def __init__(
with_replacement=replace,
local_seeds_per_call=local_seeds_per_call,
biased=(weight_attr is not None),
heterogeneous=(not graph_store.is_homogeneous)
),
(feature_store, graph_store),
batch_size=batch_size,
Expand Down
118 changes: 113 additions & 5 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,118 @@ def __iter__(self):
return self


class HeterogeneousSampleReader(SampleReader):
"""
Subclass of SampleReader that reads heterogeneous output samples
produced by the cuGraph distributed sampler.
"""

def __init__(
self, base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]
):
"""
Constructs a new HeterogeneousSampleReader
Parameters
----------
base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]
The iterator responsible for loading saved samples produced by
the cuGraph distributed sampler.
"""
super().__init__(base_reader)


def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
fanout_length = raw_sample_data['fanout'].numel()
num_edge_types = raw_sample_data['edge_type'].max() + 1

major_minor_start = raw_sample_data["label_hop_offsets"][index * num_edge_types * fanout_length]
ix_end = (index + 1) * fanout_length
if ix_end == raw_sample_data["label_hop_offsets"].numel():
major_minor_end = raw_sample_data["majors"].numel()
else:
major_minor_end = raw_sample_data["label_hop_offsets"][ix_end]

majors = raw_sample_data["majors"][major_minor_start:major_minor_end]
minors = raw_sample_data["minors"][major_minor_start:major_minor_end]
edge_id = raw_sample_data["edge_id"][major_minor_start:major_minor_end]
edge_type = raw_sample_data['edge_type'][major_minor_start:major_minor_end]

renumber_map_start = raw_sample_data["renumber_map_offsets"][index]
renumber_map_end = raw_sample_data["renumber_map_offsets"][index + 1]

renumber_map = raw_sample_data["map"][renumber_map_start:renumber_map_end]

num_sampled_edges = (
raw_sample_data["label_hop_offsets"][
index * fanout_length : (index + 1) * fanout_length + 1
]
.diff()
.cpu()
)

num_seeds = (majors[: num_sampled_edges[0]].max() + 1).reshape((1,)).cpu()
num_sampled_nodes_hops = torch.tensor(
[
minors[: num_sampled_edges[:i].sum()].max() + 1
for i in range(1, fanout_length + 1)
],
device="cpu",
)

num_sampled_nodes = torch.concat(
[num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)]
)

input_index = raw_sample_data["input_index"][
raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][
index + 1
]
]

edge_inverse = (
(
raw_sample_data["edge_inverse"][
(raw_sample_data["input_offsets"][index] * 2) : (
raw_sample_data["input_offsets"][index + 1] * 2
)
]
)
if "edge_inverse" in raw_sample_data
else None
)

if edge_inverse is None:
metadata = (
input_index,
None, # TODO this will eventually include time
)
else:
metadata = (
input_index,
edge_inverse.view(2, -1),
None,
None, # TODO this will eventually include time
)

return torch_geometric.sampler.SamplerOutput(
node=renumber_map.cpu(),
row=minors,
col=majors,
edge=edge_id,
batch=renumber_map[:num_seeds],
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
)

def _decode(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
if "major_offsets" in raw_sample_data:
raise ValueError("CSR format not currently supported for heterogeneous graphs")
else:
return self.__decode_coo(raw_sample_data, index)


class HomogeneousSampleReader(SampleReader):
"""
Subclass of SampleReader that reads homogeneous output samples
Expand Down Expand Up @@ -465,11 +577,7 @@ def sample_from_nodes(
):
return HomogeneousSampleReader(reader)
else:
# TODO implement heterogeneous sampling
raise NotImplementedError(
"Sampling heterogeneous graphs is currently"
" unsupported in the non-dask API"
)
return HeterogeneousSampleReader(reader)

def sample_from_edges(
self,
Expand Down

0 comments on commit 35af4b4

Please sign in to comment.