Skip to content

Commit

Permalink
Signbit (#137)
Browse files Browse the repository at this point in the history
* speedup bisection

* avoid Float64 conversion
  • Loading branch information
jverzani authored Sep 18, 2018
1 parent 1b81d22 commit 72050cb
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 75 deletions.
4 changes: 2 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
CHANGES in v0.7.3

* fix bug with find_zeros and Float32
* speeds up bisection function

CHANGES in v0.7.2

* speed up bisection
* speed up bisection

CHANGES in v0.7.1

Expand All @@ -13,4 +14,3 @@ CHANGES in v0.7.1
* took algorithm from Order0, and made it an alternative for find_zero allowing other non-bracketing methods to be more robust

* In FalsePosition there is a parameter to adjust when a bisection step should be used. This was changed in v0.7.0, the old value is restored. (This method is too sensitive to this parameter. It is recommended that either A42 or AlefeldPotraShi be used as alternatives.

24 changes: 12 additions & 12 deletions src/find_zeros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function _fz!(zs, f, a::T, b, no_pts, k=4) where {T}

for (i,x) in enumerate(pts[1:end])
q,r = divrem(i-1, k)

if i > 1 && iszero(r)
v::T = x
if !found_bisection_zero
Expand Down Expand Up @@ -54,7 +54,7 @@ function _fz!(zs, f, a::T, b, no_pts, k=4) where {T}
end
end
end

sort!(zs)
end

Expand All @@ -78,7 +78,7 @@ end
function find_non_zero(f, a::T, barrier, xatol, xrtol, atol, rtol) where {T}
nan = (0*a)/(0*a) # try to get typed NaN
xtol = max(xatol, abs(a) * xrtol, oneunit(a) * eps(T))
sgn = barrier > a ? 1 : -1
sgn = barrier > a ? 1 : -1
ctr = 0
x = a + 2^ctr*sgn*xtol
while !_non_zero(f(x), x, atol, rtol)
Expand Down Expand Up @@ -151,9 +151,9 @@ find_zeros(x -> sin(x^2) + cos(x)^2, 0, 10) # many zeros
find_zeros(x -> cos(x) + cos(2x), 0, 4pi) # mix of simple, non-simple zeros
f(x) = (x-0.5) * (x-0.5001) * (x-1) # nearby zeros
find_zeros(f, 0, 2)
f(x) = (x-0.5) * (x-0.5001) * (x-4) * (x-4.001) * (x-4.2)
find_zeros(f, 0, 10)
f(x) = (x-0.5)^2 * (x-0.5001)^3 * (x-4) * (x-4.001) * (x-4.2)^2 # hard to identify
f(x) = (x-0.5) * (x-0.5001) * (x-4) * (x-4.001) * (x-4.2)
find_zeros(f, 0, 10)
f(x) = (x-0.5)^2 * (x-0.5001)^3 * (x-4) * (x-4.001) * (x-4.2)^2 # hard to identify
find_zeros(f, 0, 10, no_pts=21) # too hard for default
```
Expand Down Expand Up @@ -192,7 +192,7 @@ compares `|f(x)| <= 8*eps(x)` to identify a zero. The algorithm might
identify more than one value for a zero, due to floating point
approximations. If a potential pair of zeros satisfy
`isapprox(a,b,atol=sqrt(xatol), rtol=sqrt(xrtol))` then they are
consolidated.
consolidated.
The algorithm can make many function calls. When zeros are found in an
interval, the naive search is carried out on each subinterval. To cut
Expand All @@ -216,20 +216,20 @@ function find_zeros(f, a, b; no_pts = 12, k=8,
a0, b0 = promote(float(a), float(b))
a0 = isinf(a0) ? nextfloat(a0) : a0
b0 = isinf(b0) ? prevfloat(b0) : b0

# set tolerances if not specified
fa0 = f(a0)
d = Dict(kwargs)
T, S = eltype(a0), eltype(fa0)
xatol::T = get(d, :xatol, eps(one(T))^(4/5) * oneunit(T))
xatol::T = get(d, :xatol, eps(one(T))^(4//5) * oneunit(T))
xrtol = get(d, :xrtol, eps(one(T)) * one(T))
atol::S = get(d, :atol, eps(float(S)) * oneunit(S))
rtol = get(d, :rtol, eps(float(S)) * one(S))

zs = T[] # collect zeros

_fz!(zs, f, a0, b0, no_pts,k) # initial zeros

ints = Interval{T}[] # collect subintervals
!naive && !isempty(zs) && make_intervals!(ints, f, a0, b0, zs, 1, xatol, xrtol, atol, rtol)

Expand All @@ -250,7 +250,7 @@ function find_zeros(f, a, b; no_pts = 12, k=8,
#sub_no_pts <= 2 && continue # stop on depth, always divide if roots
#sub_no_pts = max(3, floor(Int, no_pts / (2.0)^(i.depth)))
sub_no_pts = floor(Int, no_pts / (2.0)^(i.depth))

empty!(nzs)
if sub_no_pts >= 2
_fz!(nzs, f, i.a, i.b, sub_no_pts, k)
Expand Down
138 changes: 77 additions & 61 deletions src/simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,101 @@
# These avoid the setup costs of the `find_zero` method, so should be faster
# though they will take similar number of function calls.
#
# `Roots.bisection(f, a, b)` (Bisection).
# `Roots.bisection(f, a, b)` (Bisection).
# `Roots.secant_method(f, xs)` (Order1) secant method
# `Roots.dfree(f, xs)` (Order0) more robust secant method
#

## Bisection
## Bisection
##
## Essentially from Jason Merrill https://gist.github.com/jwmerrill/9012954
## cf. http://squishythinking.com/2014/02/22/bisecting-floats/
## This also borrows a trick from https://discourse.julialang.org/t/simple-and-fast-bisection/14886
## where we keep x1 so that y1 is negative, and x2 so that y2 is positive
## this allows the use of signbit over y1*y2 < 0 which avoid < and a multiplication
## this has a small, but noticeable impact on performance.
"""
bisection(f, a, b; [xatol, xrtol])
Performs bisection method to find a zero of a continuous
function.
function.
It is assumed that (a,b) is a bracket, that is, the function has
different signs at a and b. The interval (a,b) is converted to floating point
and shrunk when a or b is infinite. The function f may be infinite for
the typical case. If f is not continuous, the algorithm may find
jumping points over the x axis, not just zeros.
If non-trivial tolerances are specified, the process will terminate
when the bracket (a,b) satisfies `isapprox(a, b, atol=xatol,
rtol=xrtol)`. For zero tolerances, the default, for Float64, Float32,
or Float16 values, the process will terminate at a value `x` with
`f(x)=0` or `f(x)*f(prevfloat(x)) < 0 ` or `f(x) * f(nextfloat(x)) <
0`. For other number types, the A42 method is used.
"""
function bisection(fn, a::Number, b::Number; xatol=nothing, xrtol=nothing)
function bisection(f, a::Number, b::Number; xatol=nothing, xrtol=nothing)

x1, x2 = adjust_bracket(float.((a,b)))
T = eltype(x1)

s1 = sign(fn(x1))
s2 = sign(fn(x2))
s1 * s2 < 0 || throw(ArgumentError(bracketing_error))

atol = xatol == nothing ? zero(T) : xatol
rtol = xrtol == nothing ? zero(one(T)) : xrtol

# will converge with zero tolerance specified
iszero(atol) && iszero(rtol) && !(T <: FloatNN) && find_zero(fn, (a,b), A42())

xm = _middle(x1, x2)
if iszero(xm)
sm = sign(fn(xm))
if s1 * sm < 0
x2, s2 = xm, sm
elseif s2 * sm < 0
x1, s1 = xm, sm
else


atol = xatol == nothing ? zero(T) : abs(xatol)
rtol = xrtol == nothing ? zero(one(T)) : abs(xrtol)
CT = iszero(atol) && iszero(rtol) ? Val(:exact) : Val(:inexact)

x1, x2 = float(x1), float(x2)
y1, y2 = f(x1), f(x2)

_unitless(y1 * y2) >= 0 && error("the interval provided does not bracket a root")

if isneg(y2)
x1, x2, y1, y2 = x2, x1, y2, y1
end

xm = Roots._middle(x1, x2) # for possibly mixed sign x1, x2
ym = f(xm)

while true

if has_converged(CT, x1, x2, xm, ym, atol, rtol)
return xm
end
xm = __middle(x1, x2)
end

while x1 < xm < x2
isapprox(x1, x2, atol=atol, rtol=rtol) && break

sm = sign(fn(xm))

if s1 * sm < 0
x2 = xm
s2 = sm
elseif s2 * sm < 0
x1 = xm
s1 = sm
if isneg(ym)
x1, y1 = xm, ym
else
return xm
x2, y2 = xm, ym
end

xm = __middle(x1, x2)

xm = Roots.__middle(x1,x2)
ym = f(xm)


end
return xm

end

# -0.0 not returned by __middle, so isneg true on [-Inf, 0.0)
@inline isneg(x::T) where {T <: AbstractFloat} = signbit(x)
@inline isneg(x) = _unitless(x) < 0

@inline function has_converged(::Val{:exact}, x1, x2, m, ym, atol, rtol)
iszero(ym) && return true
isnan(ym) && return true
x1 != m && m != x2 && return false
return true
end

@inline function has_converged(::Val{:inexact}, x1, x2, m, ym, atol, rtol)
iszero(ym) && return true
isnan(ym) && return true
val = abs(x1 - x2) <= atol + max(abs(x1), abs(x2)) * rtol
return val
end


"""
secant_method(f, xs; [atol=0.0, rtol=8eps(), maxevals=1000])
Expand All @@ -105,26 +122,26 @@ The `Order1` method for `find_zero` also implements the secant
method. This one will be faster, as there are fewer setup costs.
Examples:
```julia
Roots.secant_method(sin, (3,4))
Roots.secant_method(x -> x^5 -x - 1, 1.1)
```
Note:
This function will specialize on the function `f`, so that the inital
call can take more time than a call to the `Order1()` method, though
subsequent calls will be much faster. Using `FunctionWrappers.jl` can
ensure that the initial call is also equally as fast as subsequent
ones.
"""
function secant_method(f, xs; atol=zero(float(real(first(xs)))), rtol=8eps(one(float(real(first(xs))))), maxevals=100)

if length(xs) == 1 # secant needs x0, x1; only x0 given
a = float(xs[1])

h = eps(one(real(a)))^(1/3)
da = h*oneunit(a) + abs(a)*h^2 # adjust for if eps(a) > h
b = a + da
Expand Down Expand Up @@ -155,17 +172,17 @@ function secant(f, a::T, b::T, atol=zero(T), rtol=8eps(T), maxevals=100) where {
abs(fm) <= adjustunit * max(uatol, abs(m) * rtol) && return m
if fm == fb
sign(fm) * sign(f(nextfloat(m))) <= 0 && return m
sign(fm) * sign(f(prevfloat(m))) <= 0 && return m
sign(fm) * sign(f(prevfloat(m))) <= 0 && return m
return nan
end

a,b,fa,fb = b,m,fb,fm

cnt += 1
end

return nan
end
end



Expand Down Expand Up @@ -196,7 +213,7 @@ secant method to convergence unless:
Convergence occurs when `f(m) == 0`, there is a sign change between
`m` and an adjacent floating point value, or `f(m) <= 2^3*eps(m)`.
A value of `NaN` is returned if the algorithm takes too many steps
before identifying a zero.
Expand All @@ -222,15 +239,15 @@ function dfree(f, xs)
fa, fb = f(a), f(b)
end


nan = (0*a)/(0*a) # try to preserve type
cnt, MAXCNT = 0, 5 * ceil(Int, -log(eps(one(a)))) # must be higher for BigFloat
MAXQUAD = 3

if abs(fa) > abs(fb)
a,fa,b,fb=b,fb,a,fa
end

# we keep a, b, fa, fb, gamma, fgamma
quad_ctr = 0
while !iszero(fb)
Expand All @@ -249,11 +266,11 @@ function dfree(f, xs)
gamma = b + sign(gamma-b) * 100 * abs(b-a) ## too big
end
fgamma = f(gamma)

# change sign
if sign(fgamma) * sign(fb) < 0
return bisection(f, gamma, b)
end
end

# decreasing
if abs(fgamma) < abs(fb)
Expand All @@ -271,7 +288,7 @@ function dfree(f, xs)
cnt < MAXCNT && continue
end


quad_ctr += 1
if (quad_ctr > MAXQUAD) || (cnt > MAXCNT) || iszero(gamma - b) || isnan(gamma)
bprev, bnext = prevfloat(b), nextfloat(b)
Expand All @@ -281,17 +298,16 @@ function dfree(f, xs)
for (u,fu) in ((b,fb), (bprev, fbprev), (bnext, fbnext))
abs(fu)/oneunit(fu) <= 2^3*eps(u/oneunit(u)) && return u
end
return nan # Failed.
return nan # Failed.
end

if abs(fgamma) < abs(fb)
b,fb, a,fa = gamma, fgamma, b, fb
else
a, fa = gamma, fgamma
end

end
b

end

0 comments on commit 72050cb

Please sign in to comment.