Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Nov 22, 2024
1 parent 35af4b4 commit 8da5c95
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 49 deletions.
35 changes: 35 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __clear_graph(self):
self.__graph = None
self.__vertex_offsets = None
self.__weight_attr = None
self.__numeric_edge_types = None

def _put_edge_index(
self,
Expand Down Expand Up @@ -173,6 +174,7 @@ def _graph(self) -> Union[pylibcugraph.SGGraph, pylibcugraph.MGGraph]:
else None,
)
else:
print(edgelist_dict)
self.__graph = pylibcugraph.SGGraph(
self._resource_handle,
graph_properties,
Expand Down Expand Up @@ -270,6 +272,39 @@ def __get_weight_tensor(

return torch.concat(weights)

@property
def _numeric_edge_types(self) -> Tuple[List, "torch.Tensor", "torch.Tensor"]:
"""
Returns the canonical edge types in order (the 0th canonical type corresponds
to numeric edge type 0, etc.), along with the numeric source and destination
vertex types for each edge type.
"""

if self.__numeric_edge_types is None:
sorted_keys = sorted(
list(self.__edge_indices.keys(leaves_only=True, include_nested=True))
)

vtype_table = {
k: i
for i, k in enumerate(sorted(self._vertex_offsets.keys()))
}

srcs = []
dsts = []

for can_etype in sorted_keys:
srcs.append(
vtype_table[can_etype[0]]
)
dsts.append(
vtype_table[can_etype[2]]
)

self.__numeric_edge_types = (sorted_keys, torch.tensor(srcs,device='cuda',dtype=torch.int32), torch.tensor(dsts,device='cuda',dtype=torch.int32))

return self.__numeric_edge_types

def __get_edgelist(self):
"""
Returns
Expand Down
4 changes: 4 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/loader/link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def __init__(
(None, edge_label_index),
)

# Note reverse of standard convention here
edge_label_index[0] += data[1]._vertex_offsets[input_type[0]]
edge_label_index[1] += data[1]._vertex_offsets[input_type[2]]

self.__input_data = torch_geometric.sampler.EdgeSamplerInput(
input_id=torch.arange(
edge_label_index[0].numel(), dtype=torch.int64, device="cuda"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def __init__(
with_replacement=replace,
local_seeds_per_call=local_seeds_per_call,
biased=(weight_attr is not None),
heterogeneous=(not graph_store.is_homogeneous)
heterogeneous=(not graph_store.is_homogeneous),
num_edge_types=len(graph_store.get_all_edge_attrs()),
),
(feature_store, graph_store),
batch_size=batch_size,
Expand Down
3 changes: 2 additions & 1 deletion python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def __init__(
with_replacement=replace,
local_seeds_per_call=local_seeds_per_call,
biased=(weight_attr is not None),
heterogeneous=(not graph_store.is_homogeneous)
heterogeneous=(not graph_store.is_homogeneous),
num_edge_types=len(graph_store.get_all_edge_attrs()),
),
(feature_store, graph_store),
batch_size=batch_size,
Expand Down
2 changes: 2 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __init__(
input_id,
)

input_nodes += data[1]._vertex_offsets[input_type]

self.__input_data = torch_geometric.sampler.NodeSamplerInput(
input_id=torch.arange(len(input_nodes), dtype=torch.int64, device="cuda")
if input_id is None
Expand Down
125 changes: 78 additions & 47 deletions python/cugraph-pyg/cugraph_pyg/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Iterator, Union, Dict, Tuple
from typing import Optional, Iterator, Union, Dict, Tuple, List

from cugraph.utilities.utils import import_optional
from cugraph.gnn import DistSampler
Expand Down Expand Up @@ -189,12 +189,14 @@ def __next__(self):
self.__raw_sample_data, start_inclusive, end_inclusive = next(
self.__base_reader
)
print(self.__raw_sample_data)
lho_name = "label_type_hop_offsets" if "label_type_hop_offsets" in self.__raw_sample_data else "label_type_hop_offsets"

self.__raw_sample_data["input_offsets"] -= self.__raw_sample_data[
"input_offsets"
][0].clone()
self.__raw_sample_data["label_hop_offsets"] -= self.__raw_sample_data[
"label_hop_offsets"
self.__raw_sample_data[lho_name] -= self.__raw_sample_data[
lho_name
][0].clone()
self.__raw_sample_data["renumber_map_offsets"] -= self.__raw_sample_data[
"renumber_map_offsets"
Expand Down Expand Up @@ -223,7 +225,7 @@ class HeterogeneousSampleReader(SampleReader):
"""

def __init__(
self, base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]
self, base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]], src_types: "torch.Tensor", dst_types: "torch.Tensor", edge_types: List[Tuple[str, str, str]]
):
"""
Constructs a new HeterogeneousSampleReader
Expand All @@ -233,51 +235,71 @@ def __init__(
base_reader: Iterator[Tuple[Dict[str, "torch.Tensor"], int, int]]
The iterator responsible for loading saved samples produced by
the cuGraph distributed sampler.
src_types: torch.Tensor
Integer source type for each integer edge type.
dst_types: torch.Tensor
Integer destination type for each integer edge type.
edge_types: List[Tuple[str, str, str]]
List of edge types in the graph in order, so they can be
mapped to numeric edge types.
"""

self.__src_types = src_types
self.__dst_types = dst_types
self.__edge_types = edge_types
self.__num_vertex_types = max(self.__src_types.max(), self.__dst_types.max()) + 1

super().__init__(base_reader)


def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
fanout_length = raw_sample_data['fanout'].numel()
num_edge_types = raw_sample_data['edge_type'].max() + 1

major_minor_start = raw_sample_data["label_hop_offsets"][index * num_edge_types * fanout_length]
ix_end = (index + 1) * fanout_length
if ix_end == raw_sample_data["label_hop_offsets"].numel():
major_minor_end = raw_sample_data["majors"].numel()
else:
major_minor_end = raw_sample_data["label_hop_offsets"][ix_end]
num_edge_types = self.__src_types.numel()

num_sampled_edges = {}
node = {}
row = {}
col = {}
edge = {}
for etype in range(num_edge_types):
pyg_can_etype = self.__edge_types[etype]

jx = self.__src_types[etype] + index * self.__num_vertex_types
map_ptr_src_beg, map_ptr_src_end = raw_sample_data["renumber_map_offsets"][
[jx, jx + 1]
]
map_src = raw_sample_data["renumber_map"][map_ptr_src_beg:map_ptr_src_end]
node[pyg_can_etype[0]] = map_src.cpu()

majors = raw_sample_data["majors"][major_minor_start:major_minor_end]
minors = raw_sample_data["minors"][major_minor_start:major_minor_end]
edge_id = raw_sample_data["edge_id"][major_minor_start:major_minor_end]
edge_type = raw_sample_data['edge_type'][major_minor_start:major_minor_end]
kx = self.__dst_types[etype] + index * self.__num_vertex_types
map_ptr_dst_beg, map_ptr_dst_end = raw_sample_data["renumber_map_offsets"][
[kx, kx + 1]
]
map_dst = raw_sample_data["renumber_map"][map_ptr_dst_beg:map_ptr_dst_end]
node[pyg_can_etype[2]] = map_dst.cpu()

renumber_map_start = raw_sample_data["renumber_map_offsets"][index]
renumber_map_end = raw_sample_data["renumber_map_offsets"][index + 1]
edge_ptr_beg = index * num_edge_types * fanout_length + etype * fanout_length
edge_ptr_end = index * num_edge_types * fanout_length + (etype+1) * fanout_length
lho = raw_sample_data['label_type_hop_offsets'][
edge_ptr_beg:edge_ptr_end
]

renumber_map = raw_sample_data["map"][renumber_map_start:renumber_map_end]
num_sampled_edges[pyg_can_etype] = (
lho
).diff().cpu()

num_sampled_edges = (
raw_sample_data["label_hop_offsets"][
index * fanout_length : (index + 1) * fanout_length + 1
eid_i = raw_sample_data["edge_id"][edge_ptr_beg:edge_ptr_end]
eirx = (index * num_edge_types) + etype
edge_id_ptr_beg, edge_id_ptr_end = raw_sample_data["edge_renumber_map_offsets"][
[eirx, eirx + 1]
]
.diff()
.cpu()
)
emap = raw_sample_data["edge_renumber_map"][edge_id_ptr_beg:edge_id_ptr_end]
edge[pyg_can_etype] = emap[eid_i]

num_seeds = (majors[: num_sampled_edges[0]].max() + 1).reshape((1,)).cpu()
num_sampled_nodes_hops = torch.tensor(
[
minors[: num_sampled_edges[:i].sum()].max() + 1
for i in range(1, fanout_length + 1)
],
device="cpu",
)
col[pyg_can_etype] = raw_sample_data['majors'][edge_ptr_beg:edge_ptr_end]
row[pyg_can_etype] = raw_sample_data['minors'][edge_ptr_beg:edge_ptr_end]

num_sampled_nodes = torch.concat(
[num_seeds, num_sampled_nodes_hops.diff(prepend=num_seeds)]
)
num_sampled_nodes = {}

input_index = raw_sample_data["input_index"][
raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][
Expand Down Expand Up @@ -310,12 +332,12 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int):
None, # TODO this will eventually include time
)

return torch_geometric.sampler.SamplerOutput(
node=renumber_map.cpu(),
row=minors,
col=majors,
edge=edge_id,
batch=renumber_map[:num_seeds],
return torch_geometric.sampler.HeteroSamplerOutput(
node=node,
row=row,
col=col,
edge=edge,
batch=None,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
metadata=metadata,
Expand Down Expand Up @@ -577,7 +599,14 @@ def sample_from_nodes(
):
return HomogeneousSampleReader(reader)
else:
return HeterogeneousSampleReader(reader)
edge_types,src_types,dst_types = self.__graph_store._numeric_edge_types

return HeterogeneousSampleReader(
reader,
src_types=src_types,
dst_types=dst_types,
edge_types=edge_types,
)

def sample_from_edges(
self,
Expand Down Expand Up @@ -641,8 +670,10 @@ def sample_from_edges(
):
return HomogeneousSampleReader(reader)
else:
# TODO implement heterogeneous sampling
raise NotImplementedError(
"Sampling heterogeneous graphs is currently"
" unsupported in the non-dask API"
edge_types,src_types,dst_types = self.__graph_store._numeric_edge_types
return HeterogeneousSampleReader(
reader,
src_types=src_types,
dst_types=dst_types,
edge_types=edge_types,
)

0 comments on commit 8da5c95

Please sign in to comment.