import numpy as np
from numba import njit
from collections import namedtuple
__all__ = ['newton', 'newton_halley', 'newton_secant', 'bisect', 'brentq']
_ECONVERGED = 0
_ECONVERR = -1
_iter = 100
_xtol = 2e-12
_rtol = 4*np.finfo(float).eps
results = namedtuple('results', 'root function_calls iterations converged')
@njit
def _results(r):
r"""Select from a tuple of(root, funccalls, iterations, flag)"""
x, funcalls, iterations, flag = r
return results(x, funcalls, iterations, flag == 0)
[docs]@njit
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
disp=True):
"""
Find a zero from the Newton-Raphson method using the jitted version of
Scipy's newton for scalars. Note that this does not provide an alternative
method such as secant. Thus, it is important that `fprime` can be provided.
Note that `func` and `fprime` must be jitted via Numba.
They are recommended to be `njit` for performance.
Parameters
----------
func : callable and jitted
The function whose zero is wanted. It must be a function of a
single variable of the form f(x,a,b,c...), where a,b,c... are extra
arguments that can be passed in the `args` parameter.
x0 : float
An initial estimate of the zero that should be somewhere near the
actual zero.
fprime : callable and jitted
The derivative of the function (when available and convenient).
args : tuple, optional(default=())
Extra arguments to be used in the function call.
tol : float, optional(default=1.48e-8)
The allowable error of the zero value.
maxiter : int, optional(default=50)
Maximum number of iterations.
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge
Returns
-------
results : namedtuple
A namedtuple containing the following items:
::
root - Estimated location where function is zero.
function_calls - Number of times the function was called.
iterations - Number of iterations needed to find the root.
converged - True if the routine converged
"""
if tol <= 0:
raise ValueError("tol is too small <= 0")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")
# Convert to float (don't use float(x0); this works also for complex x0)
p0 = 1.0 * x0
funcalls = 0
status = _ECONVERR
# Newton-Raphson method
for itr in range(maxiter):
# first evaluate fval
fval = func(p0, *args)
funcalls += 1
# If fval is 0, a root has been found, then terminate
if fval == 0:
status = _ECONVERGED
p = p0
itr -= 1
break
fder = fprime(p0, *args)
funcalls += 1
# derivative is zero, not converged
if fder == 0:
p = p0
break
newton_step = fval / fder
# Newton step
p = p0 - newton_step
if abs(p - p0) < tol:
status = _ECONVERGED
break
p0 = p
if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)
return _results((p, funcalls, itr + 1, status))
[docs]@njit
def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,
maxiter=50, disp=True):
"""
Find a zero from Halley's method using the jitted version of
Scipy's.
`func`, `fprime`, `fprime2` must be jitted via Numba.
Parameters
----------
func : callable and jitted
The function whose zero is wanted. It must be a function of a
single variable of the form f(x,a,b,c...), where a,b,c... are extra
arguments that can be passed in the `args` parameter.
x0 : float
An initial estimate of the zero that should be somewhere near the
actual zero.
fprime : callable and jitted
The derivative of the function (when available and convenient).
fprime2 : callable and jitted
The second order derivative of the function
args : tuple, optional(default=())
Extra arguments to be used in the function call.
tol : float, optional(default=1.48e-8)
The allowable error of the zero value.
maxiter : int, optional(default=50)
Maximum number of iterations.
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge
Returns
-------
results : namedtuple
A namedtuple containing the following items:
::
root - Estimated location where function is zero.
function_calls - Number of times the function was called.
iterations - Number of iterations needed to find the root.
converged - True if the routine converged
"""
if tol <= 0:
raise ValueError("tol is too small <= 0")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")
# Convert to float (don't use float(x0); this works also for complex x0)
p0 = 1.0 * x0
funcalls = 0
status = _ECONVERR
# Halley Method
for itr in range(maxiter):
# first evaluate fval
fval = func(p0, *args)
funcalls += 1
# If fval is 0, a root has been found, then terminate
if fval == 0:
status = _ECONVERGED
p = p0
itr -= 1
break
fder = fprime(p0, *args)
funcalls += 1
# derivative is zero, not converged
if fder == 0:
p = p0
break
newton_step = fval / fder
# Halley's variant
fder2 = fprime2(p0, *args)
p = p0 - newton_step / (1.0 - 0.5 * newton_step * fder2 / fder)
if abs(p - p0) < tol:
status = _ECONVERGED
break
p0 = p
if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)
return _results((p, funcalls, itr + 1, status))
[docs]@njit
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
disp=True):
"""
Find a zero from the secant method using the jitted version of
Scipy's secant method.
Note that `func` must be jitted via Numba.
Parameters
----------
func : callable and jitted
The function whose zero is wanted. It must be a function of a
single variable of the form f(x,a,b,c...), where a,b,c... are extra
arguments that can be passed in the `args` parameter.
x0 : float
An initial estimate of the zero that should be somewhere near the
actual zero.
args : tuple, optional(default=())
Extra arguments to be used in the function call.
tol : float, optional(default=1.48e-8)
The allowable error of the zero value.
maxiter : int, optional(default=50)
Maximum number of iterations.
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge.
Returns
-------
results : namedtuple
A namedtuple containing the following items:
::
root - Estimated location where function is zero.
function_calls - Number of times the function was called.
iterations - Number of iterations needed to find the root.
converged - True if the routine converged
"""
if tol <= 0:
raise ValueError("tol is too small <= 0")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")
# Convert to float (don't use float(x0); this works also for complex x0)
p0 = 1.0 * x0
funcalls = 0
status = _ECONVERR
# Secant method
if x0 >= 0:
p1 = x0 * (1 + 1e-4) + 1e-4
else:
p1 = x0 * (1 + 1e-4) - 1e-4
q0 = func(p0, *args)
funcalls += 1
q1 = func(p1, *args)
funcalls += 1
for itr in range(maxiter):
if q1 == q0:
p = (p1 + p0) / 2.0
status = _ECONVERGED
break
else:
p = p1 - q1 * (p1 - p0) / (q1 - q0)
if np.abs(p - p1) < tol:
status = _ECONVERGED
break
p0 = p1
q0 = q1
p1 = p
q1 = func(p1, *args)
funcalls += 1
if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)
return _results((p, funcalls, itr + 1, status))
@njit
def _bisect_interval(a, b, fa, fb):
"""Conditional checks for intervals in methods involving bisection"""
if fa*fb > 0:
raise ValueError("f(a) and f(b) must have different signs")
root = 0.0
status = _ECONVERR
# Root found at either end of [a,b]
if fa == 0:
root = a
status = _ECONVERGED
if fb == 0:
root = b
status = _ECONVERGED
return root, status
[docs]@njit
def bisect(f, a, b, args=(), xtol=_xtol,
rtol=_rtol, maxiter=_iter, disp=True):
"""
Find root of a function within an interval adapted from Scipy's bisect.
Basic bisection routine to find a zero of the function `f` between the
arguments `a` and `b`. `f(a)` and `f(b)` cannot have the same signs.
`f` must be jitted via numba.
Parameters
----------
f : jitted and callable
Python function returning a number. `f` must be continuous.
a : number
One end of the bracketing interval [a,b].
b : number
The other end of the bracketing interval [a,b].
args : tuple, optional(default=())
Extra arguments to be used in the function call.
xtol : number, optional(default=2e-12)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be nonnegative.
rtol : number, optional(default=4*np.finfo(float).eps)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root.
maxiter : number, optional(default=100)
Maximum number of iterations.
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge.
Returns
-------
results : namedtuple
"""
if xtol <= 0:
raise ValueError("xtol is too small (<= 0)")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")
# Convert to float
xa = a * 1.0
xb = b * 1.0
fa = f(xa, *args)
fb = f(xb, *args)
funcalls = 2
root, status = _bisect_interval(xa, xb, fa, fb)
# Check for sign error and early termination
if status == _ECONVERGED:
itr = 0
else:
# Perform bisection
dm = xb - xa
for itr in range(maxiter):
dm *= 0.5
xm = xa + dm
fm = f(xm, *args)
funcalls += 1
if fm * fa >= 0:
xa = xm
if fm == 0 or abs(dm) < xtol + rtol * abs(xm):
root = xm
status = _ECONVERGED
itr += 1
break
if disp and status == _ECONVERR:
raise RuntimeError("Failed to converge")
return _results((root, funcalls, itr, status))
[docs]@njit
def brentq(f, a, b, args=(), xtol=_xtol,
rtol=_rtol, maxiter=_iter, disp=True):
"""
Find a root of a function in a bracketing interval using Brent's method
adapted from Scipy's brentq.
Uses the classic Brent's method to find a zero of the function `f` on
the sign changing interval [a , b].
`f` must be jitted via numba.
Parameters
----------
f : jitted and callable
Python function returning a number. `f` must be continuous.
a : number
One end of the bracketing interval [a,b].
b : number
The other end of the bracketing interval [a,b].
args : tuple, optional(default=())
Extra arguments to be used in the function call.
xtol : number, optional(default=2e-12)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be nonnegative.
rtol : number, optional(default=4*np.finfo(float).eps)
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root.
maxiter : number, optional(default=100)
Maximum number of iterations.
disp : bool, optional(default=True)
If True, raise a RuntimeError if the algorithm didn't converge.
Returns
-------
results : namedtuple
"""
if xtol <= 0:
raise ValueError("xtol is too small (<= 0)")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")
# Convert to float
xpre = a * 1.0
xcur = b * 1.0
fpre = f(xpre, *args)
fcur = f(xcur, *args)
funcalls = 2
root, status = _bisect_interval(xpre, xcur, fpre, fcur)
# Check for sign error and early termination
if status == _ECONVERGED:
itr = 0
else:
# Perform Brent's method
for itr in range(maxiter):
if fpre * fcur < 0:
xblk = xpre
fblk = fpre
spre = scur = xcur - xpre
if abs(fblk) < abs(fcur):
xpre = xcur
xcur = xblk
xblk = xpre
fpre = fcur
fcur = fblk
fblk = fpre
delta = (xtol + rtol * abs(xcur)) / 2
sbis = (xblk - xcur) / 2
# Root found
if fcur == 0 or abs(sbis) < delta:
status = _ECONVERGED
root = xcur
itr += 1
break
if abs(spre) > delta and abs(fcur) < abs(fpre):
if xpre == xblk:
# interpolate
stry = -fcur * (xcur - xpre) / (fcur - fpre)
else:
# extrapolate
dpre = (fpre - fcur) / (xpre - xcur)
dblk = (fblk - fcur) / (xblk - xcur)
stry = -fcur * (fblk * dblk - fpre * dpre) / \
(dblk * dpre * (fblk - fpre))
if (2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta)):
# good short step
spre = scur
scur = stry
else:
# bisect
spre = sbis
scur = sbis
else:
# bisect
spre = sbis
scur = sbis
xpre = xcur
fpre = fcur
if (abs(scur) > delta):
xcur += scur
else:
xcur += (delta if sbis > 0 else -delta)
fcur = f(xcur, *args)
funcalls += 1
if disp and status == _ECONVERR:
raise RuntimeError("Failed to converge")
return _results((root, funcalls, itr, status))