diff --git a/hidimstat/cpi.py b/hidimstat/cpi.py index 22fa8ce..19c7bf3 100644 --- a/hidimstat/cpi.py +++ b/hidimstat/cpi.py @@ -169,18 +169,42 @@ def _joblib_predict_one_group(imputation_model, X, j): X_j_hat = imputation_model.predict(X_minus_j).reshape(X_j.shape) residual_j = X_j - X_j_hat - for _ in range(self.n_permutations): - X_j_perm = X_j_hat + self.rng.permutation(residual_j) - X_perm = np.empty_like(X) - X_perm[:, non_group_ids] = X_minus_j - X_perm[:, group_ids] = X_j_perm - if isinstance(X, pd.DataFrame): - X_perm = pd.DataFrame(X_perm, columns=X.columns) - - y_pred_perm = getattr(self.estimator, self.method)(X_perm) - list_y_pred_perm.append(y_pred_perm) - - return np.array(list_y_pred_perm) + X_perm_all = np.empty((self.n_permutations, X.shape[0], X.shape[1])) + X_perm_all[:, :, non_group_ids] = X_minus_j + # n_permutations x n_samples x n_features_j + residual_j_perm = np.array( + [self.rng.permutation(residual_j) for _ in range(self.n_permutations)] + ) + X_perm_all[:, :, group_ids] = X_j_hat[np.newaxis, :, :] + residual_j_perm + + X_perm_batch = X_perm_all.reshape(-1, X.shape[1]) + if isinstance(X, pd.DataFrame): + X_perm_batch = pd.DataFrame( + X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns + ) + y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch) + + # In case of classification, the output is a 2D array. Reshape accordingly + if y_pred_perm.ndim == 1: + y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0]) + else: + y_pred_perm = y_pred_perm.reshape( + self.n_permutations, X.shape[0], y_pred_perm.shape[1] + ) + return y_pred_perm + + # for _ in range(self.n_permutations): + # X_j_perm = X_j_hat + self.rng.permutation(residual_j) + # X_perm = np.empty_like(X) + # X_perm[:, non_group_ids] = X_minus_j + # X_perm[:, group_ids] = X_j_perm + # if isinstance(X, pd.DataFrame): + # X_perm = pd.DataFrame(X_perm, columns=X.columns) + + # y_pred_perm = getattr(self.estimator, self.method)(X_perm) + # list_y_pred_perm.append(y_pred_perm) + + # return np.array(list_y_pred_perm) # Parallelize the computation of the importance scores for each group out_list = Parallel(n_jobs=self.n_jobs)(