-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
02c7210
commit b67d5ed
Showing
8 changed files
with
972 additions
and
715 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .dataset import Dataset | ||
from .ogbn_papers100M import OGBNPapers100MDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from typing import Dict, Tuple | ||
|
||
class Dataset: | ||
@property | ||
def edge_index_dict(self) -> Dict[Tuple[str, str, str], Dict[str, torch.Tensor]]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def x_dict(self) -> Dict[str, torch.Tensor]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def y_dict(self) -> Dict[str, torch.Tensor]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def train_dict(self) -> Dict[str, torch.Tensor]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def test_dict(self) -> Dict[str, torch.Tensor]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def val_dict(self) -> Dict[str, torch.Tensor]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def num_input_features(self) -> int: | ||
raise NotImplementedError() | ||
|
||
@property | ||
def num_labels(self) -> int: | ||
raise NotImplementedError() | ||
|
||
def num_nodes(self, node_type: str) -> int: | ||
raise NotImplementedError() | ||
|
||
def num_edges(self, edge_type: Tuple[str, str, str]) -> int: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from dataset import Dataset | ||
from typing import Dict, Tuple | ||
|
||
import pandas | ||
import torch | ||
|
||
from sklearn.model_selection import train_test_split | ||
|
||
import gc | ||
|
||
# TODO automatically generate this dataset and splits | ||
class OGBNPapers100MDataset(Dataset): | ||
def __init__(self, *, replication_factor=1, dataset_dir='.', train_split=0.8, val_split=0.5): | ||
self.__replication_factor = replication_factor | ||
self.__disk_x = None | ||
self.__y = None | ||
self.__edge_index = None | ||
self.__dataset_dir = '.' | ||
self.__train_split = train_split | ||
self.__val_split = val_split | ||
|
||
@property | ||
def edge_index_dict(self) -> Dict[Tuple[str, str, str], Dict[str, torch.Tensor]]: | ||
if self.__edge_index is None: | ||
parquet_path = os.path.join( | ||
self.__dataset_dir, | ||
'ogbn_papers100M', | ||
'parquet' | ||
) | ||
|
||
ei = pandas.read_parquet( | ||
os.path.join(parquet_path, 'paper__cites__paper', 'edge_index.parquet') | ||
) | ||
|
||
ei = { | ||
'src': torch.as_tensor(ei.src.values, device='cpu'), | ||
'dst': torch.as_tensor(ei.dst.values, device='cpu'), | ||
} | ||
|
||
print('sorting edge index...') | ||
ei['dst'], ix = torch.sort(ei['dst']) | ||
ei['src'] = ei['src'][ix] | ||
del ix | ||
gc.collect() | ||
|
||
print('processing replications...') | ||
orig_num_nodes = self.num_nodes('paper') // self.__replication_factor | ||
if self.__replication_factor > 1: | ||
orig_src = ei['src'].clone().detach() | ||
orig_dst = ei['dst'].clone().detach() | ||
for r in range(1, replication_factor): | ||
ei['src'] = torch.concat([ | ||
ei['src'], | ||
orig_src + int(r * orig_num_nodes), | ||
]) | ||
|
||
ei['dst'] = torch.concat([ | ||
ei['dst'], | ||
orig_dst + int(r * orig_num_nodes), | ||
]) | ||
|
||
del orig_src | ||
del orig_dst | ||
|
||
ei['src'] = ei['src'].contiguous() | ||
ei['dst'] = ei['dst'].contiguous() | ||
gc.collect() | ||
|
||
print(f"# edges: {len(ei['src'])}") | ||
self.__edge_index = {('paper','cites','paper'): ei} | ||
|
||
return self.__edge_index | ||
|
||
@property | ||
def x_dict(self) -> Dict[str, torch.Tensor]: | ||
node_type_path = os.path.join( | ||
self.__dataset_dir, | ||
'ogbn_papers100M', | ||
'npy', | ||
'paper' | ||
) | ||
|
||
if self.__disk_x is None: | ||
if replication_factor == 1: | ||
full_path = os.path.join(node_type_path, 'node_feat.npy') | ||
else: | ||
full_path = os.path.join(node_type_path, f'node_feat_{replication_factor}x.npy') | ||
|
||
self.__disk_x = {'paper': np.load( | ||
full_path, | ||
mmap_mode='r' | ||
)} | ||
|
||
return self.__disk_x | ||
|
||
@property | ||
def y_dict(self) -> Dict[str, torch.Tensor]: | ||
if self.__y is None: | ||
self.__get_labels() | ||
|
||
return self.__y | ||
|
||
@property | ||
def train_dict(self) -> Dict[str, torch.Tensor]: | ||
if self.__train is None: | ||
self.__get_labels() | ||
return self.__train | ||
|
||
@property | ||
def test_dict(self) -> Dict[str, torch.Tensor]: | ||
if self.__test is None: | ||
self.__get_labels() | ||
return self.__test | ||
|
||
@property | ||
def val_dict(self) -> Dict[str, torch.Tensor]: | ||
if self.__val is None: | ||
self.__get_labels() | ||
return self.__val | ||
|
||
@property | ||
def num_input_features(self) -> int: | ||
return self.x_dict['paper'].shape[1] | ||
|
||
@property | ||
def num_labels(self) -> int: | ||
return self.y_dict['paper'].max() + 1 | ||
|
||
def num_nodes(self, node_type: str) -> int: | ||
if node_type != 'paper': | ||
raise ValueError(f"Invalid node type {node_type}") | ||
|
||
return 111_059_956 * self.__replication_factor | ||
|
||
def num_edges(self, edge_type: Tuple[str, str, str]) -> int: | ||
if edge_type != ('paper', 'cites', 'paper'): | ||
raise ValueError(f"Invalid edge type {edge_type}") | ||
|
||
return 1_615_685_872 * self.__replication_factor | ||
|
||
def __get_labels(self): | ||
label_path = os.path.join( | ||
self.__dataset_dir, | ||
'ogbn_papers100M', | ||
'parquet', | ||
'paper', | ||
'node_label.parquet' | ||
) | ||
|
||
node_label = pandas.read_parquet(label_path) | ||
|
||
if replication_factor > 1: | ||
orig_num_nodes = self.num_nodes('paper') // replication_factor | ||
dfr = pandas.DataFrame({ | ||
'node': pandas.concat([node_label.node + (r * orig_num_nodes) for r in range(1, replication_factor)]), | ||
'label': pandas.concat([node_label.label for r in range(1, replication_factor)]), | ||
}) | ||
node_label = pandas.concat([node_label, dfr]).reset_index(drop=True) | ||
|
||
num_nodes = self.num_nodes("paper") | ||
node_label_tensor = torch.full((num_nodes,), -1, dtype=torch.float32, device='cpu') | ||
node_label_tensor[torch.as_tensor(node_label.node.values, device='cpu')] = \ | ||
torch.as_tensor(node_label.label.values, device='cpu') | ||
|
||
self.__y = {'paper': node_label_tensor.contiguous()} | ||
|
||
train_ix, test_val_ix = train_test_split(torch.as_tensor(node_label.node.values), train_split=self.__train_split, random_state=num_nodes) | ||
test_ix, val_ix = train_test_split(test_val_ix, test_split=self.__val_split, random_state=num_nodes) | ||
|
||
train_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device='cpu') | ||
train_tensor[train_ix] = 1 | ||
self.__train = {'paper': train_tensor} | ||
|
||
test_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device='cpu') | ||
test_tensor[test_ix] = 1 | ||
self.__test = {'paper': test_tensor} | ||
|
||
val_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device='cpu') | ||
val_tensor[val_ix] = 1 | ||
self.__val = {'paper': val_tensor} |
Oops, something went wrong.