From ce37632ba4c4efbb449621a2dcc711642614b33a Mon Sep 17 00:00:00 2001 From: jpaillard Date: Thu, 14 Nov 2024 15:14:02 +0100 Subject: [PATCH] rename variables --- hidimstat/cpi.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/hidimstat/cpi.py b/hidimstat/cpi.py index 19c7bf3..a0a0224 100644 --- a/hidimstat/cpi.py +++ b/hidimstat/cpi.py @@ -150,7 +150,7 @@ def _joblib_predict_one_group(imputation_model, X, j): Compute the prediction of the model with the permuted data for a single group of covariates. """ - list_y_pred_perm = [] + if isinstance(X, pd.DataFrame): X_j = X[self.groups[j]].copy().values X_minus_j = X.drop(columns=self.groups[j]).values @@ -169,15 +169,17 @@ def _joblib_predict_one_group(imputation_model, X, j): X_j_hat = imputation_model.predict(X_minus_j).reshape(X_j.shape) residual_j = X_j - X_j_hat - X_perm_all = np.empty((self.n_permutations, X.shape[0], X.shape[1])) - X_perm_all[:, :, non_group_ids] = X_minus_j - # n_permutations x n_samples x n_features_j + # Create an array X_perm_j of shape (n_permutations, n_samples, n_features) + # where the j-th group of covariates is (conditionally) permuted + X_perm_j = np.empty((self.n_permutations, X.shape[0], X.shape[1])) + X_perm_j[:, :, non_group_ids] = X_minus_j + # Create the permuted data for the j-th group of covariates residual_j_perm = np.array( [self.rng.permutation(residual_j) for _ in range(self.n_permutations)] ) - X_perm_all[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm - - X_perm_batch = X_perm_all.reshape(-1, X.shape[1]) + X_perm_j[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm + # Reshape X_perm_j to allow for batch prediction + X_perm_batch = X_perm_j.reshape(-1, X.shape[1]) if isinstance(X, pd.DataFrame): X_perm_batch = pd.DataFrame( X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns @@ -193,19 +195,6 @@ def _joblib_predict_one_group(imputation_model, X, j): ) return y_pred_perm - # for _ in range(self.n_permutations): - # X_j_perm = X_j_hat + self.rng.permutation(residual_j) - # X_perm = np.empty_like(X) - # X_perm[:, non_group_ids] = X_minus_j - # X_perm[:, group_ids] = X_j_perm - # if isinstance(X, pd.DataFrame): - # X_perm = pd.DataFrame(X_perm, columns=X.columns) - - # y_pred_perm = getattr(self.estimator, self.method)(X_perm) - # list_y_pred_perm.append(y_pred_perm) - - # return np.array(list_y_pred_perm) - # Parallelize the computation of the importance scores for each group out_list = Parallel(n_jobs=self.n_jobs)( delayed(_joblib_predict_one_group)(imputation_model, X, j)