diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 45fefcf9..83663e3d 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -7,7 +7,9 @@ from sklearn.base import TransformerMixin from scipy.linalg import pinvh try: - from sklearn.covariance import _graphical_lasso as graphical_lasso + from sklearn.covariance._graph_lasso import ( + _graphical_lasso as graphical_lasso + ) except ImportError: from sklearn.covariance import graphical_lasso @@ -83,7 +85,7 @@ def _fit(self, pairs, y): msg=self.verbose, Theta0=theta0, Sigma0=sigma0) else: - _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param, + _, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param, verbose=self.verbose, cov_init=sigma0) raised_error = None