-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
40 lines (28 loc) · 1.5 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os.path as osp
from torch_geometric.datasets import Planetoid, CitationFull, WikiCS, Coauthor, Amazon
import torch_geometric.transforms as T
# 这个数据集不跑
# from ogb.nodeproppred import PygNodePropPredDataset
def get_dataset(path, name):
assert name in ['Cora', 'CiteSeer', 'PubMed', 'DBLP', 'Karate', 'WikiCS', 'Coauthor-CS', 'Coauthor-Phy',
'Amazon-Computers', 'Amazon-Photo', 'ogbn-arxiv', 'ogbg-code']
name = 'dblp' if name == 'DBLP' else name
root_path = osp.expanduser('~/datasets')
if name == 'Coauthor-CS':
return Coauthor(root=path, name='cs', transform=T.NormalizeFeatures())
if name == 'Coauthor-Phy':
return Coauthor(root=path, name='physics', transform=T.NormalizeFeatures())
if name == 'WikiCS':
return WikiCS(root=path, transform=T.NormalizeFeatures())
if name == 'Amazon-Computers':
return Amazon(root=path, name='computers', transform=T.NormalizeFeatures())
if name == 'Amazon-Photo':
return Amazon(root=path, name='photo', transform=T.NormalizeFeatures())
# if name.startswith('ogbn'):
# return PygNodePropPredDataset(root=osp.join(root_path, 'OGB'), name=name, transform=T.NormalizeFeatures())
return (CitationFull if name == 'dblp' else Planetoid)(osp.join(root_path, 'Citation'), name, transform=T.NormalizeFeatures())
def get_path(base_path, name):
if name in ['Cora', 'CiteSeer', 'PubMed']:
return base_path
else:
return osp.join(base_path, name)