From 5bcba3c100b8f44189fb4bc867797b3888aebfe9 Mon Sep 17 00:00:00 2001 From: tai-dang11 Date: Thu, 5 Sep 2024 01:44:55 +0000 Subject: [PATCH] update models --- src/protein_ligand_embedder/model_wrapper.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/protein_ligand_embedder/model_wrapper.py b/src/protein_ligand_embedder/model_wrapper.py index 4518021..8095bff 100644 --- a/src/protein_ligand_embedder/model_wrapper.py +++ b/src/protein_ligand_embedder/model_wrapper.py @@ -5,6 +5,7 @@ from torch_geometric.loader import DataLoader from DockingModels import EquivariantElucidatedDiffusion from datasets import Dataset +from torch_scatter import scatter, scatter_mean class ProteinLigandWrapper(nn.Module): @@ -37,11 +38,11 @@ def get_emnbeddings( tqdm_dataloader = tqdm(dataloader) for batch in tqdm_dataloader: batch = batch.to(self.diffusion_model.device) - _, x_t = self.diffusion_model.sample(batch, num_steps, dtype) - lig_seq_len = torch.bincount(batch['ligand'].batch).tolist() - lig_coords = torch.split(x_t, lig_seq_len) - - col_emb.append(lig_coords) + x_t, rec_node_attr, lig_node_attr = self.diffusion_model.sample(batch, num_steps, dtype) + rec_node_attr = scatter_mean(rec_node_attr, batch['receptor'].batch, dim=0) + lig_node_attr = scatter_mean(lig_node_attr, batch['ligand'].batch, dim=0) + embedding = torch.cat([lig_node_attr, rec_node_attr], dim=1) + col_emb.append(embedding) emb_dict['embedding'] = col_emb return Dataset.from_dict(emb_dict) \ No newline at end of file