Convert sympy expressions to function of numpy arrays

后端 未结 2 1102
失恋的感觉
失恋的感觉 2020-12-16 15:48

I have a system of ODEs written in sympy:

from sympy.parsing.sympy_parser import parse_expr

xs = symbols(\'x1 x2\')
ks = symbols(\'k1 k2\')
strs = [\'-k1 *          


        
相关标签:
2条回答
  • 2020-12-16 16:12

    I wrote a module named JiTCODE, which is tailored to problems such as yours. It takes symbolic expressions, converts them to C code, wraps a Python extension around it, compiles this, and loads it for use with scipy.integrate.ode or scipy.integrate.solve_ivp.

    Your example would look like this:

    from jitcode import y, jitcode
    from sympy.parsing.sympy_parser import parse_expr
    from sympy import symbols
    
    xs = symbols('x1 x2')
    ks = symbols('k1 k2')
    strs = ['-k1 * x1**2 + k2 * x2', 'k1 * x1**2 - k2 * x2']
    syms = [parse_expr(item) for item in strs]
    
    substitutions = {x_i:y(i) for i,x_i in enumerate(xs)}
    f = [sym.subs(substitutions) for sym in syms]
    
    ODE = jitcode(f,control_pars=ks)
    

    You can then use ODE pretty much like an instance of scipy.integrate.ode.

    While you would not need this for your application, you can also extract and use the compiled function:

    ODE.compile_C()
    import numpy as np
    x = np.array([3.5, 1.5])
    k = np.array([4, 2])
    print(ODE.f(0.0,x,*k))
    

    Note that in contrast to your specifications, k is not passed as a NumPy array. For most ODE applications, this should not be relevant, because it is more efficient to hardcode the control parameters.

    Finally, note that for this small example, you may not get the best performance due to the overheads of scipy.integrate.ode or scipy.integrate.solve_ivp (also see SciPy Issue #8257 or this answer of mine). For large differential equations (as you have), this overhead becomes irrelevant.

    0 讨论(0)
  • 2020-12-16 16:21

    You can use the sympy function lambdify. For example,

    from sympy import symbols, lambdify
    from sympy.parsing.sympy_parser import parse_expr
    import numpy as np
    
    xs = symbols('x1 x2')
    ks = symbols('k1 k2')
    strs = ['-k1 * x1**2 + k2 * x2', 'k1 * x1**2 - k2 * x2']
    syms = [parse_expr(item) for item in strs]
    
    # Convert each expression in syms to a function with signature f(x1, x2, k1, k2):
    funcs = [lambdify(xs + ks, f) for f in syms]
    
    
    # This is not exactly the same as the `my_odes` in the question.
    # `t` is included so this can be used with `scipy.integrate.odeint`.
    # The value returned by `sym.subs` is wrapped in a call to `float`
    # to ensure that the function returns python floats and not sympy Floats.
    def my_odes(x, t, k):
        all_dict = dict(zip(xs, x))
        all_dict.update(dict(zip(ks, k)))
        return np.array([float(sym.subs(all_dict)) for sym in syms])
    
    def lambdified_odes(x, t, k):
        x1, x2 = x
        k1, k2 = k
        xdot = [f(x1, x2, k1, k2) for f in funcs]
        return xdot
    
    
    if __name__ == "__main__":
        from scipy.integrate import odeint
    
        k1 = 0.5
        k2 = 1.0
        init = [1.0, 0.0]
        t = np.linspace(0, 1, 6)
        sola = odeint(lambdified_odes, init, t, args=((k1, k2),))
        solb = odeint(my_odes, init, t, args=((k1, k2),))
        print(np.allclose(sola, solb))
    

    True is printed when the script is run.

    It is much faster (note the change in units of the timing results):

    In [79]: t = np.linspace(0, 10, 1001)
    
    In [80]: %timeit sol = odeint(my_odes, init, t, args=((k1, k2),))
    1 loops, best of 3: 239 ms per loop
    
    In [81]: %timeit sol = odeint(lambdified_odes, init, t, args=((k1, k2),))
    1000 loops, best of 3: 610 µs per loop
    
    0 讨论(0)
提交回复
热议问题