diff --git a/.gitignore b/.gitignore index 46a5434..02e114f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,11 @@ __pycache__ *.pyc *.egg-info/ *.ipynb_checkpoints/ -*.pt \ No newline at end of file +*.pt + +.DS_Store +.vscode/ +build/ +datasets/ +dist/ +log/ \ No newline at end of file diff --git a/cpa/model.py b/cpa/model.py index 538b6de..de20b2d 100644 --- a/cpa/model.py +++ b/cpa/model.py @@ -361,8 +361,6 @@ def predict( """ genes, drugs, covariates = self.move_inputs_(genes, drugs, covariates) - if self.loss_ae == 'nb': - genes = torch.log1p(genes) latent_basal = self.encoder(genes) @@ -378,19 +376,21 @@ def predict( ) #argmax because OHE gene_reconstructions = self.decoder(latent_treated) + dim = gene_reconstructions.size(1) // 2 if self.loss_ae == 'gauss': # convert variance estimates to a positive value in [1e-3, \infty) - dim = gene_reconstructions.size(1) // 2 gene_means = gene_reconstructions[:, :dim] gene_vars = F.softplus(gene_reconstructions[:, dim:]).add(1e-3) #gene_vars = gene_reconstructions[:, dim:].exp().add(1).log().add(1e-3) + gene_reconstructions = torch.cat([gene_means, gene_vars], dim=1) if self.loss_ae == 'nb': - gene_means = F.softplus(gene_means).add(1e-3) + gene_mus = F.softplus(gene_reconstructions[:, :dim]).add(1e-3) + gene_thetas = F.softplus(gene_reconstructions[:, dim:]).add(1e-3) #gene_reconstructions[:, :dim] = torch.clamp(gene_reconstructions[:, :dim], min=1e-4, max=1e4) #gene_reconstructions[:, dim:] = torch.clamp(gene_reconstructions[:, dim:], min=1e-4, max=1e4) - gene_reconstructions = torch.cat([gene_means, gene_vars], dim=1) - + gene_reconstructions = torch.cat([gene_mus, gene_thetas], dim=1) + if return_latent_basal: if return_latent_treated: return gene_reconstructions, latent_basal, latent_treated @@ -430,9 +430,7 @@ def update(self, genes, drugs, covariates): ) dim = gene_reconstructions.size(1) // 2 - gene_means = gene_reconstructions[:, :dim] - gene_vars = gene_reconstructions[:, dim:] - reconstruction_loss = self.loss_autoencoder(gene_means, genes, gene_vars) + reconstruction_loss = self.loss_autoencoder(gene_reconstructions[:, :dim], genes, gene_reconstructions[:, dim:]) adversary_drugs_loss = torch.tensor([0.0], device=self.device) if self.num_drugs > 0: adversary_drugs_predictions = self.adversary_drugs(latent_basal) @@ -456,7 +454,7 @@ def update(self, genes, drugs, covariates): adversary_drugs_penalty = torch.tensor([0.0], device=self.device) adversary_covariates_penalty = torch.tensor([0.0], device=self.device) - if self.iteration % self.hparams["adversary_steps"]: + if self.iteration % self.hparams["adversary_steps"] == 0: def compute_gradients(output, input): grads = torch.autograd.grad(output, input, create_graph=True) diff --git a/cpa/train.py b/cpa/train.py index cc72cdc..daf92cd 100644 --- a/cpa/train.py +++ b/cpa/train.py @@ -114,7 +114,7 @@ def compute_score(labels): return [np.mean(pert_scores), *[np.mean(cov_score) for cov_score in cov_scores]] -def evaluate_r2(autoencoder, dataset, genes_control): +def evaluate_r2(autoencoder, dataset, genes_control, min_samples=30): """ Measures different quality metrics about an CPA `autoencoder`, when tasked to translate some `genes_control` into each of the drug/covariates @@ -128,8 +128,6 @@ def evaluate_r2(autoencoder, dataset, genes_control): mean_score, var_score, mean_score_de, var_score_de = [], [], [], [] num, dim = genes_control.size(0), genes_control.size(1) - total_cells = len(dataset) - for pert_category in np.unique(dataset.pert_categories): # pert_category category contains: 'celltype_perturbation_dose' info de_idx = np.where( @@ -138,7 +136,8 @@ def evaluate_r2(autoencoder, dataset, genes_control): idx = np.where(dataset.pert_categories == pert_category)[0] - if len(idx) > 30: + # estimate metrics only for reasonably-sized drug/cell-type combos + if len(idx) > min_samples: emb_drugs = dataset.drugs[idx][0].view(1, -1).repeat(num, 1).clone() emb_covars = [ covar[idx][0].view(1, -1).repeat(num, 1).clone() @@ -169,14 +168,12 @@ def evaluate_r2(autoencoder, dataset, genes_control): total_count=counts, logits=logits ) - nb_sample = dist.sample().cpu().numpy() - yp_m = nb_sample.mean(0) - yp_v = nb_sample.var(0) + yp_m = dist.mean.mean(0) + yp_v = dist.variance.mean(0) else: # predicted means and variances yp_m = mean_predict.mean(0) yp_v = var_predict.mean(0) - # estimate metrics only for reasonably-sized drug/cell-type combos y_true = dataset.genes[idx, :].numpy() @@ -376,7 +373,7 @@ def parse_arguments(): parser.add_argument("--perturbation_key", type=str, default="condition") parser.add_argument("--control", type=str, default=None) parser.add_argument("--dose_key", type=str, default="dose_val") - parser.add_argument("--covariate_keys", nargs="*", type=str, default="cell_type") + parser.add_argument("--covariate_keys", nargs="*", type=str, default=["cell_type"]) parser.add_argument("--split_key", type=str, default="split") parser.add_argument("--loss_ae", type=str, default="gauss") parser.add_argument("--doser_type", type=str, default="sigm")