From b6a5620a093c393cdf57165949b5736c4ee986a3 Mon Sep 17 00:00:00 2001 From: Jeroen Date: Mon, 23 May 2022 11:19:04 +0200 Subject: [PATCH] add init for DAG-GNN INFO: requires different implementation for D-Struct(DAG-GNN) Signed-off-by: Jeroen --- .gitignore | 1 + src/daggnn.py | 224 +++++++++++++++++++++++++++++++++++++++++++++ src/experiments.py | 85 +++++++++-------- src/utils.py | 101 +++++++++++++++++++- 4 files changed, 370 insertions(+), 41 deletions(-) create mode 100644 src/daggnn.py diff --git a/.gitignore b/.gitignore index 2f9191a..f16f2bc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ Untitled.ipynb /wandb /d-struct-notears /d-struct-synth +/d-struct-test_experiment diff --git a/src/daggnn.py b/src/daggnn.py new file mode 100644 index 0000000..957030b --- /dev/null +++ b/src/daggnn.py @@ -0,0 +1,224 @@ +from typing import Any +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.autograd import Variable + +import pytorch_lightning as pl + +import src.utils as ut + + +# Yue et al. +class DAGGNN_MLPEncoder(nn.Module): + """MLP encoder module.""" + def __init__(self, n_in, n_xdims, n_hid, n_out, adj_A, batch_size, do_prob=0., factor=True, tol = 0.1): + super(DAGGNN_MLPEncoder, self).__init__() + + self.adj_A = nn.Parameter(Variable(torch.from_numpy(adj_A).double(), requires_grad=True)) + self.factor = factor + + self.Wa = nn.Parameter(torch.zeros(n_out), requires_grad=True) + self.fc1 = nn.Linear(n_xdims, n_hid, bias = True) + self.fc2 = nn.Linear(n_hid, n_out, bias = True) + self.dropout_prob = do_prob + self.batch_size = batch_size + self.z = nn.Parameter(torch.tensor(tol)) + self.z_positive = nn.Parameter(torch.ones_like(torch.from_numpy(adj_A)).double()) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + elif isinstance(m, nn.BatchNorm1d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + + def forward(self, inputs, rel_rec, rel_send): + + if torch.sum(self.adj_A != self.adj_A): + print('nan error \n') + + # to amplify the value of A and accelerate convergence. + adj_A1 = torch.sinh(3.*self.adj_A) + + # adj_Aforz = I-A^T + adj_Aforz = ut.preprocess_adj_new(adj_A1) + + adj_A = torch.eye(adj_A1.size()[0]).double() + H1 = F.relu((self.fc1(inputs))) + x = (self.fc2(H1)) + logits = torch.matmul(x+self.Wa, adj_Aforz) -self.Wa + + return x, logits, adj_A1, adj_A, self.z, self.z_positive, self.adj_A, self.Wa + +# Yue et al. +class DAGGNN_MLPDecoder(nn.Module): + """MLP decoder module.""" + + def __init__(self, n_in_node, n_in_z, n_out, encoder, data_variable_size, batch_size, n_hid, + do_prob=0.): + super(DAGGNN_MLPDecoder, self).__init__() + + self.out_fc1 = nn.Linear(n_in_z, n_hid, bias = True) + self.out_fc2 = nn.Linear(n_hid, n_out, bias = True) + + self.batch_size = batch_size + self.data_variable_size = data_variable_size + + self.dropout_prob = do_prob + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + m.bias.data.fill_(0.0) + elif isinstance(m, nn.BatchNorm1d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, inputs, input_z, n_in_node, rel_rec, rel_send, origin_A, adj_A_tilt, Wa): + + #adj_A_new1 = (I-A^T)^(-1) + adj_A_new1 = ut.preprocess_adj_new1(origin_A) + mat_z = torch.matmul(input_z+Wa, adj_A_new1)-Wa + + H3 = F.relu(self.out_fc1((mat_z))) + out = self.out_fc2(H3) + + return mat_z, out, adj_A_tilt + + + +class DAG_GNN(pl.LightningModule): + def __init__( + self, + dim: int, + n: int, + A: np.ndarray, + tau_A: float=.0, + lambda_A: float=.0, + c_A: float=1., + lr: float=.001, + lr_decay: float=30, + gamma: float=.1, + ): + + super().__init__() + + self.dim = dim + self.n = n + self.A = A + self.tau_A = tau_A + self.lambda_A = lambda_A + self.c_A = c_A + + self.encoder = DAGGNN_MLPEncoder( + n_in=self.dim, + n_xdims=self.dim, + n_hid=self.dim, + n_out=self.dim, + adj_A=self.A, + batch_size=256 + ) + + self.decoder = DAGGNN_MLPDecoder( + n_in_node=self.dim, + n_in_z=self.dim, + n_out=self.dim, + encoder=self.encoder, + data_variable_size=self.dim, + batch_size=256, + n_hid=self.dim + ) + + self.lr = lr + self.lr_decay = lr_decay + self.gamma = gamma + + off_diag = np.ones([self.dim, self.dim]) - np.eye(self.dim) + + rel_rec = np.array(ut.encode_onehot(np.where(off_diag)[1]), dtype=np.float64) + rel_send = np.array(ut.encode_onehot(np.where(off_diag)[0]), dtype=np.float64) + rel_rec = torch.DoubleTensor(rel_rec) + rel_send = torch.DoubleTensor(rel_send) + + self.rel_rec = Variable(rel_rec) + self.rel_send = Variable(rel_send) + + self.prox_plus = torch.nn.Threshold(0.,0.) + + + self.triu_indices = ut.get_triu_offdiag_indices(self.dim) + self.tril_indices = ut.get_tril_offdiag_indices(self.dim) + + self.graph = None + + + + def configure_optimizers(self): + optimizer = torch.optim.Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=self.lr) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=self.lr_decay, gamma=self.gamma + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + } + + def _h_A(self, A, m): + expm_A = ut.matrix_poly(A*A, m) + h_A = torch.trace(expm_A) - m + return h_A + + def stau(self, w, tau): + w1 = self.prox_plus(torch.abs(w)-tau) + return torch.sign(w)*w1 + + def training_step(self, batch, batch_idx): + (X,) = batch + enc_x, logits, origin_A, adj_A_tilt_encoder, z_gap, z_positive, myA, Wa = self.encoder(X, self.rel_rec, self.rel_send) + edges = logits + dec_x, output, adj_A_tilt_decoder = self.decoder(X, edges, self.dim, self.rel_rec, self.rel_send, origin_A, adj_A_tilt_encoder, Wa) + + + target=X + preds=output + variance=0. + + # reconstruction accuracy loss + loss_nll = ut.nll_gaussian(preds, target, variance) + + # KL loss + loss_kl = ut.kl_gaussian_sem(logits) + + # ELBO loss: + loss = loss_kl + loss_nll + + # add A loss + one_adj_A = origin_A # torch.mean(adj_A_tilt_decoder, dim =0) + sparse_loss = self.tau_A * torch.sum(torch.abs(one_adj_A)) + + h_A = self._h_A(origin_A, self.dim) + loss += self.lambda_A * h_A + 0.5 * self.c_A * h_A * h_A + 100. * torch.trace(origin_A*origin_A) + sparse_loss + + self.log('loss', loss) + + self.graph = origin_A.data.clone().numpy() + + return loss + + def test_step(self, batch, batch_idx) -> Any: + B_est = self.graph + B_true = self.trainer.datamodule.DAG + print(f"B_est: {B_est}") + print(f"B_true: {B_true}") + self.log_dict(ut.count_accuracy(B_true, B_est)) + diff --git a/src/experiments.py b/src/experiments.py index d8fdbc1..09051a9 100644 --- a/src/experiments.py +++ b/src/experiments.py @@ -12,8 +12,9 @@ import src.utils as ut from src.data import BetaP, Data from src.dstruct import NOTEARS, DStruct, lit_NOTEARS +from src.daggnn import DAG_GNN -model_refs = {"notears-mlp": lit_NOTEARS, "dstruct-mlp": DStruct} +model_refs = {"notears-mlp": lit_NOTEARS, "dstruct-mlp": DStruct, "dag-gnn": DAG_GNN} def experiment( @@ -155,47 +156,57 @@ def main( "batch_size": batch_size, }, models={ - "dstruct-mlp": { + # "dstruct-mlp": { + # "model": { + # "dim": d, + # "dsl": NOTEARS, + # "dsl_config": {"dim": d}, + # "h_tol": nt_h_tol, + # "rho_max": nt_rho_max, + # "p": BetaP(k, bool(sort), bool(rand_sort)), + # "K": k, + # "lmbda": lmbda, + # "n": n, + # "s": s, + # "dag_type": graph_type, + # }, + # "train": { + # "max_epochs": epochs, + # "callbacks": [ + # EarlyStopping(monitor="h", stopping_threshold=nt_h_tol), + # EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max), + # ], + # }, + # }, + # "notears-mlp": { + # "model": { + # "model": NOTEARS(dim=d), + # "h_tol": nt_h_tol, + # "rho_max": nt_rho_max, + # "n": n, + # "s": s, + # "dim": d, + # "K": k, + # "dag_type": graph_type, + # }, + # "train": { + # "max_epochs": epochs, + # "callbacks": [ + # EarlyStopping(monitor="h", stopping_threshold=nt_h_tol), + # EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max), + # ], + # }, + # }, + "dag-gnn": { "model": { "dim": d, - "dsl": NOTEARS, - "dsl_config": {"dim": d}, - "h_tol": nt_h_tol, - "rho_max": nt_rho_max, - "p": BetaP(k, bool(sort), bool(rand_sort)), - "K": k, - "lmbda": lmbda, "n": n, - "s": s, - "dag_type": graph_type, + "A": np.zeros((d, d)) }, "train": { - "max_epochs": epochs, - "callbacks": [ - EarlyStopping(monitor="h", stopping_threshold=nt_h_tol), - EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max), - ], - }, - }, - "notears-mlp": { - "model": { - "model": NOTEARS(dim=d), - "h_tol": nt_h_tol, - "rho_max": nt_rho_max, - "n": n, - "s": s, - "dim": d, - "K": k, - "dag_type": graph_type, - }, - "train": { - "max_epochs": epochs, - "callbacks": [ - EarlyStopping(monitor="h", stopping_threshold=nt_h_tol), - EarlyStopping(monitor="rho", stopping_threshold=nt_rho_max), - ], - }, - }, + "max_epochs": epochs + } + } }, ) diff --git a/src/utils.py b/src/utils.py index 06d3c75..16580ac 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,4 +1,3 @@ -# Source from Zheng et al. 2020 (https://github.com/xunzheng/notears/blob/master/notears/locally_connected.py) import logging import math @@ -8,12 +7,106 @@ import torch import torch.nn as nn +import torch.nn.functional as F + +# Zheng et al. def is_dag(W): G = ig.Graph.Weighted_Adjacency(W.tolist()) return G.is_dag() - +# Yue et al. +def my_softmax(input, axis=1): + trans_input = input.transpose(axis, 0).contiguous() + soft_max_1d = F.softmax(trans_input) + return soft_max_1d.transpose(axis, 0) + +# Yue et al. +def preprocess_adj_new(adj): + adj_normalized = (torch.eye(adj.shape[0]).double() - (adj.transpose(0,1))) + return adj_normalized + +# Yue et al. +def preprocess_adj_new1(adj): + adj_normalized = torch.inverse(torch.eye(adj.shape[0]).double()-adj.transpose(0,1)) + return adj_normalized + +# Yue et al. +def get_triu_indices(num_nodes): + """Linear triu (upper triangular) indices.""" + ones = torch.ones(num_nodes, num_nodes) + eye = torch.eye(num_nodes, num_nodes) + triu_indices = (ones.triu() - eye).nonzero().t() + triu_indices = triu_indices[0] * num_nodes + triu_indices[1] + return triu_indices + +# Yue et al. +def get_tril_indices(num_nodes): + """Linear tril (lower triangular) indices.""" + ones = torch.ones(num_nodes, num_nodes) + eye = torch.eye(num_nodes, num_nodes) + tril_indices = (ones.tril() - eye).nonzero().t() + tril_indices = tril_indices[0] * num_nodes + tril_indices[1] + return tril_indices + +# Yue et al. +def get_offdiag_indices(num_nodes): + """Linear off-diagonal indices.""" + ones = torch.ones(num_nodes, num_nodes) + eye = torch.eye(num_nodes, num_nodes) + offdiag_indices = (ones - eye).nonzero().t() + offdiag_indices = offdiag_indices[0] * num_nodes + offdiag_indices[1] + return offdiag_indices + +# Yue et al. +def get_triu_offdiag_indices(num_nodes): + """Linear triu (upper) indices w.r.t. vector of off-diagonal elements.""" + triu_idx = torch.zeros(num_nodes * num_nodes) + triu_idx[get_triu_indices(num_nodes)] = 1. + triu_idx = triu_idx[get_offdiag_indices(num_nodes)] + return triu_idx.nonzero() + +# Yue et al. +def get_tril_offdiag_indices(num_nodes): + """Linear tril (lower) indices w.r.t. vector of off-diagonal elements.""" + tril_idx = torch.zeros(num_nodes * num_nodes) + tril_idx[get_tril_indices(num_nodes)] = 1. + tril_idx = tril_idx[get_offdiag_indices(num_nodes)] + return tril_idx.nonzero() + +# Yue et al. +def nll_gaussian(preds, target, variance, add_const=False): + mean1 = preds + mean2 = target + neg_log_p = variance + torch.div(torch.pow(mean1 - mean2, 2), 2.*np.exp(2. * variance)) + if add_const: + const = 0.5 * torch.log(2 * torch.from_numpy(np.pi) * variance) + neg_log_p += const + return neg_log_p.sum() / (target.size(0)) + +# Yue et al. +def kl_gaussian_sem(preds): + mu = preds + kl_div = mu * mu + kl_sum = kl_div.sum() + return (kl_sum / (preds.size(0)))*0.5 + +# Yue et al. +def encode_onehot(labels): + classes = set(labels) + classes_dict = {c: np.identity(len(classes))[i, :] for i, c in + enumerate(classes)} + labels_onehot = np.array(list(map(classes_dict.get, labels)), + dtype=np.int32) + return labels_onehot + +# Yue et al. +def matrix_poly(matrix, d): + x = torch.eye(d).double()+ torch.div(matrix, d) + return torch.matrix_power(x, d) + + +# Zheng et al. class LBFGSBScipy(torch.optim.Optimizer): """Wrap L-BFGS-B algorithm, using scipy routines. @@ -108,7 +201,7 @@ def wrapped_closure(flat_params): self._distribute_flat_params(final_params) - +# Zheng et al. class LocallyConnected(nn.Module): """Local linear layer, i.e. Conv1dLocal() with filter size 1. Args: @@ -169,7 +262,7 @@ def extra_repr(self): self.bias is not None, ) - +# Zheng et al. def count_accuracy(B_true, B_est): """Compute various accuracy metrics for B_est. true positive = predicted association exists in condition in correct direction