Source code for neutralocean.fzero

"""
Functions for finding the zero of a univariate function.
"""

import numpy as np
import numba as nb


eps = np.finfo(np.float64).eps


[docs]@nb.njit def brent_guess(f, x, A, B, t, args=()): """ Find a zero of a function within a given range, starting from a guess Parameters ---------- f : function Continuous function of a single variable. x : float initial guess for a root A, B : float Range within which to search, satisfying `A < B` t : float Tolerance for convergence. args : tuple Additional arguments, beyond the optimization argument, to be passed to `f`. Pass `()` when `f` is univariate. Returns ------- float Value of `x` where `f(x) ~ 0`. """ a, b = guess_to_bounds(f, x, A, B, args) return brent(f, a, b, t, args)
[docs]@nb.njit def brent(f, a, b, t, args=()): """ Find a zero of a univariate function within a given range This is a bracketed root-finding method, so `f(a)` and `f(b)` must differ in sign. If they do, a root is guaranteed to be found. Parameters ---------- f : function Continuous function of a single variable. a, b : float Range within which to search, satisfying `a < b` and ideally `f(a) * f(b) <= 0` t : float Tolerance for convergence. args : tuple Additional arguments, beyond the optimization argument, to be passed to `f`. Pass `()` when `f` is univariate. Returns ------- float Value of `x` where `f(x) ~ 0`. Notes ----- `f` should be a `@numba.njit`'ed function (when this function is `njit`'ed). """ # Protection against bad input search range if np.isnan(a) or np.isnan(b) or a > b: return np.nan fa = f(a, *args) fb = f(b, *args) # Protection against input range that doesn't have a sign change if fa * fb > 0: # DEV note: check if this should be fa * fb >= 0 return np.nan c = a fc = fa e = b - a d = e while True: if abs(fc) < abs(fb): a = b b = c c = a fa = fb fb = fc fc = fa tol = 2.0 * eps * abs(b) + t m = 0.5 * (c - b) if abs(m) <= tol or fb == 0.0: break if abs(e) < tol or abs(fa) <= abs(fb): e = m d = e else: s = fb / fa if a == c: p = 2.0 * m * s q = 1.0 - s else: q = fa / fc r = fb / fc p = s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0)) q = (q - 1.0) * (r - 1.0) * (s - 1.0) if 0.0 < p: q = -q else: p = -p s = e e = d if 2.0 * p < 3.0 * m * q - abs(tol * q) and p < abs(0.5 * s * q): d = p / q else: e = m d = e a = b fa = fb if tol < abs(d): b += d elif 0.0 < m: b += tol else: b -= tol fb = f(b, *args) if (0.0 < fb and 0.0 < fc) or (fb <= 0.0 and fc <= 0.0): c = a fc = fa e = b - a d = e return b
[docs]@nb.njit def guess_to_bounds(f, x, A, B, args=()): """ Search for a range containing a sign change, expanding geometrically outwards from the initial guess. This is used as a first step in zero-finding, providing a small search range for the Brent algorithm. Parameters ---------- f : function Continuous function of a single variable x : float Central point for starting the search A, B : float Lower and upper bounds, containing `x`, within which to search for a zero. args : tuple Additional arguments beyond the optimization argument. Pass `()` when `f` is univariate. Returns ------- a, b : float Lower and upper bounds within which `f(x)` changes sign. """ nan = np.nan # Check value of f at bounds fa = f(A, *args) if fa == 0.0: return (A, A) fb = f(B, *args) if fb == 0.0: return (B, B) x = min(max(x, A), B) # initial distance to expand outward from x, in positive and negative directions dxp = (B - x) / 50 dxm = (x - A) / 50 # Set a = x, except when x is so close to A that machine roundoff makes dxm identically 0 # which would lead to an infinite loop below. In this case, set a = A. if dxm == 0: a = A else: a = x # Similarly, set b = x, except for machine precision problems. if dxp == 0: b = B else: b = x if a > A: fbpos = f(b, *args) > 0.0 else: # a == A if b == B: # So dxm == 0 and dxp == 0. So A very nearly equals B, but could # have A != B due to machine precision problems fapos = f(a, *args) > 0.0 fbpos = f(b, *args) > 0.0 else: # b < B fapos = f(a, *args) > 0.0 while True: if a > A: # Move a left, and test for a sign change dxm *= 1.414213562373095 a = max(x - dxm, A) fapos = f(a, *args) > 0.0 if fapos != fbpos: # fa and fb have different signs return (a, b) elif b == B: # also a == A, so cannot expand anymore if fapos != fbpos: # one last test for sign change return (a, b) else: # no sign change found return (nan, nan) if b < B: # Move b right, and test for a sign change dxp *= 1.414213562373095 b = min(x + dxp, B) fbpos = f(b, *args) > 0.0 if fapos != fbpos: # fa and fb have different signs return (a, b) elif a == A: # also b == B, so cannot expand anymore if fapos != fbpos: # one last test for sign change return (a, b) else: # no sign change found return (nan, nan)