Skip to content

Commit

Permalink
Take objective access a bit more serious. (#291)
Browse files Browse the repository at this point in the history
* take objective access a bit more serious.
* Test coverage.
* Update docs.
  • Loading branch information
kellertuer authored Sep 14, 2023
1 parent 48e334a commit 6036ba8
Show file tree
Hide file tree
Showing 18 changed files with 551 additions and 149 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Manopt"
uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
authors = ["Ronny Bergmann <[email protected]>"]
version = "0.4.34"
version = "0.4.35"

[deps]
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Expand Down
16 changes: 16 additions & 0 deletions docs/src/plans/objective.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ ManifoldCostObjective

```@docs
get_cost
```

and internally

```@docs
get_cost_function
```

Expand All @@ -117,6 +122,11 @@ ManifoldCostGradientObjective
```@docs
get_gradient
get_gradients
```

and internally

```@docs
get_gradient_function
```

Expand Down Expand Up @@ -158,6 +168,12 @@ get_hessian
get_preconditioner
```

and internally

```@docs
get_hessian_function
```

### Primal-Dual based Objectives

```@docs
Expand Down
2 changes: 1 addition & 1 deletion ext/ManoptLineSearchesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function (cs::Manopt.LineSearchesStepsize)(
p_tmp = copy(M, p)
X_tmp = copy(M, p, X)
Y_tmp = copy(M, p, X)
f = get_cost_function(get_objective(mp))
f = Manopt.get_cost_function(get_objective(mp))
dphi_0 = real(inner(M, p, X, η))

# guess initial alpha
Expand Down
11 changes: 5 additions & 6 deletions src/Manopt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,11 @@ export AlternatingGradient
#
# Accessors and helpers for AbstractManoptSolverState
export default_stepsize
export get_cost, get_cost_function
export get_gradient, get_gradient_function, get_gradient!
export get_cost, get_gradient, get_gradient!
export get_subgradient, get_subgradient!
export get_subtrahend_gradient!, get_subtrahend_gradient
export get_proximal_map,
get_proximal_map!,
get_state,
export get_proximal_map, get_proximal_map!
export get_state,
get_initial_stepsize,
get_iterate,
get_gradients,
Expand All @@ -302,7 +300,8 @@ export get_proximal_map,
forward_operator!,
get_objective
export set_manopt_parameter!
export get_hessian, get_hessian!, ApproxHessianFiniteDifference
export get_hessian, get_hessian!
export ApproxHessianFiniteDifference
export is_state_decorator, dispatch_state_decorator
export primal_residual, dual_residual
export get_constraints,
Expand Down
51 changes: 46 additions & 5 deletions src/plans/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ function get_cost(M::AbstractManifold, sco::SimpleManifoldCachedObjective, p)
end
return sco.c
end
get_cost_function(sco::SimpleManifoldCachedObjective) = (M, p) -> get_cost(M, sco, p)
function get_gradient_function(sco::SimpleManifoldCachedObjective)
return (M, p) -> get_gradient(M, sco, p)
function get_cost_function(sco::SimpleManifoldCachedObjective, recursive=false)
recursive && return get_cost_function(sco.objective, recursive)
return (M, p) -> get_cost(M, sco, p)
end

function get_gradient(M::AbstractManifold, sco::SimpleManifoldCachedObjective, p)
Expand All @@ -85,6 +85,19 @@ function get_gradient!(M::AbstractManifold, X, sco::SimpleManifoldCachedObjectiv
return X
end

function get_gradient_function(
sco::SimpleManifoldCachedObjective{AllocatingEvaluation}, recursive=false
)
recursive && (return get_gradient_function(sco.objective, recursive))
return (M, p) -> get_gradient(M, sco, p)
end
function get_gradient_function(
sco::SimpleManifoldCachedObjective{InplaceEvaluation}, recursive=false
)
recursive && (return get_gradient_function(sco.objective, recursive))
return (M, X, p) -> get_gradient!(M, X, sco, p)
end

#
# CostGradImplementation
#
Expand Down Expand Up @@ -297,8 +310,11 @@ function get_cost(M::AbstractManifold, co::ManifoldCachedObjective, p)
get_cost(M, co.objective, p)
end
end
get_cost_function(co::ManifoldCachedObjective) = (M, p) -> get_cost(M, co, p)
get_gradient_function(co::ManifoldCachedObjective) = (M, p) -> get_gradient(M, co, p)

function get_cost_function(co::ManifoldCachedObjective, recursive=false)
recursive && (return get_cost_function(co.objective, recursive))
return (M, p) -> get_cost(M, co, p)
end

function get_gradient(M::AbstractManifold, co::ManifoldCachedObjective, p)
!(haskey(co.cache, :Gradient)) && return get_gradient(M, co.objective, p)
Expand All @@ -325,6 +341,19 @@ function get_gradient!(M::AbstractManifold, X, co::ManifoldCachedObjective, p)
return X
end

function get_gradient_function(
sco::ManifoldCachedObjective{AllocatingEvaluation}, recursive=false
)
recursive && (return get_gradient_function(sco.objective, recursive))
return (M, p) -> get_gradient(M, sco, p)
end
function get_gradient_function(
sco::ManifoldCachedObjective{InplaceEvaluation}, recursive=false
)
recursive && (return get_gradient_function(sco.objective, recursive))
return (M, X, p) -> get_gradient!(M, X, sco, p)
end

#
# CostGradImplementation
function get_cost(
Expand Down Expand Up @@ -574,6 +603,18 @@ function get_hessian!(M::AbstractManifold, Y, co::ManifoldCachedObjective, p, X)
return Y
end

function get_hessian_function(
emo::ManifoldCachedObjective{AllocatingEvaluation}, recursive=false
)
recursive && (return get_hessian_function(emo.objective, recursive))
return (M, p, X) -> get_hessian(M, emo, p, X)
end
function get_hessian_function(
emo::ManifoldCachedObjective{InplaceEvaluation}, recursive=false
)
recursive && (return get_hessian_function(emo.objective, recursive))
return (M, Y, p, X) -> get_hessian!(M, Y, emo, p, X)
end
#
# Preconditioner
function get_preconditioner(M::AbstractManifold, co::ManifoldCachedObjective, p, X)
Expand Down
6 changes: 3 additions & 3 deletions src/plans/cost_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end
return the function to evaluate (just) the cost ``f(p)=c`` as a function `(M,p) -> c`.
"""
get_cost_function(mco::AbstractManifoldCostObjective) = mco.cost
function get_cost_function(admo::AbstractDecoratedManifoldObjective)
return get_cost_function(get_objective(admo, false))
get_cost_function(mco::AbstractManifoldCostObjective, recursive=false) = mco.cost
function get_cost_function(admo::AbstractDecoratedManifoldObjective, recursive=false)
return get_cost_function(get_objective(admo, recursive))
end
32 changes: 30 additions & 2 deletions src/plans/count.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ function get_cost(
c, _ = get_cost_and_gradient(M, co, p)
return c
end
get_cost_function(co::ManifoldCountObjective) = (M, p) -> get_cost(M, co, p)
function get_cost_function(co::ManifoldCountObjective, recursive=false)
recursive && return get_cost_function(co.objective, recursive)
return (M, p) -> get_cost(M, co, p)
end

function get_cost_and_gradient(M::AbstractManifold, co::ManifoldCountObjective, p)
_count_if_exists(co, :Cost)
Expand All @@ -232,7 +235,19 @@ function get_cost_and_gradient!(M::AbstractManifold, X, co::ManifoldCountObjecti
return get_cost_and_gradient!(M, X, co.objective, p)
end

get_gradient_function(co::ManifoldCountObjective) = (M, p) -> get_gradient(M, co, p)
function get_gradient_function(
sco::ManifoldCountObjective{AllocatingEvaluation}, recursive=false
)
recursive && return get_gradient_function(sco.objective, recursive)
return (M, p) -> get_gradient(M, sco, p)
end
function get_gradient_function(
sco::ManifoldCountObjective{InplaceEvaluation}, recursive=false
)
recursive && return get_gradient_function(sco.objective, recursive)
return (M, X, p) -> get_gradient!(M, X, sco, p)
end

function get_gradient(M::AbstractManifold, co::ManifoldCountObjective, p)
_count_if_exists(co, :Gradient)
return get_gradient(M, co.objective, p)
Expand Down Expand Up @@ -265,6 +280,19 @@ function get_hessian!(M::AbstractManifold, Y, co::ManifoldCountObjective, p, X)
return Y
end

function get_hessian_function(
sco::ManifoldCountObjective{AllocatingEvaluation}, recursive=false
)
recursive && return get_hessian_function(sco.objective, recursive)
return (M, p, X) -> get_hessian(M, sco, p, X)
end
function get_hessian_function(
sco::ManifoldCountObjective{InplaceEvaluation}, recursive=false
)
recursive && return get_hessian_function(sco.objective, recursive)
return (M, Y, p, X) -> get_hessian!(M, Y, sco, p, X)
end

function get_preconditioner(M::AbstractManifold, co::ManifoldCountObjective, p, X)
_count_if_exists(co, :Preconditioner)
return get_preconditioner(M, co.objective, p, X)
Expand Down
31 changes: 31 additions & 0 deletions src/plans/embedded_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ function get_cost(M::AbstractManifold, emo::EmbeddedManifoldObjective, p)
return get_cost(get_embedding(M), emo.objective, q)
end

function get_cost_function(emo::EmbeddedManifoldObjective, recursive=false)
recursive && (return get_cost_function(emo.objective, recursive))
return (M, p) -> get_cost(M, emo, p)
end
@doc raw"""
get_gradient(M::AbstractManifold, emo::EmbeddedManifoldObjective, p)
get_gradient!(M::AbstractManifold, X, emo::EmbeddedManifoldObjective, p)
Expand Down Expand Up @@ -105,6 +109,19 @@ function get_gradient!(
riemannian_gradient!(M, X, p, emo.X)
return X
end

function get_gradient_function(
emo::EmbeddedManifoldObjective{P,T,AllocatingEvaluation}, recursive=false
) where {P,T}
recursive && (return get_gradient_function(emo.objective, recursive))
return (M, p) -> get_gradient(M, emo, p)
end
function get_gradient_function(
emo::EmbeddedManifoldObjective{P,T,InplaceEvaluation}, recursive=false
) where {P,T}
recursive && (return get_gradient_function(emo.objective, recursive))
return (M, X, p) -> get_gradient!(M, X, emo, p)
end
#
# Hessian
#
Expand Down Expand Up @@ -163,6 +180,20 @@ function get_hessian!(
)
return Y
end

function get_hessian_function(
emo::EmbeddedManifoldObjective{P,T,AllocatingEvaluation}, recursive=false
) where {P,T}
recursive && (return get_hessian_function(emo.objective, recursive))
return (M, p, X) -> get_hessian(M, emo, p, X)
end
function get_hessian_function(
emo::EmbeddedManifoldObjective{P,T,InplaceEvaluation}, recursive=false
) where {P,T}
recursive && (return get_hessian_function(emo.objective, recursive))
return (M, Y, p, X) -> get_hessian!(M, Y, emo, p, X)
end

#
# Constraints
#
Expand Down
17 changes: 12 additions & 5 deletions src/plans/gradient_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,24 @@ abstract type AbstractManifoldGradientObjective{E<:AbstractEvaluationType,TC,TG}
AbstractManifoldCostObjective{E,TC} end

@doc raw"""
get_gradient_function(amgo::AbstractManifoldGradientObjective{E<:AbstractEvaluationType})
get_gradient_function(amgo::AbstractManifoldGradientObjective, recursive=false)
return the function to evaluate (just) the gradient ``\operatorname{grad} f(p)``,
where either the gradient function using the decorator or without the decorator is used.
By default `recursive` is set to `false`, since usually to just pass the gradient function
somewhere, you still want e.g. the cached one or the one that still counts calls.
return the function to evaluate (just) the gradient ``\operatorname{grad} f(p)``.
Depending on the [`AbstractEvaluationType`](@ref) `E` this is a function
* `(M, p) -> X` for the [`AllocatingEvaluation`](@ref) case
* `(M, X, p) -> X` for the [`InplaceEvaluation`](@ref), i.e. working inplace of `X`.
"""
get_gradient_function(amgo::AbstractManifoldGradientObjective) = amgo.gradient!!
function get_gradient_function(admo::AbstractDecoratedManifoldObjective)
return get_gradient_function(get_objective(admo, false))
function get_gradient_function(amgo::AbstractManifoldGradientObjective, recursive=false)
return amgo.gradient!!
end
function get_gradient_function(admo::AbstractDecoratedManifoldObjective, recursive=false)
return get_gradient_function(get_objective(admo, recursive))
end

@doc raw"""
Expand Down
14 changes: 14 additions & 0 deletions src/plans/hessian_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ function get_hessian!(
return Y
end

@doc raw"""
get_gradient_function(amgo::AbstractManifoldGradientObjective{E<:AbstractEvaluationType})
return the function to evaluate (just) the hessian ``\operatorname{Hess} f(p)``.
Depending on the [`AbstractEvaluationType`](@ref) `E` this is a function
* `(M, p, X) -> Y` for the [`AllocatingEvaluation`](@ref) case
* `(M, Y, p, X) -> X` for the [`InplaceEvaluation`](@ref), i.e. working inplace of `Y`.
"""
get_hessian_function(mho::ManifoldHessianObjective, recursive=false) = mho.hessian!!
function get_hessian_function(admo::AbstractDecoratedManifoldObjective, recursive=false)
return get_hessian_function(get_objective(admo, recursive))
end

@doc raw"""
get_preconditioner(amp::AbstractManoptProblem, p, X)
Expand Down
7 changes: 4 additions & 3 deletions src/solvers/trust_regions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,9 @@ function step_solver!(mp::AbstractManoptProblem, trs::TrustRegionsState, i)
ρ = (abs(ρnum / fx) < sqrt(eps(Float64))) ? 1 : ρnum / ρden # stability for small absolute relative model change

model_decreased = ρden 0
# Update the Hessian approximation
update_hessian!(M, mho.hessian!!, trs.p, trs.p_proposal, trs.η)
# Update the Hessian approximation - i.e. really unwrap the original Hessian function
# and update it if it is an approxiate Hessian.
update_hessian!(M, get_hessian_function(mho, true), trs.p, trs.p_proposal, trs.η)
# Choose the new TR radius based on the model performance.
# If the actual decrease is smaller than reduction_threshold of the predicted decrease,
# then reduce the TR radius.
Expand All @@ -569,7 +570,7 @@ function step_solver!(mp::AbstractManoptProblem, trs::TrustRegionsState, i)
if model_decreased &&
> trs.ρ_prime || (abs((ρnum) / (abs(fx) + 1)) < sqrt(eps(Float64)) && 0 < ρnum))
copyto!(trs.p, trs.p_proposal)
update_hessian_basis!(M, mho.hessian!!, trs.p)
update_hessian_basis!(M, get_hessian_function(mho, true), trs.p)
end
return trs
end
Loading

2 comments on commit 6036ba8

@kellertuer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91411

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.35 -m "<description of version>" 6036ba8d2779907848acd1b11c8185bc1aac1649
git push origin v0.4.35

Please sign in to comment.