-
Notifications
You must be signed in to change notification settings - Fork 1
/
module.py
64 lines (44 loc) · 1.37 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
This files contains lighting AI module for training
"""
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
# Import Bert
from transformers import BertModel, BertTokenizer
from graph import Graph
class GraphEncoder(nn.Module):
def __init__(self):
super(GraphEncoder, self).__init__()
def forward(self, x):
pass
class GraphDecoder:
"""
TODO: See if you can load a pretrained language decoder
"""
def __init__(self):
super(GraphDecoder, self).__init__()
def forward(self, x):
pass
class LitModel(L.LightningModule):
def __init__(self, model, graph: Graph):
super().__init__()
# CHECK: May load pretrained
self.graph = graph
self.encoder = BertModel()
self.decoder = GraphDecoder() # Perhaps load some pretrained dcoder here
self.model = model
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
graph_trans_embeds = self.encoder(x)
# Travel the Graph
center_position = self.graph.navigate(graph_trans_embeds)
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss_fn(logits, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)