Skip to content

Commit

Permalink
rework to save bracket (#173)
Browse files Browse the repository at this point in the history
* introduce `find_bracket` (which required adding a new field  to the state)

* squeeze tolerance  for  AlefeldPotraShi models for smaller brackets before termination

* fix a few bugs exposed in the process

* version bump
  • Loading branch information
jverzani authored Feb 28, 2020
1 parent bbab77c commit a6dc315
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 43 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name = "Roots"
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
version = "0.8.4"
version = "1.0.0"

[deps]
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
SpecialFunctions = "0.8, 0.9"
SpecialFunctions = "0.8, 0.9, 0.10"
julia = "1.0"

[extras]
Expand Down
102 changes: 76 additions & 26 deletions src/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ function _init_state(method::AbstractBisection, fs, xs, fxs)

x0, x1 = xs
fx0, fx1 = fxs
state = UnivariateZeroState(x1, x0, [x1],
fx1, fx0, [fx1],

state = UnivariateZeroState(x1, x0, zero(x1)/zero(x1)*oneunit(x1), [x1],
fx1, fx0, fx1, [fx1],
0, 2,
false, false, false, false,
"")
Expand Down Expand Up @@ -124,8 +125,10 @@ function init_state!(state::UnivariateZeroState{T,S}, M::AbstractBisection, fs,
m, fm = x1, fx1
end
state.f_converged = true
state.xn1 = m
state.fxn1 = fm
state.xstar = m
state.fxstar = fm
#state.xn1 = m
#state.fxn1 = fm
return state
end

Expand All @@ -137,8 +140,15 @@ function init_state!(state::UnivariateZeroState{T,S}, M::AbstractBisection, fs,
incfn(state)
if sign(fx0) * sign(fm) < 0
x1, fx1 = m, fm
else
elseif sign(fx0) * sign(fm) > 0
x0, fx0 = m, fm
else
state.message = "Exact zero found"
state.xstar = m
state.fxstar = fm
state.f_converged = true
state.x_converged = true
return state
end
end

Expand Down Expand Up @@ -234,7 +244,7 @@ function assess_convergence(M::Bisection, state::UnivariateZeroState{T,S}, optio

if x_converged
state.message=""
state.xn1 = xm
state.xstar = xm
state.x_converged = true
return true
end
Expand All @@ -246,8 +256,8 @@ function assess_convergence(M::Bisection, state::UnivariateZeroState{T,S}, optio

if f_converged
state.message = ""
state.xn1 = xm
state.fxn1 = fm
state.xstar = xm
state.fxstar = fm
state.f_converged = f_converged
return true
end
Expand All @@ -266,14 +276,17 @@ function assess_convergence(M::BisectionExact, state::UnivariateZeroState{T,S},
for (c,fc) in ((x0,y0), (xm,ym), (x1, y1))
if iszero(fc) || isnan(fc) #|| isinf(fc)
state.f_converged = true
state.xn1 = c
state.fxn1 = fc
state.x_converged = true
state.xstar = c
state.fxstar = fc
return true
end
end

x0 < xm < x1 && return false

state.xstar = x1
state.fxstar = ym
state.x_converged = true
return true
end
Expand All @@ -288,7 +301,7 @@ function update_state(method::Union{Bisection, BisectionExact}, fs, o::Univariat
m::T = o.m[1]
ym::S = o.fm[1] #fs(m)

if y0 * ym < 0
if sign(y0) * sign(ym) < 0
o.xn1, o.fxn1 = m, ym
else
o.xn0, o.fxn0 = m, ym
Expand Down Expand Up @@ -336,7 +349,7 @@ function find_zero(fs, x0, method::M;

verbose && show_trace(method, nothing, state, l)

state.xn1
state.xstar

end

Expand Down Expand Up @@ -456,8 +469,8 @@ function init_state(M::AbstractAlefeldPotraShi, f, xs)
end
fu, fv = promote(f(u), f(v))
isbracket(fu, fv) || throw(ArgumentError(bracketing_error))
state = UnivariateZeroState(v, u, [v, v], ## x1, x0, d, [ee]
fv, fu, [fv,fv], ## fx1, fx0, d, [fe]
state = UnivariateZeroState(v, u, zero(v)/zero(v)*oneunit(v), [v, v], ## x1, x0, d, [ee]
fv, fu, fv, [fv,fv], ## fx1, fx0, d, [fe]
0, 2,
false, false, false, false,
"")
Expand Down Expand Up @@ -502,14 +515,14 @@ end
default_tolerances(::AbstractAlefeldPotraShi, T, S)
The default tolerances for Alefeld, Potra, and Shi methods are
`xatol=zero(T)`, `xrtol=2eps(T)`, `atol= zero(S), and rtol=zero(S)`, with
`xatol=zero(T)`, `xrtol=eps(T)/2`, `atol= zero(S), and rtol=zero(S)`, with
appropriate units; `maxevals=45`, `maxfnevals = Inf`; and `strict=true`.
"""
default_tolerances(M::AbstractAlefeldPotraShi) = default_tolerances(M, Float64, Float64)
function default_tolerances(::AbstractAlefeldPotraShi, ::Type{T}, ::Type{S}) where {T,S}
xatol = zero(T)
xrtol = 2 * eps(one(T))
xrtol = eps(one(T))/2
atol = zero(float(one(S))) * oneunit(S)
rtol = zero(float(one(S))) * one(S)
maxevals = 45
Expand All @@ -536,16 +549,23 @@ function check_zero(::AbstractBracketing, state, c, fc)
elseif iszero(fc)
state.f_converged=true
state.message *= "Exact zero found. "
state.xn1 = c
state.fxn1 = fc
state.xstar = c
state.fxstar = fc
# state.xn1 = c
# state.fxn1 = fc
return true
end
return false
end

function assess_convergence(method::AbstractAlefeldPotraShi, state::UnivariateZeroState{T,S}, options) where {T,S}

(state.stopped || state.x_converged || state.f_converged) && return true
if state.stopped || state.x_converged || state.f_converged
if isnan(state.xstar)
state.xstar, state.fxstar = state.xn1, state.fxn1
end
return true
end

if state.steps > options.maxevals
state.stopped = true
Expand All @@ -565,21 +585,24 @@ function assess_convergence(method::AbstractAlefeldPotraShi, state::UnivariateZe

if abs(fu) <= maximum(promote(options.abstol, abs(u) * oneunit(fu) / oneunit(u) * options.reltol))
state.f_converged = true
state.xn1=u
state.fxn1=fu
state.xstar=u
state.fxstar=fu
if iszero(fu)

state.x_converged
state.message *= "Exact zero found. "
end
return true
end

a,b = state.xn0, state.xn1
tol = maximum(promote(options.xabstol, max(abs(a),abs(b)) * options.xreltol))
mx = max(abs(a), abs(b))
tol = maximum(promote(options.xabstol, mx * options.xreltol, sign(options.xreltol) * eps(mx)))

if abs(b-a) <= 2tol
# use smallest of a,b,m
state.xn1 = u
state.fxn1 = fu
state.xstar = u
state.fxstar = fu
state.x_converged = true
return true
end
Expand Down Expand Up @@ -755,8 +778,8 @@ function init_state(M::Brent, f, xs)



state = UnivariateZeroState(b, a, [a, a], ## x1, x0, c, d
fb, fa, [fa, one(fa)], ## fx1, fx0, fc, mflag
state = UnivariateZeroState(b, a, zero(b)/zero(b)*oneunit(b), [a, a], ## x1, x0, c, d
fb, fa, fb, [fa, one(fa)], ## fx1, fx0, fc, mflag
0, 2,
false, false, false, false,
"")
Expand Down Expand Up @@ -944,3 +967,30 @@ end
$f(fa, fb, fx)
end
end




"""
find_bracket(f, x0, method=A42(); kwargs...)
For bracketing methods returns an approximate root, the last bracketing interval used, and a flag indicating if an exact zero was found as a named tuple.
With the default tolerances, one of these should be the case: `exact` is `true` (indicating termination of the algorithm due to an exact zero being identified) or the length of `bracket` is less or equal than `2eps(maximum(abs.(bracket)))`. In the `BisectionExact` case, the 2 could be replaced by 1, as the bracket, `(a,b)` will satisfy `nextfloat(a) == b `; the Alefeld, Potra, and Shi algorithms don't quite have that promise.
"""
function find_bracket(fs, x0, method::M=A42(); kwargs...) where {M <: Union{AbstractAlefeldPotraShi, BisectionExact}}
x = adjust_bracket(x0)
T = eltype(x[1])
F = callable_function(fs)
state = init_state(method, F, x)
options = init_options(method, state; kwargs...)

# check if tolerances are exactly 0
iszero_tol = iszero(options.xabstol) && iszero(options.xreltol) && iszero(options.abstol) && iszero(options.reltol)

find_zero(method, F, options, state, NullTracks())

(xstar=state.xstar, bracket=(state.xn0, state.xn1), exact=iszero(state.fxstar))

end
4 changes: 2 additions & 2 deletions src/derivative_free.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ end
function init_state(method::AbstractSecant, fs, x::Union{Tuple, Vector})
x0, x1 = promote(float(x[1]), float(x[2]))
fx0, fx1 = fs(x0), fs(x1)
state = UnivariateZeroState(x1, x0, eltype(x1)[],
fx1, fx0, eltype(fx1)[],
state = UnivariateZeroState(x1, x0, zero(x1)/zero(x1)*oneunit(x1), eltype(x1)[],
fx1, fx0, fx1, eltype(fx1)[],
0, 2,
false, false, false, false, "")

Expand Down
39 changes: 30 additions & 9 deletions src/find_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ abstract type AbstractUnivariateZeroState end
mutable struct UnivariateZeroState{T,S} <: AbstractUnivariateZeroState where {T,S}
xn1::T
xn0::T
xstar::T
m::Vector{T}
fxn1::S
fxn0::S
fxstar::S
fm::Vector{S}
steps::Int
fnevals::Int
Expand All @@ -61,9 +63,9 @@ function init_state(method::Any, fs, x)
x1 = float(x)
fx1 = fs(x1); fnevals = 1
T, S = eltype(x1), eltype(fx1)

state = UnivariateZeroState(x1, oneunit(x1) * (0*x1)/(0*x1), T[],
fx1, oneunit(fx1) * (0*fx1)/(0*fx1), S[],
zT, zS = oneunit(x1) * (0*x1)/(0*x1), oneunit(fx1) * (0*fx1)/(0*fx1)
state = UnivariateZeroState(x1, zT, zT/Zt*oneunit(x1), T[],
fx1, zS, zS, S[],
0, fnevals,
false, false, false, false,
"")
Expand Down Expand Up @@ -112,8 +114,8 @@ end


function Base.copy(state::UnivariateZeroState{T,S}) where {T, S}
UnivariateZeroState(state.xn1, state.xn0, copy(state.m),
state.fxn1, state.fxn0, copy(state.fm),
UnivariateZeroState(state.xn1, state.xn0, state.xstar, copy(state.m),
state.fxn1, state.fxn0, state.fxstar, copy(state.fm),
state.steps, state.fnevals,
state.stopped, state.x_converged,
state.f_converged, state.convergence_failed,
Expand Down Expand Up @@ -348,6 +350,9 @@ function assess_convergence(method::Any, state::UnivariateZeroState{T,S}, option
fxn1 = state.fxn1

if (state.x_converged || state.f_converged || state.stopped)
if isnan(state.xstar)
state.xstar, state.fxstar = xn1, fxn1
end
return true
end

Expand All @@ -365,13 +370,15 @@ function assess_convergence(method::Any, state::UnivariateZeroState{T,S}, option

# f(xstar) ≈ xstar * f'(xstar)*eps(), so we pass in lambda
if _is_f_approx_0(fxn1, xn1, options.abstol, options.reltol)
state.xstar, state.fxstar = xn1, fxn1
state.f_converged = true
return true
end

# stop when xn1 ~ xn.
# in find_zeros there is a check that f could be a zero with a relaxed tolerance
if abs(xn1 - xn0) < max(options.xabstol, max(abs(xn1), abs(xn0)) * options.xreltol)
state.xstar, state.fxstar = xn1, fxn1
state.message *= "x_n ≈ x_{n-1}. "
state.x_converged = true
return true
Expand Down Expand Up @@ -580,10 +587,22 @@ function find_zero(M::AbstractUnivariateZeroMethod,
return decide_convergence(M, F, state, options)
end

function find_zero(M::AbstractUnivariateZeroMethod,
F,
state::AbstractUnivariateZeroState,
l::AbstractTracks=NullTracks()
) #where {T<:Number, S<:Number}

options = init_options(M, state)
find_zero(M, F, options, state, l)
end



# state has stopped, this identifies if it has converged
function decide_convergence(M::AbstractUnivariateZeroMethod, F, state::UnivariateZeroState{T,S}, options) where {T,S}
xn1 = state.xn1
fxn1 = state.fxn1
xn1 = state.xstar
fxn1 = state.fxstar

if (state.stopped || state.x_converged) && !(state.f_converged)
## stopped is a heuristic, x_converged can mask issues
Expand Down Expand Up @@ -611,6 +630,7 @@ function decide_convergence(M::AbstractUnivariateZeroMethod, F, state::Univaria
else
xstar, fxstar = state.xn1, state.fxn1
if _is_f_approx_0(fxstar, xstar, options.abstol, options.reltol, :relaxed)
state.xstar, state.fxstar = xstar, fxstar
msg = "Algorithm stopped early, but |f(xn)| < ϵ^(1/3), where ϵ depends on xn, rtol, and atol. "
state.message = state.message == "" ? msg : state.message * "\n\t" * msg
state.f_converged = true
Expand All @@ -621,10 +641,10 @@ function decide_convergence(M::AbstractUnivariateZeroMethod, F, state::Univaria
end

if state.f_converged
return state.xn1
return state.xstar
end

nan = NaN * state.xn1
nan = NaN * xn1
if state.convergence_failed
return nan
end
Expand Down Expand Up @@ -685,6 +705,7 @@ function find_zero(M::AbstractUnivariateZeroMethod,
## did we find a zero or a bracketing interval?
if iszero(state0.fxn1)
copy!(state, state0)
state.xstar, state.fxstar = state.xn1, state.fxn1
state.f_converged = true
break
elseif sign(state0.fxn0) * sign(state0.fxn1) < 0
Expand Down
10 changes: 6 additions & 4 deletions src/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ function init_state(method::AbstractNewtonLikeMethod, fs, x)
fnevals = 1
S = eltype(fx1)

state = UnivariateZeroState(x1, oneunit(x1) * (0*x1)/(0*x1), [Δ],
fx1, oneunit(fx1) * (0*fx1)/(0*fx1), S[],
zT, zS = oneunit(x1) * (0*x1)/(0*x1), oneunit(fx1) * (0*fx1)/(0*fx1)
state = UnivariateZeroState(x1, zT, zT/zT*oneunit(x1), [Δ],
fx1, zS, zS, S[],
0, fnevals,
false, false, false, false,
"")
Expand Down Expand Up @@ -163,8 +164,9 @@ function init_state(method::AbstractHalleyLikeMethod, fs, x)
S = eltype(fx1)
fnevals = 3

state = UnivariateZeroState(x1, oneunit(x1) * (0*x1)/(0*x1), [Δ,ΔΔ],
fx1, oneunit(fx1) * (0*fx1)/(0*fx1), S[],
zT, zS = oneunit(x1) * (0*x1)/(0*x1), oneunit(fx1) * (0*fx1)/(0*fx1)
state = UnivariateZeroState(x1, zT, zT/zT*oneunit(x1), [Δ,ΔΔ],
fx1, zS, zS, S[],
0, fnevals,
false, false, false, false,
"")
Expand Down
Loading

2 comments on commit a6dc315

@jverzani
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/10254

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v1.0.0 -m "<description of version>" a6dc3154fad3107c753ca2e98d6b72ff54832153
git push origin v1.0.0

Please sign in to comment.