From 0f1a14439842f7e0617e98c669c41938ce375964 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Wed, 27 Sep 2023 15:02:35 -0700 Subject: [PATCH] initial rewrite --- .../cugraph_pyg/data/cugraph_store.py | 36 ++--- .../cugraph_pyg/loader/cugraph_node_loader.py | 153 ++++++++++++------ .../cugraph_pyg/sampler/cugraph_sampler.py | 115 ++++++++++++- .../cugraph_pyg/tests/test_cugraph_loader.py | 20 +-- 4 files changed, 243 insertions(+), 81 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index e0d318adbe0..15379d2ab69 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -819,8 +819,8 @@ def _get_renumbered_edge_groups_from_sample( before this one to get the noi_index. Example Input: Series({ - 'sources': [0, 5, 11, 3], - 'destinations': [8, 2, 3, 5]}, + 'majors': [0, 5, 11, 3], + 'minors': [8, 2, 3, 5]}, 'edge_type': [1, 3, 5, 14] }), { @@ -865,23 +865,23 @@ def _get_renumbered_edge_groups_from_sample( index=cupy.asarray(id_table), ).sort_index() - # Renumber the sources using binary search + # Renumber the majors using binary search # Step 1: get the index of the new id ix_r = torch.searchsorted( torch.as_tensor(id_map.index.values, device="cuda"), - torch.as_tensor(sampling_results.sources.values, device="cuda"), + torch.as_tensor(sampling_results.majors.values, device="cuda"), ) # Step 2: Go from id indices to actual ids row_dict[t_pyg_type] = torch.as_tensor(id_map.values, device="cuda")[ ix_r ] - # Renumber the destinations using binary search + # Renumber the minors using binary search # Step 1: get the index of the new id ix_c = torch.searchsorted( torch.as_tensor(id_map.index.values, device="cuda"), torch.as_tensor( - sampling_results.destinations.values, device="cuda" + sampling_results.minors.values, device="cuda" ), ) # Step 2: Go from id indices to actual ids @@ -897,7 +897,7 @@ def _get_renumbered_edge_groups_from_sample( "new_id": cupy.arange(dst_id_table.shape[0]), } ).set_index("dst") - dst = dst_id_map["new_id"].loc[sampling_results.destinations] + dst = dst_id_map["new_id"].loc[sampling_results.minors] col_dict[t_pyg_type] = torch.as_tensor(dst.values, device="cuda") src_id_table = noi_index[src_type] @@ -907,7 +907,7 @@ def _get_renumbered_edge_groups_from_sample( "new_id": cupy.arange(src_id_table.shape[0]), } ).set_index("src") - src = src_id_map["new_id"].loc[sampling_results.sources] + src = src_id_map["new_id"].loc[sampling_results.majors] row_dict[t_pyg_type] = torch.as_tensor(src.values, device="cuda") else: @@ -929,12 +929,12 @@ def _get_renumbered_edge_groups_from_sample( else: # CSC dst_type, _, src_type = pyg_can_edge_type - # Get the de-offsetted destinations + # Get the de-offsetted minors dst_num_type = self._numeric_vertex_type_from_name(dst_type) - destinations = torch.as_tensor( - sampling_results.destinations.iloc[ix].values, device="cuda" + minors = torch.as_tensor( + sampling_results.minors.iloc[ix].values, device="cuda" ) - destinations -= self.__vertex_type_offsets["start"][dst_num_type] + minors -= self.__vertex_type_offsets["start"][dst_num_type] # Create the col entry for this type dst_id_table = noi_index[dst_type] @@ -944,15 +944,15 @@ def _get_renumbered_edge_groups_from_sample( .rename(columns={"index": "new_id"}) .set_index("dst") ) - dst = dst_id_map["new_id"].loc[cupy.asarray(destinations)] + dst = dst_id_map["new_id"].loc[cupy.asarray(minors)] col_dict[pyg_can_edge_type] = torch.as_tensor(dst.values, device="cuda") - # Get the de-offsetted sources + # Get the de-offsetted majors src_num_type = self._numeric_vertex_type_from_name(src_type) - sources = torch.as_tensor( - sampling_results.sources.iloc[ix].values, device="cuda" + majors = torch.as_tensor( + sampling_results.majors.iloc[ix].values, device="cuda" ) - sources -= self.__vertex_type_offsets["start"][src_num_type] + majors -= self.__vertex_type_offsets["start"][src_num_type] # Create the row entry for this type src_id_table = noi_index[src_type] @@ -962,7 +962,7 @@ def _get_renumbered_edge_groups_from_sample( .rename(columns={"index": "new_id"}) .set_index("src") ) - src = src_id_map["new_id"].loc[cupy.asarray(sources)] + src = src_id_map["new_id"].loc[cupy.asarray(majors)] row_dict[pyg_can_edge_type] = torch.as_tensor(src.values, device="cuda") return row_dict, col_dict diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index cf7eb330d67..2e1bf6af739 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -25,7 +25,9 @@ from cugraph_pyg.data import CuGraphStore from cugraph_pyg.sampler.cugraph_sampler import ( _sampler_output_from_sampling_results_heterogeneous, - _sampler_output_from_sampling_results_homogeneous, + _sampler_output_from_sampling_results_homogeneous_csr, + _sampler_output_from_sampling_results_homogeneous_coo, + filter_cugraph_store_csc, ) from typing import Union, Tuple, Sequence, List, Dict @@ -58,6 +60,7 @@ def __init__( # Sampler args num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]] = None, replace: bool = True, + compression: str = 'COO', # Other kwargs for the BulkSampler **kwargs, ): @@ -174,6 +177,10 @@ def __init__( with_replacement=replace, batches_per_partition=self.__batches_per_partition, renumber=renumber, + use_legacy_names=False, + deduplicate_sources=True, + prior_sources_behavior='exclude', + include_hop_column = (compression == 'COO'), **kwargs, ) @@ -245,51 +252,77 @@ def __next__(self): fname, ) - columns = { - "sources": "int64", - "destinations": "int64", - # 'edge_id':'int64', - "edge_type": "int32", - "batch_id": "int32", - "hop_id": "int32", - } - raw_sample_data = cudf.read_parquet(parquet_path) + if "map" in raw_sample_data.columns: - num_batches = end_inclusive - self.__start_inclusive + 1 + if "renumber_map_offsets" not in raw_sample_data.columns: + num_batches = end_inclusive - self.__start_inclusive + 1 - map_end = raw_sample_data["map"].iloc[num_batches] + map_end = raw_sample_data["map"].iloc[num_batches] - map = torch.as_tensor( - raw_sample_data["map"].iloc[0:map_end], device="cuda" - ) - raw_sample_data.drop("map", axis=1, inplace=True) + map = torch.as_tensor( + raw_sample_data["map"].iloc[0:map_end], device="cuda" + ) + raw_sample_data.drop("map", axis=1, inplace=True) - self.__renumber_map_offsets = map[0 : num_batches + 1] - map[0] - self.__renumber_map = map[num_batches + 1 :] + self.__renumber_map_offsets = map[0 : num_batches + 1] - map[0] + self.__renumber_map = map[num_batches + 1 :] + else: + self.__renumber_map = raw_sample_data['map'] + self.__renumber_map_offsets = raw_sample_data['renumber_map_offsets'] + raw_sample_data.drop(columns=['map', 'renumber_map_offsets'], inplace=True) + + self.__renumber_map.dropna(inplace=True) + self.__renumber_map = torch.as_tensor(self.__renumber_map, device='cuda') + + self.__renumber_map_offsets.dropna(inplace=True) + self.__renumber_map_offsets = torch.as_tensor(self.__renumber_map_offsets, device='cuda') else: self.__renumber_map = None - self.__data = raw_sample_data[list(columns.keys())].astype(columns) - self.__data.dropna(inplace=True) + self.__data = raw_sample_data + self.__coo = ('majors' in self.__data.columns) + if self.__coo: + self.__data.dropna(inplace=True) if ( len(self.__graph_store.edge_types) == 1 and len(self.__graph_store.node_types) == 1 ): - group_cols = ["batch_id", "hop_id"] - self.__data_index = self.__data.groupby(group_cols, as_index=True).agg( - {"sources": "max", "destinations": "max"} - ) - self.__data_index.rename( - columns={"sources": "src_max", "destinations": "dst_max"}, - inplace=True, - ) - self.__data_index = self.__data_index.to_dict(orient="index") + if self.__coo: + group_cols = ["batch_id", "hop_id"] + self.__data_index = self.__data.groupby(group_cols, as_index=True).agg( + {"majors": "max", "minors": "max"} + ) + self.__data_index.rename( + columns={"majors": "src_max", "minors": "dst_max"}, + inplace=True, + ) + self.__data_index = self.__data_index.to_dict(orient="index") + else: + self.__data_index = None + + self.__label_hop_offsets = self.__data['label_hop_offsets'] + self.__data.drop(columns=['label_hop_offsets'], inplace=True) + self.__label_hop_offsets.dropna(inplace=True) + self.__label_hop_offsets = torch.as_tensor(self.__label_hop_offsets, device='cuda') + self.__label_hop_offsets -= self.__label_hop_offsets[0].clone() + + self.__major_offsets = self.__data['major_offsets'] + self.__data.drop(columns='major_offsets', inplace=True) + self.__major_offsets.dropna(inplace=True) + self.__major_offsets = torch.as_tensor(self.__major_offsets, device='cuda') + self.__major_offsets -= self.__major_offsets[0].clone() + + self.__minors = self.__data['minors'] + self.__data.drop(columns='minors', inplace=True) + self.__minors.dropna(inplace=True) + self.__minors = torch.tensor(self.__minors, device='cuda') # Pull the next set of sampling results out of the dataframe in memory - f = self.__data["batch_id"] == self.__next_batch + if self.__coo: + f = self.__data["batch_id"] == self.__next_batch if self.__renumber_map is not None: i = self.__next_batch - self.__start_inclusive @@ -306,13 +339,29 @@ def __next__(self): len(self.__graph_store.edge_types) == 1 and len(self.__graph_store.node_types) == 1 ): - sampler_output = _sampler_output_from_sampling_results_homogeneous( - self.__data[f], - current_renumber_map, - self.__graph_store, - self.__data_index, - self.__next_batch, - ) + if self.__coo: + sampler_output = _sampler_output_from_sampling_results_homogeneous_coo( + self.__data[f], + current_renumber_map, + self.__graph_store, + self.__data_index, + self.__next_batch, + ) + else: + i = (self.__next_batch - self.__start_inclusive) * self.__fanout_length + current_label_hop_offsets = self.__label_hop_offsets[i : i + self.__fanout_length + 1] + current_major_offsets = self.__major_offsets[current_label_hop_offsets[0] : current_label_hop_offsets[-1]] + current_minors = self.__minors[current_major_offsets[0] : current_major_offsets[-1]] + + sampler_output = _sampler_output_from_sampling_results_homogeneous_csr( + current_major_offsets, + current_minors, + current_renumber_map, + self.__graph_store, + current_label_hop_offsets, + self.__data_index, + self.__next_batch, + ) else: sampler_output = _sampler_output_from_sampling_results_heterogeneous( self.__data[f], current_renumber_map, self.__graph_store @@ -322,17 +371,29 @@ def __next__(self): self.__next_batch += 1 # Create a PyG HeteroData object, loading the required features - out = torch_geometric.loader.utils.filter_custom_store( - self.__feature_store, - self.__graph_store, - sampler_output.node, - sampler_output.row, - sampler_output.col, - sampler_output.edge, - ) + if self.__coo: + out = torch_geometric.loader.utils.filter_custom_store( + self.__feature_store, + self.__graph_store, + sampler_output.node, + sampler_output.row, + sampler_output.col, + sampler_output.edge, + ) + else: + if self.__graph_store.order == "CSR": + raise ValueError("CSR format incompatible with CSC output") + out = filter_cugraph_store_csc( + self.__feature_store, + self.__graph_store, + sampler_output.node, + sampler_output.row, + sampler_output.col, + sampler_output.edge, + ) # Account for CSR format in cuGraph vs. CSC format in PyG - if self.__graph_store.order == "CSC": + if self.__coo and self.__graph_store.order == "CSC": for node_type in out.edge_index_dict: out[node_type].edge_index[0], out[node_type].edge_index[1] = ( out[node_type].edge_index[1], diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py index 6e8c4322418..ab9a8613764 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py @@ -53,10 +53,10 @@ def _get_unique_nodes( The unique nodes of the given node type. """ if node_position == "src": - edge_index = "sources" + edge_index = "majors" edge_sel = 0 elif node_position == "dst": - edge_index = "destinations" + edge_index = "minors" edge_sel = -1 else: raise ValueError(f"Illegal value {node_position} for node_position") @@ -78,7 +78,7 @@ def _get_unique_nodes( return sampling_results_node[edge_index] -def _sampler_output_from_sampling_results_homogeneous( +def _sampler_output_from_sampling_results_homogeneous_coo( sampling_results: cudf.DataFrame, renumber_map: torch.Tensor, graph_store: CuGraphStore, @@ -133,11 +133,11 @@ def _sampler_output_from_sampling_results_homogeneous( noi_index = {node_type: torch.as_tensor(renumber_map, device="cuda")} row_dict = { - edge_type: torch.as_tensor(sampling_results.sources, device="cuda"), + edge_type: torch.as_tensor(sampling_results.majors, device="cuda"), } col_dict = { - edge_type: torch.as_tensor(sampling_results.destinations, device="cuda"), + edge_type: torch.as_tensor(sampling_results.minors, device="cuda"), } num_nodes_per_hop_dict[node_type][0] = data_index[batch_id, 0]["src_max"] + 1 @@ -177,6 +177,78 @@ def _sampler_output_from_sampling_results_homogeneous( ) +def _sampler_output_from_sampling_results_homogeneous_csr( + major_offsets: torch.Tensor, + minors: torch.Tensor, + renumber_map: torch.Tensor, + graph_store: CuGraphStore, + label_hop_offsets: torch.Tensor, + batch_id: int, + metadata: Sequence = None, +) -> HeteroSamplerOutput: + """ + Parameters + ---------- + major_offsets: torch.Tensor + The major offsets for the CSC/CSR matrix ("row pointer") + minors: torch.Tensor + The minors for the CSC/CSR matrix ("col index") + renumber_map: torch.Tensor + The tensor containing the renumber map. + Required. + graph_store: CuGraphStore + The graph store containing the structure of the sampled graph. + label_hop_offsets: torch.Tensor + The tensor containing the label-hop offsets. + batch_id: int + The current batch id, whose samples are being retrieved + from the sampling results and data index. + metadata: Tensor + The metadata for the sampled batch. + + Returns + ------- + HeteroSamplerOutput + """ + + if len(graph_store.edge_types) > 1 or len(graph_store.node_types) > 1: + raise ValueError("Graph is heterogeneous") + + if renumber_map is None: + raise ValueError("Renumbered input is expected for homogeneous graphs") + + node_type = graph_store.node_types[0] + edge_type = graph_store.edge_types[0] + + major_offsets = major_offsets.clone() - major_offsets[0] + + num_nodes_per_hop_dict = {node_type: label_hop_offsets.diff()} + num_edges_per_hop_dict = {edge_type: major_offsets[1:]} + + noi_index = {node_type: torch.as_tensor(renumber_map, device="cuda")} + + col_dict = { + edge_type: major_offsets, + } + + row_dict = { + edge_type: minors, + } + + if HeteroSamplerOutput is None: + raise ImportError("Error importing from pyg") + + return HeteroSamplerOutput( + node=noi_index, + row=row_dict, + col=col_dict, + edge=None, + num_sampled_nodes=num_nodes_per_hop_dict, + num_sampled_edges=num_edges_per_hop_dict, + metadata=metadata, + ) + + def _sampler_output_from_sampling_results_heterogeneous( sampling_results: cudf.DataFrame, renumber_map: cudf.Series, @@ -244,8 +316,8 @@ def _sampler_output_from_sampling_results_heterogeneous( cudf.Series( torch.concat( [ - torch.as_tensor(sampling_results_hop_0.sources, device="cuda"), - torch.as_tensor(sampling_results.destinations, device="cuda"), + torch.as_tensor(sampling_results_hop_0.majors, device="cuda"), + torch.as_tensor(sampling_results.minors, device="cuda"), ] ), name="nodes_of_interest", @@ -320,3 +392,32 @@ def _sampler_output_from_sampling_results_heterogeneous( num_sampled_edges=num_edges_per_hop_dict, metadata=metadata, ) + + +def filter_cugraph_store_csc( + feature_store: torch_geometric.data.FeatureStore, + graph_store: torch_geometric.data.GraphStore, + node_dict: Dict[str, torch.Tensor], + row_dict: Dict[str, torch.Tensor], + col_dict: Dict[str, torch.Tensor], + edge_dict: Dict[str, Tuple[torch.Tensor]], +) -> torch_geometric.data.HeteroData: + data = torch_geometric.data.HeteroData() + + for attr in graph_store.get_all_edge_attrs(): + key = attr.edge_type + if key in row_dict and key in col_dict: + data.put_edge_index((col_dict[key], row_dict[key]), edge_type=key, layout='CSC', is_sorted=True) + + required_attrs = [] + for attr in feature_store.get_all_tensor_attrs(): + if attr.group_name in node_dict: + attr.index = node_dict[attr.group_name] + required_attrs.append(attr) + data[attr.group_name].num_nodes = attr.index.size(0) + + tensors = feature_store.multi_get_tensor(required_attrs) + for i, attr in enumerate(required_attrs): + data[attr.group_name][attr.attr_name] = tensors[i] + + return data \ No newline at end of file diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py index 48a21cb7fd6..fe751b7d417 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_loader.py @@ -98,8 +98,8 @@ def test_cugraph_loader_from_disk(): bogus_samples = cudf.DataFrame( { - "sources": [0, 1, 2, 3, 4, 5, 6, 6], - "destinations": [5, 4, 3, 2, 2, 6, 5, 2], + "majors": [0, 1, 2, 3, 4, 5, 6, 6], + "minors": [5, 4, 3, 2, 2, 6, 5, 2], "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 2], dtype="int32"), @@ -131,11 +131,11 @@ def test_cugraph_loader_from_disk(): assert ( edge_index[0].tolist() - == bogus_samples.sources.dropna().values_host.tolist() + == bogus_samples.majors.dropna().values_host.tolist() ) assert ( edge_index[1].tolist() - == bogus_samples.destinations.dropna().values_host.tolist() + == bogus_samples.minors.dropna().values_host.tolist() ) assert num_samples == 256 @@ -157,8 +157,8 @@ def test_cugraph_loader_from_disk_subset(): bogus_samples = cudf.DataFrame( { - "sources": [0, 1, 2, 3, 4, 5, 6, 6], - "destinations": [5, 4, 3, 2, 2, 6, 5, 2], + "majors": [0, 1, 2, 3, 4, 5, 6, 6], + "minors": [5, 4, 3, 2, 2, 6, 5, 2], "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 2], dtype="int32"), @@ -191,11 +191,11 @@ def test_cugraph_loader_from_disk_subset(): assert ( edge_index[0].tolist() - == bogus_samples.sources.dropna().values_host.tolist() + == bogus_samples.majors.dropna().values_host.tolist() ) assert ( edge_index[1].tolist() - == bogus_samples.destinations.dropna().values_host.tolist() + == bogus_samples.minors.dropna().values_host.tolist() ) assert num_samples == 100 @@ -215,8 +215,8 @@ def test_cugraph_loader_e2e_coo(): bogus_samples = cudf.DataFrame( { - "sources": [0, 1, 2, 3, 4, 5, 6, 6], - "destinations": [5, 4, 3, 2, 2, 6, 5, 2], + "majors": [0, 1, 2, 3, 4, 5, 6, 6], + "minors": [5, 4, 3, 2, 2, 6, 5, 2], "edge_type": cudf.Series([0, 0, 0, 0, 0, 0, 0, 0], dtype="int32"), "edge_id": [5, 10, 15, 20, 25, 30, 35, 40], "hop_id": cudf.Series([0, 0, 0, 1, 1, 1, 2, 2], dtype="int32"),