diff --git a/docs/api.rst b/docs/api.rst index a685482c..167743b1 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -125,3 +125,20 @@ Line search :toctree: _autosummary jaxopt.BacktrackingLineSearch + +Tree utilities +-------------- + +.. autosummary:: + :toctree: _autosummary + + jaxopt.tree_util.tree_add + jaxopt.tree_util.tree_sub + jaxopt.tree_util.tree_mul + jaxopt.tree_util.tree_div + jaxopt.tree_util.tree_scalar_mul + jaxopt.tree_util.tree_add_scalar_mul + jaxopt.tree_util.tree_vdot + jaxopt.tree_util.tree_sum + jaxopt.tree_util.tree_l2_norm + jaxopt.tree_util.tree_zeros_like diff --git a/docs/changelog.rst b/docs/changelog.rst index b781d07f..74916a15 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,9 +7,10 @@ Version 0.3 New features ~~~~~~~~~~~~ -- :class:`jaxopt.LBFGS`. -- :class:`jaxopt.BacktrackingLineSearch`. -- :class:`jaxopt.GaussNewton`. +- :class:`jaxopt.LBFGS` +- :class:`jaxopt.BacktrackingLineSearch` +- :class:`jaxopt.GaussNewton` +- :class:`jaxopt.NonlinearCG` Bug fixes and enhancements ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -17,6 +18,11 @@ Bug fixes and enhancements - `Support implicit AD in higher-order differentiation `_. +Contributors +~~~~~~~~~~~~ + +Amir Saadat, Fabian Pedregosa, Geoffrey NĂ©giar, Hyunsung Lee, Mathieu Blondel, Roy Frostig. + Version 0.2 ----------- @@ -48,7 +54,7 @@ Bug fixes and enhancements Deprecations ~~~~~~~~~~~~ -- :class:`jaxopt.QuadraticProgramming` is deprecated and will be removed in v0.3. Use +- :class:`jaxopt.QuadraticProgramming` is deprecated and will be removed in v0.4. Use :class:`jaxopt.CvxpyQP`, :class:`jaxopt.OSQP`, :class:`jaxopt.BoxOSQP` and :class:`jaxopt.EqualityConstrainedQP` instead. - ``params, state = solver.init(...)`` is deprecated. Use ``state = solver.init_state(...)`` instead. diff --git a/docs/unconstrained.rst b/docs/unconstrained.rst index e1a3b1b1..6ba29fcf 100644 --- a/docs/unconstrained.rst +++ b/docs/unconstrained.rst @@ -53,7 +53,8 @@ instantiated and run as follows:: # Alternatively, we could have used one of these solvers as well: # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500) # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500) - # cg_model = jaxopt.NonlinearCG(fun=ridge_reg_objective, maxiter=300, method="polak-ribiere") + # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500) + Unpacking results ~~~~~~~~~~~~~~~~~ diff --git a/jaxopt/_src/quadratic_prog.py b/jaxopt/_src/quadratic_prog.py index e63166b7..3d279c6f 100644 --- a/jaxopt/_src/quadratic_prog.py +++ b/jaxopt/_src/quadratic_prog.py @@ -158,7 +158,7 @@ def ineq_fun(primal_var, params_ineq): @dataclass(eq=False) class QuadraticProgramming(base.Solver): - """Deprecated: will be removed in v0.3. + """Deprecated: will be removed in v0.4. Use :class:`jaxopt.CvxpyQP`, :class:`jaxopt.OSQP`, :class:`jaxopt.BoxOSQP` and :class:`jaxopt.EqualityConstrainedQP` instead. @@ -233,7 +233,7 @@ def l2_optimality_error( return tree_util.tree_l2_norm(pytree) def __post_init__(self): - warnings.warn("Class 'QuadraticProgramming' will be removed in v0.3. " + warnings.warn("Class 'QuadraticProgramming' will be removed in v0.4. " "Use 'EqualityConstraintsQP' if you want the same behavior as " "'QuadraticProgramming' for QPs with equality constraints only. " "Use 'CVXPY_QP' if you want the same behavior as " diff --git a/jaxopt/_src/tree_util.py b/jaxopt/_src/tree_util.py index 3521a643..393056cd 100644 --- a/jaxopt/_src/tree_util.py +++ b/jaxopt/_src/tree_util.py @@ -31,9 +31,16 @@ tree_unflatten = tu.tree_unflatten tree_add = functools.partial(tree_multimap, operator.add) +tree_add.__doc__ = "Tree addition." + tree_sub = functools.partial(tree_multimap, operator.sub) +tree_sub.__doc__ = "Tree subtraction." + tree_mul = functools.partial(tree_multimap, operator.mul) +tree_mul.__doc__ = "Tree multiplication." + tree_div = functools.partial(tree_multimap, operator.truediv) +tree_div.__doc__ = "Tree division." def tree_scalar_mul(scalar, tree_x):