-
Notifications
You must be signed in to change notification settings - Fork 25
/
model_factory.py
50 lines (43 loc) · 1.52 KB
/
model_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from config import FLAGS
from collections import Counter
from model_text_gnn import TextGNN
from utils import parse_as_int_list
import torch
def create_model(dataset):
sp = vars(FLAGS)["model"].split(':')
name = sp[0]
layer_info = {}
if len(sp) > 1:
assert (len(sp) == 2)
for spec in sp[1].split(','):
ssp = spec.split('=')
layer_info[ssp[0]] = '='.join(ssp[1:]) # could have '=' in layer_info
if name in model_ctors:
return model_ctors[name](layer_info, dataset)
else:
raise ValueError("Model not implemented {}".format(name))
def create_text_gnn(layer_info, dataset):
lyr_dims = parse_as_int_list(layer_info["layer_dim_list"])
lyr_dims = [dataset.node_feats.shape[1]] + lyr_dims
weights = None
if layer_info["class_weights"].lower() == "true":
counts = Counter(dataset.label_inds[dataset.node_ids])
weights = len(counts) * [0]
min_weight = min(counts.values())
for k, v in counts.items():
weights[k] = min_weight / float(v)
weights = torch.tensor(weights, device=FLAGS.device)
return TextGNN(
pred_type=layer_info["pred_type"],
node_embd_type=layer_info["node_embd_type"],
num_layers=int(layer_info["num_layers"]),
layer_dim_list=lyr_dims,
act=layer_info["act"],
bn=False,
num_labels=len(dataset.label_dict),
class_weights=weights,
dropout=layer_info["dropout"]
)
model_ctors = {
'TextGNN': create_text_gnn,
}