Skip to content

Commit

Permalink
add init for DAG-GNN
Browse files Browse the repository at this point in the history
INFO: requires different implementation for D-Struct(DAG-GNN)
Signed-off-by: Jeroen <[email protected]>
  • Loading branch information
jeroenbe committed May 23, 2022
1 parent 5014d26 commit b6a5620
Show file tree
Hide file tree
Showing 4 changed files with 370 additions and 41 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Untitled.ipynb
/wandb
/d-struct-notears
/d-struct-synth
/d-struct-test_experiment
224 changes: 224 additions & 0 deletions src/daggnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from typing import Any
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable

import pytorch_lightning as pl

import src.utils as ut


# Yue et al.
class DAGGNN_MLPEncoder(nn.Module):
"""MLP encoder module."""
def __init__(self, n_in, n_xdims, n_hid, n_out, adj_A, batch_size, do_prob=0., factor=True, tol = 0.1):
super(DAGGNN_MLPEncoder, self).__init__()

self.adj_A = nn.Parameter(Variable(torch.from_numpy(adj_A).double(), requires_grad=True))
self.factor = factor

self.Wa = nn.Parameter(torch.zeros(n_out), requires_grad=True)
self.fc1 = nn.Linear(n_xdims, n_hid, bias = True)
self.fc2 = nn.Linear(n_hid, n_out, bias = True)
self.dropout_prob = do_prob
self.batch_size = batch_size
self.z = nn.Parameter(torch.tensor(tol))
self.z_positive = nn.Parameter(torch.ones_like(torch.from_numpy(adj_A)).double())
self.init_weights()

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()


def forward(self, inputs, rel_rec, rel_send):

if torch.sum(self.adj_A != self.adj_A):
print('nan error \n')

# to amplify the value of A and accelerate convergence.
adj_A1 = torch.sinh(3.*self.adj_A)

# adj_Aforz = I-A^T
adj_Aforz = ut.preprocess_adj_new(adj_A1)

adj_A = torch.eye(adj_A1.size()[0]).double()
H1 = F.relu((self.fc1(inputs)))
x = (self.fc2(H1))
logits = torch.matmul(x+self.Wa, adj_Aforz) -self.Wa

return x, logits, adj_A1, adj_A, self.z, self.z_positive, self.adj_A, self.Wa

# Yue et al.
class DAGGNN_MLPDecoder(nn.Module):
"""MLP decoder module."""

def __init__(self, n_in_node, n_in_z, n_out, encoder, data_variable_size, batch_size, n_hid,
do_prob=0.):
super(DAGGNN_MLPDecoder, self).__init__()

self.out_fc1 = nn.Linear(n_in_z, n_hid, bias = True)
self.out_fc2 = nn.Linear(n_hid, n_out, bias = True)

self.batch_size = batch_size
self.data_variable_size = data_variable_size

self.dropout_prob = do_prob

self.init_weights()

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
m.bias.data.fill_(0.0)
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, inputs, input_z, n_in_node, rel_rec, rel_send, origin_A, adj_A_tilt, Wa):

#adj_A_new1 = (I-A^T)^(-1)
adj_A_new1 = ut.preprocess_adj_new1(origin_A)
mat_z = torch.matmul(input_z+Wa, adj_A_new1)-Wa

H3 = F.relu(self.out_fc1((mat_z)))
out = self.out_fc2(H3)

return mat_z, out, adj_A_tilt



class DAG_GNN(pl.LightningModule):
def __init__(
self,
dim: int,
n: int,
A: np.ndarray,
tau_A: float=.0,
lambda_A: float=.0,
c_A: float=1.,
lr: float=.001,
lr_decay: float=30,
gamma: float=.1,
):

super().__init__()

self.dim = dim
self.n = n
self.A = A
self.tau_A = tau_A
self.lambda_A = lambda_A
self.c_A = c_A

self.encoder = DAGGNN_MLPEncoder(
n_in=self.dim,
n_xdims=self.dim,
n_hid=self.dim,
n_out=self.dim,
adj_A=self.A,
batch_size=256
)

self.decoder = DAGGNN_MLPDecoder(
n_in_node=self.dim,
n_in_z=self.dim,
n_out=self.dim,
encoder=self.encoder,
data_variable_size=self.dim,
batch_size=256,
n_hid=self.dim
)

self.lr = lr
self.lr_decay = lr_decay
self.gamma = gamma

off_diag = np.ones([self.dim, self.dim]) - np.eye(self.dim)

rel_rec = np.array(ut.encode_onehot(np.where(off_diag)[1]), dtype=np.float64)
rel_send = np.array(ut.encode_onehot(np.where(off_diag)[0]), dtype=np.float64)
rel_rec = torch.DoubleTensor(rel_rec)
rel_send = torch.DoubleTensor(rel_send)

self.rel_rec = Variable(rel_rec)
self.rel_send = Variable(rel_send)

self.prox_plus = torch.nn.Threshold(0.,0.)


self.triu_indices = ut.get_triu_offdiag_indices(self.dim)
self.tril_indices = ut.get_tril_offdiag_indices(self.dim)

self.graph = None



def configure_optimizers(self):
optimizer = torch.optim.Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=self.lr)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=self.lr_decay, gamma=self.gamma
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
}

def _h_A(self, A, m):
expm_A = ut.matrix_poly(A*A, m)
h_A = torch.trace(expm_A) - m
return h_A

def stau(self, w, tau):
w1 = self.prox_plus(torch.abs(w)-tau)
return torch.sign(w)*w1

def training_step(self, batch, batch_idx):
(X,) = batch
enc_x, logits, origin_A, adj_A_tilt_encoder, z_gap, z_positive, myA, Wa = self.encoder(X, self.rel_rec, self.rel_send)
edges = logits
dec_x, output, adj_A_tilt_decoder = self.decoder(X, edges, self.dim, self.rel_rec, self.rel_send, origin_A, adj_A_tilt_encoder, Wa)


target=X
preds=output
variance=0.

# reconstruction accuracy loss
loss_nll = ut.nll_gaussian(preds, target, variance)

# KL loss
loss_kl = ut.kl_gaussian_sem(logits)

# ELBO loss:
loss = loss_kl + loss_nll

# add A loss
one_adj_A = origin_A # torch.mean(adj_A_tilt_decoder, dim =0)
sparse_loss = self.tau_A * torch.sum(torch.abs(one_adj_A))

h_A = self._h_A(origin_A, self.dim)
loss += self.lambda_A * h_A + 0.5 * self.c_A * h_A * h_A + 100. * torch.trace(origin_A*origin_A) + sparse_loss

self.log('loss', loss)

self.graph = origin_A.data.clone().numpy()

return loss

def test_step(self, batch, batch_idx) -> Any:
B_est = self.graph
B_true = self.trainer.datamodule.DAG
print(f"B_est: {B_est}")
print(f"B_true: {B_true}")
self.log_dict(ut.count_accuracy(B_true, B_est))

85 changes: 48 additions & 37 deletions src/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import src.utils as ut
from src.data import BetaP, Data
from src.dstruct import NOTEARS, DStruct, lit_NOTEARS
from src.daggnn import DAG_GNN

model_refs = {"notears-mlp": lit_NOTEARS, "dstruct-mlp": DStruct}
model_refs = {"notears-mlp": lit_NOTEARS, "dstruct-mlp": DStruct, "dag-gnn": DAG_GNN}


def experiment(
Expand Down Expand Up @@ -155,47 +156,57 @@ def main(
"batch_size": batch_size,
},
models={
"dstruct-mlp": {
# "dstruct-mlp": {
# "model": {
# "dim": d,
# "dsl": NOTEARS,
# "dsl_config": {"dim": d},
# "h_tol": nt_h_tol,
# "rho_max": nt_rho_max,
# "p": BetaP(k, bool(sort), bool(rand_sort)),
# "K": k,
# "lmbda": lmbda,
# "n": n,
# "s": s,
# "dag_type": graph_type,
# },
# "train": {
# "max_epochs": epochs,
# "callbacks": [
# EarlyStopping(monitor="h", stopping_threshold=nt_h_tol),
# EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max),
# ],
# },
# },
# "notears-mlp": {
# "model": {
# "model": NOTEARS(dim=d),
# "h_tol": nt_h_tol,
# "rho_max": nt_rho_max,
# "n": n,
# "s": s,
# "dim": d,
# "K": k,
# "dag_type": graph_type,
# },
# "train": {
# "max_epochs": epochs,
# "callbacks": [
# EarlyStopping(monitor="h", stopping_threshold=nt_h_tol),
# EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max),
# ],
# },
# },
"dag-gnn": {
"model": {
"dim": d,
"dsl": NOTEARS,
"dsl_config": {"dim": d},
"h_tol": nt_h_tol,
"rho_max": nt_rho_max,
"p": BetaP(k, bool(sort), bool(rand_sort)),
"K": k,
"lmbda": lmbda,
"n": n,
"s": s,
"dag_type": graph_type,
"A": np.zeros((d, d))
},
"train": {
"max_epochs": epochs,
"callbacks": [
EarlyStopping(monitor="h", stopping_threshold=nt_h_tol),
EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max),
],
},
},
"notears-mlp": {
"model": {
"model": NOTEARS(dim=d),
"h_tol": nt_h_tol,
"rho_max": nt_rho_max,
"n": n,
"s": s,
"dim": d,
"K": k,
"dag_type": graph_type,
},
"train": {
"max_epochs": epochs,
"callbacks": [
EarlyStopping(monitor="h", stopping_threshold=nt_h_tol),
EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max),
],
},
},
"max_epochs": epochs
}
}
},
)

Expand Down
Loading

0 comments on commit b6a5620

Please sign in to comment.