diff --git a/hidimstat/loco.py b/hidimstat/loco.py index a5ae3c3..7c84644 100644 --- a/hidimstat/loco.py +++ b/hidimstat/loco.py @@ -116,6 +116,8 @@ def predict(self, X, y): - 'importance': the importance scores for each group. """ check_is_fitted(self.estimator) + if len(self._list_estimators) == 0: + raise ValueError("fit must be called before predict") for m in self._list_estimators: check_is_fitted(m) @@ -168,6 +170,7 @@ def score(self, X, y): the permuted data for each group. - 'importance': the importance scores for each group. """ + check_is_fitted(self.estimator) if len(self._list_estimators) == 0: raise ValueError("fit must be called before predict") for m in self._list_estimators: