Source code for quantecon.optimize.root_finding

import numpy as np
from numba import njit
from collections import namedtuple

__all__ = ['newton', 'newton_halley', 'newton_secant', 'bisect', 'brentq']

_ECONVERGED = 0
_ECONVERR = -1

_iter = 100
_xtol = 2e-12
_rtol = 4*np.finfo(float).eps

results = namedtuple('results', 'root function_calls iterations converged')


@njit
def _results(r):
    r"""Select from a tuple of(root, funccalls, iterations, flag)"""
    x, funcalls, iterations, flag = r
    return results(x, funcalls, iterations, flag == 0)


[docs]@njit def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50, disp=True): """ Find a zero from the Newton-Raphson method using the jitted version of Scipy's newton for scalars. Note that this does not provide an alternative method such as secant. Thus, it is important that `fprime` can be provided. Note that `func` and `fprime` must be jitted via Numba. They are recommended to be `njit` for performance. Parameters ---------- func : callable and jitted The function whose zero is wanted. It must be a function of a single variable of the form f(x,a,b,c...), where a,b,c... are extra arguments that can be passed in the `args` parameter. x0 : float An initial estimate of the zero that should be somewhere near the actual zero. fprime : callable and jitted The derivative of the function (when available and convenient). args : tuple, optional(default=()) Extra arguments to be used in the function call. tol : float, optional(default=1.48e-8) The allowable error of the zero value. maxiter : int, optional(default=50) Maximum number of iterations. disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge Returns ------- results : namedtuple A namedtuple containing the following items: :: root - Estimated location where function is zero. function_calls - Number of times the function was called. iterations - Number of iterations needed to find the root. converged - True if the routine converged """ if tol <= 0: raise ValueError("tol is too small <= 0") if maxiter < 1: raise ValueError("maxiter must be greater than 0") # Convert to float (don't use float(x0); this works also for complex x0) p0 = 1.0 * x0 funcalls = 0 status = _ECONVERR # Newton-Raphson method for itr in range(maxiter): # first evaluate fval fval = func(p0, *args) funcalls += 1 # If fval is 0, a root has been found, then terminate if fval == 0: status = _ECONVERGED p = p0 itr -= 1 break fder = fprime(p0, *args) funcalls += 1 # derivative is zero, not converged if fder == 0: p = p0 break newton_step = fval / fder # Newton step p = p0 - newton_step if abs(p - p0) < tol: status = _ECONVERGED break p0 = p if disp and status == _ECONVERR: msg = "Failed to converge" raise RuntimeError(msg) return _results((p, funcalls, itr + 1, status))
[docs]@njit def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8, maxiter=50, disp=True): """ Find a zero from Halley's method using the jitted version of Scipy's. `func`, `fprime`, `fprime2` must be jitted via Numba. Parameters ---------- func : callable and jitted The function whose zero is wanted. It must be a function of a single variable of the form f(x,a,b,c...), where a,b,c... are extra arguments that can be passed in the `args` parameter. x0 : float An initial estimate of the zero that should be somewhere near the actual zero. fprime : callable and jitted The derivative of the function (when available and convenient). fprime2 : callable and jitted The second order derivative of the function args : tuple, optional(default=()) Extra arguments to be used in the function call. tol : float, optional(default=1.48e-8) The allowable error of the zero value. maxiter : int, optional(default=50) Maximum number of iterations. disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge Returns ------- results : namedtuple A namedtuple containing the following items: :: root - Estimated location where function is zero. function_calls - Number of times the function was called. iterations - Number of iterations needed to find the root. converged - True if the routine converged """ if tol <= 0: raise ValueError("tol is too small <= 0") if maxiter < 1: raise ValueError("maxiter must be greater than 0") # Convert to float (don't use float(x0); this works also for complex x0) p0 = 1.0 * x0 funcalls = 0 status = _ECONVERR # Halley Method for itr in range(maxiter): # first evaluate fval fval = func(p0, *args) funcalls += 1 # If fval is 0, a root has been found, then terminate if fval == 0: status = _ECONVERGED p = p0 itr -= 1 break fder = fprime(p0, *args) funcalls += 1 # derivative is zero, not converged if fder == 0: p = p0 break newton_step = fval / fder # Halley's variant fder2 = fprime2(p0, *args) p = p0 - newton_step / (1.0 - 0.5 * newton_step * fder2 / fder) if abs(p - p0) < tol: status = _ECONVERGED break p0 = p if disp and status == _ECONVERR: msg = "Failed to converge" raise RuntimeError(msg) return _results((p, funcalls, itr + 1, status))
[docs]@njit def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50, disp=True): """ Find a zero from the secant method using the jitted version of Scipy's secant method. Note that `func` must be jitted via Numba. Parameters ---------- func : callable and jitted The function whose zero is wanted. It must be a function of a single variable of the form f(x,a,b,c...), where a,b,c... are extra arguments that can be passed in the `args` parameter. x0 : float An initial estimate of the zero that should be somewhere near the actual zero. args : tuple, optional(default=()) Extra arguments to be used in the function call. tol : float, optional(default=1.48e-8) The allowable error of the zero value. maxiter : int, optional(default=50) Maximum number of iterations. disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge. Returns ------- results : namedtuple A namedtuple containing the following items: :: root - Estimated location where function is zero. function_calls - Number of times the function was called. iterations - Number of iterations needed to find the root. converged - True if the routine converged """ if tol <= 0: raise ValueError("tol is too small <= 0") if maxiter < 1: raise ValueError("maxiter must be greater than 0") # Convert to float (don't use float(x0); this works also for complex x0) p0 = 1.0 * x0 funcalls = 0 status = _ECONVERR # Secant method if x0 >= 0: p1 = x0 * (1 + 1e-4) + 1e-4 else: p1 = x0 * (1 + 1e-4) - 1e-4 q0 = func(p0, *args) funcalls += 1 q1 = func(p1, *args) funcalls += 1 for itr in range(maxiter): if q1 == q0: p = (p1 + p0) / 2.0 status = _ECONVERGED break else: p = p1 - q1 * (p1 - p0) / (q1 - q0) if np.abs(p - p1) < tol: status = _ECONVERGED break p0 = p1 q0 = q1 p1 = p q1 = func(p1, *args) funcalls += 1 if disp and status == _ECONVERR: msg = "Failed to converge" raise RuntimeError(msg) return _results((p, funcalls, itr + 1, status))
@njit def _bisect_interval(a, b, fa, fb): """Conditional checks for intervals in methods involving bisection""" if fa*fb > 0: raise ValueError("f(a) and f(b) must have different signs") root = 0.0 status = _ECONVERR # Root found at either end of [a,b] if fa == 0: root = a status = _ECONVERGED if fb == 0: root = b status = _ECONVERGED return root, status
[docs]@njit def bisect(f, a, b, args=(), xtol=_xtol, rtol=_rtol, maxiter=_iter, disp=True): """ Find root of a function within an interval adapted from Scipy's bisect. Basic bisection routine to find a zero of the function `f` between the arguments `a` and `b`. `f(a)` and `f(b)` cannot have the same signs. `f` must be jitted via numba. Parameters ---------- f : jitted and callable Python function returning a number. `f` must be continuous. a : number One end of the bracketing interval [a,b]. b : number The other end of the bracketing interval [a,b]. args : tuple, optional(default=()) Extra arguments to be used in the function call. xtol : number, optional(default=2e-12) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The parameter must be nonnegative. rtol : number, optional(default=4*np.finfo(float).eps) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. maxiter : number, optional(default=100) Maximum number of iterations. disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge. Returns ------- results : namedtuple """ if xtol <= 0: raise ValueError("xtol is too small (<= 0)") if maxiter < 1: raise ValueError("maxiter must be greater than 0") # Convert to float xa = a * 1.0 xb = b * 1.0 fa = f(xa, *args) fb = f(xb, *args) funcalls = 2 root, status = _bisect_interval(xa, xb, fa, fb) # Check for sign error and early termination if status == _ECONVERGED: itr = 0 else: # Perform bisection dm = xb - xa for itr in range(maxiter): dm *= 0.5 xm = xa + dm fm = f(xm, *args) funcalls += 1 if fm * fa >= 0: xa = xm if fm == 0 or abs(dm) < xtol + rtol * abs(xm): root = xm status = _ECONVERGED itr += 1 break if disp and status == _ECONVERR: raise RuntimeError("Failed to converge") return _results((root, funcalls, itr, status))
[docs]@njit def brentq(f, a, b, args=(), xtol=_xtol, rtol=_rtol, maxiter=_iter, disp=True): """ Find a root of a function in a bracketing interval using Brent's method adapted from Scipy's brentq. Uses the classic Brent's method to find a zero of the function `f` on the sign changing interval [a , b]. `f` must be jitted via numba. Parameters ---------- f : jitted and callable Python function returning a number. `f` must be continuous. a : number One end of the bracketing interval [a,b]. b : number The other end of the bracketing interval [a,b]. args : tuple, optional(default=()) Extra arguments to be used in the function call. xtol : number, optional(default=2e-12) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The parameter must be nonnegative. rtol : number, optional(default=4*np.finfo(float).eps) The computed root ``x0`` will satisfy ``np.allclose(x, x0, atol=xtol, rtol=rtol)``, where ``x`` is the exact root. maxiter : number, optional(default=100) Maximum number of iterations. disp : bool, optional(default=True) If True, raise a RuntimeError if the algorithm didn't converge. Returns ------- results : namedtuple """ if xtol <= 0: raise ValueError("xtol is too small (<= 0)") if maxiter < 1: raise ValueError("maxiter must be greater than 0") # Convert to float xpre = a * 1.0 xcur = b * 1.0 fpre = f(xpre, *args) fcur = f(xcur, *args) funcalls = 2 root, status = _bisect_interval(xpre, xcur, fpre, fcur) # Check for sign error and early termination if status == _ECONVERGED: itr = 0 else: # Perform Brent's method for itr in range(maxiter): if fpre * fcur < 0: xblk = xpre fblk = fpre spre = scur = xcur - xpre if abs(fblk) < abs(fcur): xpre = xcur xcur = xblk xblk = xpre fpre = fcur fcur = fblk fblk = fpre delta = (xtol + rtol * abs(xcur)) / 2 sbis = (xblk - xcur) / 2 # Root found if fcur == 0 or abs(sbis) < delta: status = _ECONVERGED root = xcur itr += 1 break if abs(spre) > delta and abs(fcur) < abs(fpre): if xpre == xblk: # interpolate stry = -fcur * (xcur - xpre) / (fcur - fpre) else: # extrapolate dpre = (fpre - fcur) / (xpre - xcur) dblk = (fblk - fcur) / (xblk - xcur) stry = -fcur * (fblk * dblk - fpre * dpre) / \ (dblk * dpre * (fblk - fpre)) if (2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta)): # good short step spre = scur scur = stry else: # bisect spre = sbis scur = sbis else: # bisect spre = sbis scur = sbis xpre = xcur fpre = fcur if (abs(scur) > delta): xcur += scur else: xcur += (delta if sbis > 0 else -delta) fcur = f(xcur, *args) funcalls += 1 if disp and status == _ECONVERR: raise RuntimeError("Failed to converge") return _results((root, funcalls, itr, status))