How to find all zeros of a function using numpy (and scipy)?

前端 未结 4 645
忘掉有多难
忘掉有多难 2020-12-16 04:35

Suppose I have a function f(x) defined between a and b. This function can have many zeros, but also many asymptotes. I need to retriev

4条回答
  •  时光说笑
    2020-12-16 04:58

    I found out it's relatively easy to implement your own root finder using the scipy.optimize.fsolve.

    • Idea: Find any zeroes from interval (start, stop) and stepsize step by calling the fsolve repeatedly with changing x0. Use relatively small stepsize to find all the roots.

    • Can only search for zeroes in one dimension (other dimensions must be fixed). If you have other needs, I would recommend using sympy for calculating the analytical solution.

    • Note: It may not always find all the zeroes, but I saw it giving relatively good results. I put the code also to a gist, which I will update if needed.

    import numpy as np
    import scipy
    from scipy.optimize import fsolve
    from matplotlib import pyplot as plt
    
    # Defined below
    r = RootFinder(1, 20, 0.01)
    args = (90, 5)
    roots = r.find(f, *args)
    print("Roots: ", roots)
    
    # plot results
    u = np.linspace(1, 20, num=600)
    fig, ax = plt.subplots()
    ax.plot(u, f(u, *args))
    ax.scatter(roots, f(np.array(roots), *args), color="r", s=10)
    ax.grid(color="grey", ls="--", lw=0.5)
    plt.show()
    

    Example output:

    Roots:  [ 2.84599497  8.82720551 12.38857782 15.74736542 19.02545276]
    

    zoom-in:

    RootFinder definition

    import numpy as np
    import scipy
    from scipy.optimize import fsolve
    from matplotlib import pyplot as plt
    
    
    class RootFinder:
        def __init__(self, start, stop, step=0.01, root_dtype="float64", xtol=1e-9):
    
            self.start = start
            self.stop = stop
            self.step = step
            self.xtol = xtol
            self.roots = np.array([], dtype=root_dtype)
    
        def add_to_roots(self, x):
    
            if (x < self.start) or (x > self.stop):
                return  # outside range
            if any(abs(self.roots - x) < self.xtol):
                return  # root already found.
    
            self.roots = np.append(self.roots, x)
    
        def find(self, f, *args):
            current = self.start
    
            for x0 in np.arange(self.start, self.stop + self.step, self.step):
                if x0 < current:
                    continue
                x = self.find_root(f, x0, *args)
                if x is None:  # no root found.
                    continue
                current = x
                self.add_to_roots(x)
    
            return self.roots
    
        def find_root(self, f, x0, *args):
    
            x, _, ier, _ = fsolve(f, x0=x0, args=args, full_output=True, xtol=self.xtol)
            if ier == 1:
                return x[0]
            return None
    
    

    Test function

    The scipy.special.jnjn does not exist anymore, but I created similar test function for the case.

    def f(u, V=90, ell=5):
        w = np.sqrt(V ** 2 - u ** 2)
    
        jl = scipy.special.jn(ell, u)
        jl1 = scipy.special.yn(ell - 1, u)
        kl = scipy.special.kn(ell, w)
        kl1 = scipy.special.kn(ell - 1, w)
    
        return jl / (u * jl1) + kl / (w * kl1)
    

提交回复
热议问题