diff --git a/benchmarks/lbfgs_benchmark.py b/benchmarks/lbfgs_benchmark.py index 0c1d05a8..f104125e 100644 --- a/benchmarks/lbfgs_benchmark.py +++ b/benchmarks/lbfgs_benchmark.py @@ -14,14 +14,12 @@ """Benchmark LBFGS implementation.""" -import time from absl import app from absl import flags from sklearn import datasets -import jax import jax.numpy as jnp import jaxopt diff --git a/benchmarks/proximal_gradient_benchmark.py b/benchmarks/proximal_gradient_benchmark.py index 0a9b54a2..7cffb333 100644 --- a/benchmarks/proximal_gradient_benchmark.py +++ b/benchmarks/proximal_gradient_benchmark.py @@ -16,7 +16,6 @@ import time -from typing import NamedTuple from typing import Sequence from absl import app diff --git a/examples/deep_learning/plot_sgd_solvers.py b/examples/deep_learning/plot_sgd_solvers.py index e9e711b0..679c8c62 100644 --- a/examples/deep_learning/plot_sgd_solvers.py +++ b/examples/deep_learning/plot_sgd_solvers.py @@ -38,7 +38,6 @@ from absl import flags -import logging import sys from timeit import default_timer as timer @@ -49,7 +48,6 @@ from jaxopt import ArmijoSGD from jaxopt import PolyakSGD from jaxopt import OptaxSolver -from jaxopt.tree_util import tree_l2_norm, tree_sub import optax from flax import linen as nn diff --git a/examples/fixed_point/deep_equilibrium_model.py b/examples/fixed_point/deep_equilibrium_model.py index 0faf2455..40c8d01d 100644 --- a/examples/fixed_point/deep_equilibrium_model.py +++ b/examples/fixed_point/deep_equilibrium_model.py @@ -34,19 +34,16 @@ """ from functools import partial -from typing import Any, Mapping, Tuple, Callable +from typing import Any, Tuple, Callable from absl import app from absl import flags -import flax from flax import linen as nn import jax import jax.numpy as jnp -from jax.tree_util import tree_structure -import jaxopt from jaxopt import loss from jaxopt import OptaxSolver from jaxopt import FixedPointIteration @@ -58,7 +55,6 @@ import tensorflow_datasets as tfds import tensorflow as tf -from collections import namedtuple dataset_names = [ diff --git a/examples/fixed_point/plot_picard_ode.py b/examples/fixed_point/plot_picard_ode.py index f9289f66..ea2ceb85 100644 --- a/examples/fixed_point/plot_picard_ode.py +++ b/examples/fixed_point/plot_picard_ode.py @@ -58,12 +58,9 @@ from jaxopt import AndersonAcceleration -from jaxopt import objective -from jaxopt.tree_util import tree_scalar_mul, tree_sub import numpy as np import matplotlib.pyplot as plt -from sklearn import datasets from matplotlib.pyplot import cm import scipy.integrate diff --git a/examples/implicit_diff/lasso_implicit_diff.py b/examples/implicit_diff/lasso_implicit_diff.py index a66c34cd..f6da6164 100644 --- a/examples/implicit_diff/lasso_implicit_diff.py +++ b/examples/implicit_diff/lasso_implicit_diff.py @@ -20,7 +20,6 @@ from absl import app from absl import flags -import jax import jax.numpy as jnp from jaxopt import BlockCoordinateDescent diff --git a/jaxopt/_src/anderson.py b/jaxopt/_src/anderson.py index 15f0c114..ef153bb2 100644 --- a/jaxopt/_src/anderson.py +++ b/jaxopt/_src/anderson.py @@ -17,7 +17,6 @@ from typing import Any from typing import Callable from typing import NamedTuple -from typing import List from typing import Union from typing import Optional @@ -27,12 +26,10 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src import linear_solve -from jaxopt._src.tree_util import tree_l2_norm, tree_sub -from jaxopt._src.tree_util import tree_vdot, tree_add -from jaxopt._src.tree_util import tree_mul, tree_scalar_mul +from jaxopt._src.tree_util import tree_sub +from jaxopt._src.tree_util import tree_vdot from jaxopt._src.tree_util import tree_average, tree_add_scalar_mul -from jaxopt._src.tree_util import tree_map, tree_gram +from jaxopt._src.tree_util import tree_map def minimize_residuals(residual_gram, ridge): diff --git a/jaxopt/_src/anderson_wrapper.py b/jaxopt/_src/anderson_wrapper.py index dea00a0e..8f3dbe4b 100644 --- a/jaxopt/_src/anderson_wrapper.py +++ b/jaxopt/_src/anderson_wrapper.py @@ -26,8 +26,7 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src.tree_util import tree_l2_norm, tree_sub, tree_map -from jaxopt._src.anderson import AndersonAcceleration +from jaxopt._src.tree_util import tree_sub, tree_map from jaxopt._src.anderson import anderson_step, update_history diff --git a/jaxopt/_src/backtracking_linesearch.py b/jaxopt/_src/backtracking_linesearch.py index 72db4722..03ebc7e6 100644 --- a/jaxopt/_src/backtracking_linesearch.py +++ b/jaxopt/_src/backtracking_linesearch.py @@ -22,7 +22,6 @@ from dataclasses import dataclass -import jax import jax.numpy as jnp from jaxopt._src import base diff --git a/jaxopt/_src/base.py b/jaxopt/_src/base.py index e3b7a507..d5130135 100644 --- a/jaxopt/_src/base.py +++ b/jaxopt/_src/base.py @@ -35,7 +35,6 @@ # jaxopt._src.linear_solve instead. # This allows to define linear solver with base.Solver interface, # and then exporting them in jaxopt.linear_solve. -from jaxopt._src import linear_solve from jaxopt import loop from jaxopt import tree_util diff --git a/jaxopt/_src/bisection.py b/jaxopt/_src/bisection.py index 943a20d3..97f259ca 100644 --- a/jaxopt/_src/bisection.py +++ b/jaxopt/_src/bisection.py @@ -25,8 +25,6 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src import implicit_diff as idf -from jaxopt._src import loop class BisectionState(NamedTuple): diff --git a/jaxopt/_src/block_cd.py b/jaxopt/_src/block_cd.py index d77eb4d9..f1616ff3 100644 --- a/jaxopt/_src/block_cd.py +++ b/jaxopt/_src/block_cd.py @@ -28,10 +28,7 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt._src import implicit_diff as idf -from jaxopt._src import loop from jaxopt._src import objective -from jaxopt._src import tree_util class BlockCDState(NamedTuple): diff --git a/jaxopt/_src/cd_qp.py b/jaxopt/_src/cd_qp.py index bd994522..e1598796 100644 --- a/jaxopt/_src/cd_qp.py +++ b/jaxopt/_src/cd_qp.py @@ -26,7 +26,6 @@ from jaxopt._src import base from jaxopt._src import projection -from jaxopt._src import tree_util class BoxCDQPState(NamedTuple): diff --git a/jaxopt/_src/cvxpy_wrapper.py b/jaxopt/_src/cvxpy_wrapper.py index 4a0a246e..cbe99449 100644 --- a/jaxopt/_src/cvxpy_wrapper.py +++ b/jaxopt/_src/cvxpy_wrapper.py @@ -13,18 +13,14 @@ # limitations under the License. """CVXPY wrappers.""" -from typing import Any from typing import Callable from typing import Optional -from typing import Tuple from dataclasses import dataclass -import jax import jax.numpy as jnp from jaxopt._src import base from jaxopt._src import implicit_diff as idf -from jaxopt._src import linear_solve from jaxopt._src import tree_util diff --git a/jaxopt/_src/fixed_point_iteration.py b/jaxopt/_src/fixed_point_iteration.py index 24190449..ede40516 100644 --- a/jaxopt/_src/fixed_point_iteration.py +++ b/jaxopt/_src/fixed_point_iteration.py @@ -23,7 +23,6 @@ from dataclasses import dataclass import jax.numpy as jnp -from jax.tree_util import tree_leaves, tree_structure from jaxopt._src import base from jaxopt._src.tree_util import tree_l2_norm, tree_sub diff --git a/jaxopt/_src/gradient_descent.py b/jaxopt/_src/gradient_descent.py index bbaf100f..e17d4db6 100644 --- a/jaxopt/_src/gradient_descent.py +++ b/jaxopt/_src/gradient_descent.py @@ -15,9 +15,7 @@ """Implementation of gradient descent in JAX.""" from typing import Any -from typing import Callable from typing import NamedTuple -from typing import Union from dataclasses import dataclass diff --git a/jaxopt/_src/iterative_refinement.py b/jaxopt/_src/iterative_refinement.py index 70667aad..cb32bfb6 100644 --- a/jaxopt/_src/iterative_refinement.py +++ b/jaxopt/_src/iterative_refinement.py @@ -28,15 +28,11 @@ from dataclasses import dataclass from functools import partial -import jax import jax.numpy as jnp -from jaxopt._src import loop from jaxopt._src import base -from jaxopt._src import implicit_diff as idf from jaxopt._src.tree_util import tree_zeros_like, tree_add, tree_sub -from jaxopt._src.tree_util import tree_add_scalar_mul, tree_scalar_mul -from jaxopt._src.tree_util import tree_vdot, tree_negative, tree_l2_norm +from jaxopt._src.tree_util import tree_l2_norm from jaxopt._src.linear_operator import _make_linear_operator import jaxopt._src.linear_solve as linear_solve diff --git a/jaxopt/_src/lbfgs.py b/jaxopt/_src/lbfgs.py index 1be37e1b..78c3b54b 100644 --- a/jaxopt/_src/lbfgs.py +++ b/jaxopt/_src/lbfgs.py @@ -17,7 +17,6 @@ import warnings from dataclasses import dataclass -from functools import partial from typing import Any, Callable, NamedTuple, Optional, Union import jax diff --git a/jaxopt/_src/levenberg_marquardt.py b/jaxopt/_src/levenberg_marquardt.py index 1a3646ac..b02f5520 100644 --- a/jaxopt/_src/levenberg_marquardt.py +++ b/jaxopt/_src/levenberg_marquardt.py @@ -14,7 +14,6 @@ """Levenberg-Marquardt algorithm in JAX.""" -import math from typing import Any from typing import Callable from typing import Literal @@ -35,7 +34,7 @@ from jaxopt._src.linear_solve import solve_inv from jaxopt._src.linear_solve import solve_lu from jaxopt._src.linear_solve import solve_qr -from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm, tree_sub, tree_add, tree_mul +from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm class LevenbergMarquardtState(NamedTuple): diff --git a/jaxopt/_src/linear_operator.py b/jaxopt/_src/linear_operator.py index a7d20313..72379e20 100644 --- a/jaxopt/_src/linear_operator.py +++ b/jaxopt/_src/linear_operator.py @@ -16,9 +16,8 @@ import functools import jax import jax.numpy as jnp -import numpy as onp -from jaxopt.tree_util import tree_map, tree_sum, tree_mul +from jaxopt.tree_util import tree_map class DenseLinearOperator: diff --git a/jaxopt/_src/nonlinear_cg.py b/jaxopt/_src/nonlinear_cg.py index 398f3db2..fa9f412e 100644 --- a/jaxopt/_src/nonlinear_cg.py +++ b/jaxopt/_src/nonlinear_cg.py @@ -25,7 +25,6 @@ from jaxopt._src.linesearch_util import _setup_linesearch from jaxopt._src.tree_util import tree_single_dtype, get_real_dtype from jaxopt.tree_util import tree_add_scalar_mul -from jaxopt.tree_util import tree_div from jaxopt.tree_util import tree_l2_norm from jaxopt.tree_util import tree_scalar_mul from jaxopt.tree_util import tree_sub diff --git a/jaxopt/_src/osqp.py b/jaxopt/_src/osqp.py index a62c7f63..9d15e5bf 100644 --- a/jaxopt/_src/osqp.py +++ b/jaxopt/_src/osqp.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """GPU-friendly implementation of OSQP.""" -from abc import ABC, abstractmethod +from abc import abstractmethod from dataclasses import dataclass from functools import partial @@ -24,7 +24,6 @@ from typing import Union import jax -import jax.nn as nn import jax.numpy as jnp from jax.tree_util import tree_reduce diff --git a/jaxopt/_src/scipy_wrappers.py b/jaxopt/_src/scipy_wrappers.py index d7c57825..fcd116e5 100644 --- a/jaxopt/_src/scipy_wrappers.py +++ b/jaxopt/_src/scipy_wrappers.py @@ -21,8 +21,6 @@ # currently only ScipyMinimize exposes this option. """ -import abc -import dataclasses from dataclasses import dataclass from typing import Any from typing import Callable diff --git a/jaxopt/_src/test_util.py b/jaxopt/_src/test_util.py index 799cfa38..5b41fe33 100644 --- a/jaxopt/_src/test_util.py +++ b/jaxopt/_src/test_util.py @@ -14,7 +14,6 @@ """Test utilities.""" -from absl.testing import absltest from absl.testing import parameterized import functools @@ -22,8 +21,6 @@ import jax import jax.numpy as jnp -from jaxopt._src import base -from jaxopt._src import loss import numpy as onp import scipy as osp diff --git a/jaxopt/_src/zoom_linesearch.py b/jaxopt/_src/zoom_linesearch.py index de732e67..96b790c9 100644 --- a/jaxopt/_src/zoom_linesearch.py +++ b/jaxopt/_src/zoom_linesearch.py @@ -34,7 +34,6 @@ from jaxopt.tree_util import tree_scalar_mul from jaxopt.tree_util import tree_vdot_real from jaxopt.tree_util import tree_conj -from jaxopt.tree_util import tree_l2_norm # pylint: disable=g-bare-generic # pylint: disable=invalid-name diff --git a/tests/anderson_test.py b/tests/anderson_test.py index ec43dade..cf960f2f 100644 --- a/tests/anderson_test.py +++ b/tests/anderson_test.py @@ -19,18 +19,14 @@ import jax import jax.numpy as jnp from jax import scipy as jsp -from jax.tree_util import tree_map, tree_all from jax.test_util import check_grads import jaxopt -from jaxopt.tree_util import tree_l2_norm, tree_scalar_mul -from jaxopt._src.tree_util import tree_average, tree_sub -from jaxopt import objective +from jaxopt.tree_util import tree_l2_norm from jaxopt import AndersonAcceleration from jaxopt._src import test_util import numpy as onp -from sklearn import datasets class AndersonAccelerationTest(test_util.JaxoptTestCase): diff --git a/tests/anderson_wrapper_test.py b/tests/anderson_wrapper_test.py index f9a58fa6..19e11b14 100644 --- a/tests/anderson_wrapper_test.py +++ b/tests/anderson_wrapper_test.py @@ -14,31 +14,22 @@ from absl.testing import absltest -from absl.testing import parameterized -import jax import jax.numpy as jnp -from jax import config -from jax.tree_util import tree_map, tree_all from jax.test_util import check_grads import optax -from jaxopt.tree_util import tree_l2_norm, tree_scalar_mul, tree_sub from jaxopt import objective -from jaxopt import projection from jaxopt import prox from jaxopt._src import test_util from jaxopt import AndersonWrapper from jaxopt import BlockCoordinateDescent -from jaxopt import GradientDescent from jaxopt import OptaxSolver from jaxopt import PolyakSGD from jaxopt import ProximalGradient -import numpy as onp -import scipy from sklearn import datasets diff --git a/tests/base_test.py b/tests/base_test.py index 42305c71..fca39ba7 100644 --- a/tests/base_test.py +++ b/tests/base_test.py @@ -25,7 +25,6 @@ from typing import Any from typing import Callable from typing import NamedTuple -from typing import Optional import dataclasses @@ -33,12 +32,6 @@ import jax.numpy as jnp from jaxopt._src import base -from jaxopt.tree_util import tree_add -from jaxopt.tree_util import tree_add_scalar_mul -from jaxopt.tree_util import tree_l2_norm -from jaxopt.tree_util import tree_scalar_mul -from jaxopt.tree_util import tree_sub -from jaxopt.tree_util import tree_zeros_like class DummySolverState(NamedTuple): diff --git a/tests/common_test.py b/tests/common_test.py index 45d71f01..4412e031 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -25,7 +25,6 @@ import jax import jax.numpy as jnp -import numpy as onp import optax from sklearn import datasets from sklearn import preprocessing @@ -39,7 +38,6 @@ from jaxopt._src.base import LinearOperator from jaxopt._src.objective import LeastSquares, MulticlassLinearSvmDual from jaxopt._src.osqp import BoxOSQP -from jaxopt._src.osqp import OSQP N_CALLS = 0 diff --git a/tests/cvxpy_wrapper_test.py b/tests/cvxpy_wrapper_test.py index 9d325bc9..85df5bfd 100644 --- a/tests/cvxpy_wrapper_test.py +++ b/tests/cvxpy_wrapper_test.py @@ -15,7 +15,6 @@ """CVXPY tests.""" from absl.testing import absltest -from absl.testing import parameterized import jax import jax.numpy as jnp diff --git a/tests/eq_qp_test.py b/tests/eq_qp_test.py index c103e96a..b209c21f 100644 --- a/tests/eq_qp_test.py +++ b/tests/eq_qp_test.py @@ -22,7 +22,6 @@ from jaxopt.base import KKTSolution from jaxopt import EqualityConstrainedQP from jaxopt._src import test_util -from jaxopt._src.tree_util import tree_negative import numpy as onp diff --git a/tests/fixed_point_iteration_test.py b/tests/fixed_point_iteration_test.py index 630739a0..b6da9012 100644 --- a/tests/fixed_point_iteration_test.py +++ b/tests/fixed_point_iteration_test.py @@ -17,16 +17,12 @@ import jax import jax.numpy as jnp -from jax.tree_util import tree_map, tree_all from jax.test_util import check_grads -from jaxopt.tree_util import tree_l2_norm, tree_scalar_mul, tree_sub -from jaxopt import objective +from jaxopt.tree_util import tree_l2_norm from jaxopt import FixedPointIteration from jaxopt._src import test_util -import numpy as onp -from sklearn import datasets class FixedPointIterationTest(test_util.JaxoptTestCase): diff --git a/tests/implicit_diff_test.py b/tests/implicit_diff_test.py index ac928b0c..ff204984 100644 --- a/tests/implicit_diff_test.py +++ b/tests/implicit_diff_test.py @@ -15,7 +15,6 @@ import functools from absl.testing import absltest -from absl.testing import parameterized import jax import jax.numpy as jnp diff --git a/tests/isotonic_test.py b/tests/isotonic_test.py index 0fc027e7..e477dbaf 100644 --- a/tests/isotonic_test.py +++ b/tests/isotonic_test.py @@ -19,7 +19,6 @@ import jax import jax.numpy as jnp -import numpy as onp from jax.test_util import check_grads from jaxopt.isotonic import isotonic_l2_pav diff --git a/tests/linear_operator_test.py b/tests/linear_operator_test.py index 599b888e..2e24007f 100644 --- a/tests/linear_operator_test.py +++ b/tests/linear_operator_test.py @@ -15,7 +15,6 @@ from absl.testing import absltest -import jax import jax.numpy as jnp import numpy as onp diff --git a/tests/projected_gradient_test.py b/tests/projected_gradient_test.py index 8737b8ac..4e16164f 100644 --- a/tests/projected_gradient_test.py +++ b/tests/projected_gradient_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools from absl.testing import absltest from absl.testing import parameterized diff --git a/tests/proximal_gradient_test.py b/tests/proximal_gradient_test.py index 2c0279da..ef393da3 100644 --- a/tests/proximal_gradient_test.py +++ b/tests/proximal_gradient_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools from typing import Callable from absl.testing import absltest @@ -21,16 +20,12 @@ import jax import jax.numpy as jnp -from jaxopt import implicit_diff from jaxopt import objective -from jaxopt import projection from jaxopt import prox from jaxopt import ProximalGradient from jaxopt._src import test_util -from jaxopt import tree_util as tu from sklearn import datasets -from sklearn import preprocessing def make_stepsize_schedule(max_stepsize, n_steps, power=1.0) -> Callable: