diff --git a/hidimstat/permutation_importance.py b/hidimstat/permutation_importance.py index 191db45..59920b6 100644 --- a/hidimstat/permutation_importance.py +++ b/hidimstat/permutation_importance.py @@ -69,7 +69,7 @@ def fit(self, X, y=None, groups=None): self.groups = groups return self - def predict(self, X, y): + def predict(self, X, y=None): """ Compute the prediction of the model with permuted data for each group. @@ -114,17 +114,31 @@ def _joblib_predict_one_group(X, j): group_ids = self.groups[j] non_group_ids = np.delete(np.arange(X.shape[1]), group_ids) - for _ in range(self.n_permutations): - X_j_perm = self.rng.permutation(X_j) - X_perm = np.empty_like(X) - X_perm[:, non_group_ids] = X_minus_j - X_perm[:, group_ids] = X_j_perm - if isinstance(X, pd.DataFrame): - X_perm = pd.DataFrame(X_perm, columns=X.columns) - - y_pred_perm = getattr(self.estimator, self.method)(X_perm) - list_y_pred_perm.append(y_pred_perm) - return np.array(list_y_pred_perm) + # Create an array X_perm_j of shape (n_permutations, n_samples, n_features) + # where the j-th group of covariates is permuted + X_perm_j = np.empty((self.n_permutations, X.shape[0], X.shape[1])) + X_perm_j[:, :, non_group_ids] = X_minus_j + # Create the permuted data for the j-th group of covariates + group_j_permuted = np.array( + [self.rng.permutation(X_j) for _ in range(self.n_permutations)] + ) + X_perm_j[:, :, group_ids] = group_j_permuted + # Reshape X_perm_j to allow for batch prediction + X_perm_batch = X_perm_j.reshape(-1, X.shape[1]) + if isinstance(X, pd.DataFrame): + X_perm_batch = pd.DataFrame( + X_perm_batch.reshape(-1, X.shape[1]), columns=X.columns + ) + y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch) + + # In case of classification, the output is a 2D array. Reshape accordingly + if y_pred_perm.ndim == 1: + y_pred_perm = y_pred_perm.reshape(self.n_permutations, X.shape[0]) + else: + y_pred_perm = y_pred_perm.reshape( + self.n_permutations, X.shape[0], y_pred_perm.shape[1] + ) + return y_pred_perm # Parallelize the computation of the importance scores for each group out_list = Parallel(n_jobs=self.n_jobs)(