forked from ehoogeboom/e3_diffusion_for_molecules
-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_sample.py
164 lines (126 loc) · 5.47 KB
/
eval_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Rdkit import should be first, do not move it
try:
from rdkit import Chem
except ModuleNotFoundError:
pass
import utils
import argparse
from configs.datasets_config import qm9_with_h, qm9_without_h
from qm9 import dataset
from qm9.models import get_model
from equivariant_diffusion.utils import assert_correctly_masked
import torch
import pickle
import qm9.visualizer as vis
from qm9.analyze import check_stability
from os.path import join
from qm9.sampling import sample_chain, sample
from configs.datasets_config import get_dataset_info
def check_mask_correct(variables, node_mask):
for variable in variables:
assert_correctly_masked(variable, node_mask)
def save_and_sample_chain(args, eval_args, device, flow,
n_tries, n_nodes, dataset_info, id_from=0,
num_chains=100):
for i in range(num_chains):
target_path = f'eval/chain_{i}/'
one_hot, charges, x = sample_chain(
args, device, flow, n_tries, dataset_info)
vis.save_xyz_file(
join(eval_args.model_path, target_path), one_hot, charges, x,
dataset_info, id_from, name='chain')
vis.visualize_chain_uncertainty(
join(eval_args.model_path, target_path), dataset_info,
spheres_3d=True)
return one_hot, charges, x
def sample_different_sizes_and_save(args, eval_args, device, generative_model,
nodes_dist, dataset_info, n_samples=10):
nodesxsample = nodes_dist.sample(n_samples)
one_hot, charges, x, node_mask = sample(
args, device, generative_model, dataset_info,
nodesxsample=nodesxsample)
vis.save_xyz_file(
join(eval_args.model_path, 'eval/molecules/'), one_hot, charges, x,
id_from=0, name='molecule', dataset_info=dataset_info,
node_mask=node_mask)
def sample_only_stable_different_sizes_and_save(
args, eval_args, device, flow, nodes_dist,
dataset_info, n_samples=10, n_tries=50):
assert n_tries > n_samples
nodesxsample = nodes_dist.sample(n_tries)
one_hot, charges, x, node_mask = sample(
args, device, flow, dataset_info,
nodesxsample=nodesxsample)
counter = 0
for i in range(n_tries):
num_atoms = int(node_mask[i:i+1].sum().item())
atom_type = one_hot[i:i+1, :num_atoms].argmax(2).squeeze(0).cpu().detach().numpy()
x_squeeze = x[i:i+1, :num_atoms].squeeze(0).cpu().detach().numpy()
mol_stable = check_stability(x_squeeze, atom_type, dataset_info)[0]
num_remaining_attempts = n_tries - i - 1
num_remaining_samples = n_samples - counter
if mol_stable or num_remaining_attempts <= num_remaining_samples:
if mol_stable:
print('Found stable mol.')
vis.save_xyz_file(
join(eval_args.model_path, 'eval/molecules/'),
one_hot[i:i+1], charges[i:i+1], x[i:i+1],
id_from=counter, name='molecule_stable',
dataset_info=dataset_info,
node_mask=node_mask[i:i+1])
counter += 1
if counter >= n_samples:
break
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str,
default="outputs/edm_1",
help='Specify model path')
parser.add_argument(
'--n_tries', type=int, default=10,
help='N tries to find stable molecule for gif animation')
parser.add_argument('--n_nodes', type=int, default=19,
help='number of atoms in molecule for gif animation')
eval_args, unparsed_args = parser.parse_known_args()
assert eval_args.model_path is not None
with open(join(eval_args.model_path, 'args.pickle'), 'rb') as f:
args = pickle.load(f)
# CAREFUL with this -->
if not hasattr(args, 'normalization_factor'):
args.normalization_factor = 1
if not hasattr(args, 'aggregation_method'):
args.aggregation_method = 'sum'
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
args.device = device
dtype = torch.float32
utils.create_folders(args)
print(args)
dataset_info = get_dataset_info(args.dataset, args.remove_h)
dataloaders, charge_scale = dataset.retrieve_dataloaders(args)
flow, nodes_dist, prop_dist = get_model(
args, device, dataset_info, dataloaders['train'])
flow.to(device)
fn = 'generative_model_ema.npy' if args.ema_decay > 0 else 'generative_model.npy'
flow_state_dict = torch.load(join(eval_args.model_path, fn),
map_location=device)
flow.load_state_dict(flow_state_dict)
print('Sampling handful of molecules.')
sample_different_sizes_and_save(
args, eval_args, device, flow, nodes_dist,
dataset_info=dataset_info, n_samples=30)
print('Sampling stable molecules.')
sample_only_stable_different_sizes_and_save(
args, eval_args, device, flow, nodes_dist,
dataset_info=dataset_info, n_samples=10, n_tries=2*10)
print('Visualizing molecules.')
vis.visualize(
join(eval_args.model_path, 'eval/molecules/'), dataset_info,
max_num=100, spheres_3d=True)
print('Sampling visualization chain.')
save_and_sample_chain(
args, eval_args, device, flow,
n_tries=eval_args.n_tries, n_nodes=eval_args.n_nodes,
dataset_info=dataset_info)
if __name__ == "__main__":
main()