From e4396001f1cc4f3e2b8a4657c753095a9c7f4c1c Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Thu, 11 May 2023 11:36:35 +0200 Subject: [PATCH] Change defaults for checking vectors/linearity/symmetry in check_ functions (#247) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Set the vector check default to false – at least until we have a nicer control about tolerances. * Change the format a bit, such that the kwargs are passed down to `is_vector` and tolerances can be set. * bump version. --- Project.toml | 2 +- src/helpers/checks.jl | 44 +++++++++++++++++++++++++++++++++---------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index a1ecadf835..2a78a4d3d7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Manopt" uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" authors = ["Ronny Bergmann "] -version = "0.4.19" +version = "0.4.20" [deps] ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" diff --git a/src/helpers/checks.jl b/src/helpers/checks.jl index 678b4eec5c..258e88d59f 100644 --- a/src/helpers/checks.jl +++ b/src/helpers/checks.jl @@ -67,8 +67,9 @@ function prepare_check_result( throw_error && throw(ErrorException(msg)) return false end + @doc raw""" - check_differential(M, F, dF, p=rand(M), X=rand(M; vector_at=p)) + check_differential(M, F, dF, p=rand(M), X=rand(M; vector_at=p); kwargs...) Check numerivcally whether the differential `dF(M,p,X)` of `F(M,p)` is correct. @@ -167,10 +168,13 @@ no plot will be generated. * `plot` - (`false`) whether to plot the resulting check (if `Plots.jl` is loaded). The plot is in log-log-scale. This is returned and can then also be saved. * `retraction_method` - (`default_retraction_method(M, typeof(p))`) retraction method to use for the check * `slope_tol` – (`0.1`) tolerance for the slope (global) of the approximation +* `atol`, `rtol` – (same defaults as `isapprox`) tolerances that are passed down to `is_vector` if `check_vector` is set to `true` * `throw_error` - (`false`) throw an error message if the gradient is wrong * `window` – (`nothing`) specify window sizes within the `log_range` that are used for the slope estimation. the default is, to use all window sizes `2:N`. +The `kwargs...` are also passed down to the `check_vector` call, such that tolerances can +easily be set. """ function check_gradient( @@ -180,11 +184,14 @@ function check_gradient( p=rand(M), X=rand(M; vector_at=p); gradient=grad_f(M, p), - check_vector=true, + check_vector=false, throw_error=false, + atol::Real=0, + rtol::Real=atol > 0 ? 0 : sqrt(eps(eltype(p))), kwargs..., ) - check_vector && (!is_vector(M, p, gradient, throw_error;) && return false) + check_vector && + (!is_vector(M, p, gradient, throw_error; atol=atol, rtol=rtol) && return false) # function for the directional derivative - real so it also works on complex manifolds df(M, p, Y) = real(inner(M, p, gradient, Y)) return check_differential( @@ -214,15 +221,16 @@ no plot will be generated. # Keyword arguments -* `check_grad` – (`true`) check whether ``\operatorname{grad} f(p) \in T_p\mathcal M``. +* `check_grad` – (`true`) check whether ``\operatorname{grad} f(p) \in T_p\mathcal M``. * `check_linearity` – (`true`) check whether the Hessian is linear, see [`is_Hessian_linear`](@ref) using `a`, `b`, `X`, and `Y` * `check_symmetry` – (`true`) check whether the Hessian is symmetric, see [`is_Hessian_symmetric`](@ref) -* `check_vector` – (`true`) check whether ``\operatorname{Hess} f(p)[X] \in T_p\mathcal M`` using `is_vector`. +* `check_vector` – (`false`) check whether ``\operatorname{Hess} f(p)[X] \in T_p\mathcal M`` using `is_vector`. * `mode` - (`:Default`) specify the mode, by default we assume to have a second order retraction given by `retraction_method=` you can also this method if you already _have_ a cirtical point `p`. Set to `:CritalPoint` to use [`gradient_descent`](@ref) to find a critical point. Note: This requires (and evaluates) new tangent vectors `X` and `Y` +* `atol`, `rtol` – (same defaults as `isapprox`) tolerances that are passed down to all checks * `a`, `b` – two real values to check linearity of the Hessian (if `check_linearity=true`) * `N` - (`101`) number of points to check within the `log_range` default range ``[10^{-8},10^{0}]`` * `exactness_tol` - (`1e-12`) if all errors are below this tolerance, the check is considered to be exact @@ -239,6 +247,9 @@ no plot will be generated. * `throw_error` - (`false`) throw an error message if the Hessian is wrong * `window` – (`nothing`) specify window sizes within the `log_range` that are used for the slope estimation. the default is, to use all window sizes `2:N`. + +The `kwargs...` are also passed down to the `check_vector` call, such that tolerances can +easily be set. """ function check_Hessian( M::AbstractManifold, @@ -249,9 +260,10 @@ function check_Hessian( X=rand(M; vector_at=p), Y=rand(M; vector_at=p); a=randn(), + atol::Real=0, b=randn(), check_grad=true, - check_vector=true, + check_vector=false, check_symmetry=true, check_linearity=true, exactness_tol=1e-12, @@ -264,6 +276,7 @@ function check_Hessian( log_range=range(limits[1], limits[2]; length=N), plot=false, retraction_method=default_retraction_method(M, typeof(p)), + rtol::Real=atol > 0 ? 0 : sqrt(eps(eltype(p))), slope_tol=0.1, throw_error=false, window=nothing, @@ -271,22 +284,33 @@ function check_Hessian( ) if check_grad if !check_gradient( - M, f, grad_f, p, X; gradient=gradient, throw_error=throw_error, io=io, kwargs... + M, + f, + grad_f, + p, + X; + gradient=gradient, + throw_error=throw_error, + io=io, + atol=atol, + rtol=rtol, + kwargs..., ) return false end end - check_vector && (!is_vector(M, p, Hessian, throw_error) && return false) + check_vector && + (!is_vector(M, p, Hessian, throw_error; atol=atol, rtol=rtol) && return false) if check_linearity if !is_Hessian_linear( - M, Hess_f, p, X, Y, a, b; throw_error=throw_error, io=io, kwargs... + M, Hess_f, p, X, Y, a, b; throw_error=throw_error, io=io, atol=atol, rtol=rtol ) return false end end if check_symmetry if !is_Hessian_symmetric( - M, Hess_f, p, X, Y; throw_error=throw_error, io=io, kwargs... + M, Hess_f, p, X, Y; throw_error=throw_error, io=io, atol=atol, rtol=rtol ) return false end