-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial sketch of the count objective * Starts a count objective. * Add tests for the first sketch. * first sketches for the new cache. * Starts the LRU Cache * Sketiching Counters further. * Rework decorators to parametrise the most inner as well to allow for a few tricky dispatches. Now the counts even with caches work superb. * 🚀 Tests * Sketch an option to also return the objective. * Work on the docs. * Test Coverage * Adds all remaining things that can be counted on the objective * Apply suggestions from code review * Update Changelog.md --------- Co-authored-by: Mateusz Baran <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Loading branch information
1 parent
b6880f8
commit b3c826e
Showing
91 changed files
with
3,768 additions
and
959 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Manopt" | ||
uuid = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" | ||
authors = ["Ronny Bergmann <[email protected]>"] | ||
version = "0.4.20" | ||
version = "0.4.21" | ||
|
||
[deps] | ||
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" | ||
|
@@ -17,7 +17,6 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" | |
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | ||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
|
@@ -26,19 +25,32 @@ ColorSchemes = "3.5.0" | |
ColorTypes = "0.9.1, 0.10, 0.11" | ||
Colors = "0.11.2, 0.12" | ||
DataStructures = "0.17, 0.18" | ||
LRUCache = "1.4" | ||
ManifoldDiff = "0.2, 0.3" | ||
Manifolds = "0.8.57" | ||
ManifoldsBase = "0.14.4" | ||
Requires = "0.5, 1" | ||
StaticArrays = "0.12, 1.0" | ||
julia = "1.6" | ||
|
||
[extensions] | ||
ManoptLRUCacheExt = "LRUCache" | ||
ManoptLineSearchesExt = "LineSearches" | ||
ManoptManifoldsExt = ["Manifolds"] | ||
ManoptPlotsExt = "Plots" | ||
|
||
[extras] | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" | ||
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" | ||
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" | ||
|
||
[targets] | ||
test = ["Test", "ForwardDiff", "Manifolds", "Plots", "LineSearches"] | ||
test = ["Test", "ForwardDiff", "Manifolds", "Plots", "LineSearches", "LRUCache"] | ||
|
||
[weakdeps] | ||
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" | ||
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" | ||
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,4 +23,4 @@ format: | |
variant: -raw_html | ||
wrap: none | ||
|
||
jupyter: julia-1.8 | ||
jupyter: julia-1.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# How to Count and Cache Function Calls | ||
Ronny Bergmann | ||
|
||
In this tutorial, we want to investigate the caching and counting (i.e. statistics) features | ||
of [Manopt.jl](https://manoptjl.org). We will reuse the optimization tasks from the | ||
introductionary tutorial [Get Started: Optimize!](https://manoptjl.org/stable/tutorials/Optimize!.html). | ||
|
||
## Introduction | ||
|
||
There are surely many ways to keep track for example of how often the cost function is called, | ||
for example with a [functor](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects), as we used in an example in [How to Record Data](https://manoptjl.org/stable/tutorials/HowtoRecord.html) | ||
|
||
``` julia | ||
mutable struct MyCost{I<:Integer} | ||
count::I | ||
end | ||
MyCost() = MyCost{Int64}(0) | ||
function (c::MyCost)(M, x) | ||
c.count += 1 | ||
# [ .. Actual implementation of the cost here ] | ||
end | ||
``` | ||
|
||
This still leaves a bit of work to the user, especially for tracking more than just the number of cost function evaluations. | ||
|
||
When the a function like objective or gradient is expensive to compute, it may make sense to cache its results. | ||
Manopt.jl tries to minimize the number of repeated calls but sometimes they are necessary and harmless when the function is cheap to compute. | ||
Caching of expensive function calls can for example be added using [Memoize.jl](https://github.com/JuliaCollections/Memoize.jl) by the user. | ||
The approach in the solvers of [Manopt.jl](https://manoptjl.org) aims to simplify adding | ||
both these capabilities on the level of calling a solver. | ||
|
||
## Technical Background | ||
|
||
The two ingdredients for a solver in [Manopt.jl](https://manoptjl.org) | ||
are the [`AbstractManoptProblem`](@ref) and the [`AbstractManoptSolverState`](@ref), where the | ||
former consists of the domain, that is the [manifold](https://juliamanifolds.github.io/ManifoldsBase.jl/stable/types.html#The-AbstractManifold) and [`AbstractManifoldObjective`](@ref). | ||
|
||
Both recording and debug capabilities are implemented in a decorator pattern to the solver state. | ||
They can be easily added using the `record=` and `debug=` in any solver call. | ||
This pattern was recently extended, such that also the objective can be decorated. | ||
This is how both caching and counting are implemented, as decorators of the [`AbstractManifoldObjective`](@ref) | ||
and hence for example changing/extending the behaviour of a call to [`get_cost`](@ref). | ||
|
||
Let’s finish off the technical background by loading the necessary packages. | ||
Besides [Manopt.jl](https://manoptjl.org) and [Manifolds.jl](https://juliamanifolds.github.io/Manifolds.jl/latest/) we also need | ||
[LRUCaches.jl](https://github.com/JuliaCollections/LRUCache.jl) which are (since Julia 1.9) a weak dependency and provide | ||
the *least recently used* strategy for our caches. | ||
|
||
``` julia | ||
using Manopt, Manifolds, Random, LRUCache | ||
``` | ||
|
||
## Counting | ||
|
||
We first define our task, the Riemannian Center of Mass from the [Get Started: Optimize!](https://manoptjl.org/stable/tutorials/Optimize!.html) tutorial. | ||
|
||
``` julia | ||
n = 100 | ||
σ = π / 8 | ||
M = Sphere(2) | ||
p = 1 / sqrt(2) * [1.0, 0.0, 1.0] | ||
data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n]; | ||
f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2) | ||
grad_f(M, p) = sum(1 / n * grad_distance.(Ref(M), data, Ref(p))); | ||
``` | ||
|
||
to now count how often the cost and the gradient are called, we use the `count=` keyword | ||
argument that works in any solver to specify the elements of the objective whose calls we | ||
want to count calls to. A full list is available in the documentation of the | ||
[`AbstractManifoldObjective`](@ref). | ||
To also see the result, we have to set `return_objective=true`. | ||
This returns `(objective, p)` instead of just the solver result `p`. | ||
We can further also set `return_state=true` to get even more information about the solver run. | ||
|
||
``` julia | ||
gradient_descent(M, f, grad_f, data[1]; count=[:Cost, :Gradient], return_objective=true, return_state=true) | ||
``` | ||
|
||
# Solver state for `Manopt.jl`s Gradient Descent | ||
After 72 iterations | ||
|
||
## Parameters | ||
* retraction method: ExponentialRetraction() | ||
|
||
## Stepsize | ||
ArmijoLineseach() with keyword parameters | ||
* initial_stepsize = 1.0 | ||
* retraction_method = ExponentialRetraction() | ||
* contraction_factor = 0.95 | ||
* sufficient_decrease = 0.1 | ||
|
||
## Stopping Criterion | ||
Stop When _one_ of the following are fulfilled: | ||
Max Iteration 200: not reached | ||
|grad f| < 1.0e-9: reached | ||
Overall: reached | ||
This indicates convergence: Yes | ||
|
||
## Statistics on function calls | ||
* :Gradient : 217 | ||
* :Cost : 298 | ||
on a ManifoldGradientObjective{AllocatingEvaluation} | ||
|
||
And we see that statistics are shown in the end. To now also cache these calls, | ||
we can use the `cache=` keyword argument. | ||
Since now both the cache and the count “extend” the functionality of the objective, | ||
the order is important: On the high-level interface, the `count` is treated first, which | ||
means that only actual function calls and not cache look-ups are counted. | ||
With the proper initialisation, you can use any caches here that support the | ||
`get!(function, cache, key)!` update. All parts of the objective that can currently be cached are listed at [`ManifoldCachedObjective`](@ref). The solver call has a keyword `cache` that takes a tuple`(c, vs, n)` of three arguments, where `c` is a symbol for the type of cache, `vs` is a vector of symbols, which calls to cache and `n` is the size of the cache. If the last element is not provided, a suitable default (currently`n=10`) is used. | ||
|
||
Here we want to use `c=:LRU` caches for `vs=[Cost, :Gradient]` with a size of `n=25`. | ||
|
||
``` julia | ||
r = gradient_descent(M, f, grad_f, data[1]; | ||
count=[:Cost, :Gradient], | ||
cache=(:LRU, [:Cost, :Gradient], 25), | ||
return_objective=true, return_state=true) | ||
``` | ||
|
||
# Solver state for `Manopt.jl`s Gradient Descent | ||
After 72 iterations | ||
|
||
## Parameters | ||
* retraction method: ExponentialRetraction() | ||
|
||
## Stepsize | ||
ArmijoLineseach() with keyword parameters | ||
* initial_stepsize = 1.0 | ||
* retraction_method = ExponentialRetraction() | ||
* contraction_factor = 0.95 | ||
* sufficient_decrease = 0.1 | ||
|
||
## Stopping Criterion | ||
Stop When _one_ of the following are fulfilled: | ||
Max Iteration 200: not reached | ||
|grad f| < 1.0e-9: reached | ||
Overall: reached | ||
This indicates convergence: Yes | ||
|
||
## Statistics on function calls | ||
* :Gradient : 72 | ||
* :Cost : 164 | ||
on a ManifoldGradientObjective{AllocatingEvaluation} | ||
|
||
Since the default setup with [`ArmijoLinesearch`](@ref) needs the gradient and the | ||
cost, and similarly the stopping criterion might (independently) evaluate the gradient, | ||
the caching is quite helpful here. | ||
|
||
And of course also for this advanced return value of the solver, we can still access the | ||
result as usual: | ||
|
||
``` julia | ||
get_solver_result(r) | ||
``` | ||
|
||
3-element Vector{Float64}: | ||
0.7298774364923435 | ||
0.047665824852873 | ||
0.6819141418393224 |
Oops, something went wrong.
b3c826e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
b3c826e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/84009
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: