问题
I have been using SymPy to expand the terms of a complex partial differential equation and would like to use the collect function to gather terms. However, it seems to have a problem dealing with second (or higher order) derivatives where the variables of differentiation differ.
In the code example below collect(expr6...
works, but collect(expr7 ...
does not, returning the error message "NotImplementedError: Improve MV Derivative support in collect"
. The error is clearly related to the psi.diff(x,y)
difference in the two cases. Is it obvious to anyone what I need to do to have collect(expr7 ...
work?
cheers
Richard
Example:
from sympy import *
psi = Function("psi") (x,y,z,t)
expr6=2*psi.diff(x,x)+3*U*psi.diff(x)+5*psi.diff(y)
expr7=2*psi.diff(x,y)+3*U*psi.diff(x)+5*psi.diff(y)
collect(expr6, psi.diff(x),evaluate=False, exact=False) # works
#collect(expr7, psi.diff(x),evaluate=False, exact=False)
# throws an error: NotImplementedError: Improve MV Derivative support in collect
回答1:
I've bumped into this issue and my workaround is to perform a substitution with simple dummy variables first, collect
based on these simple variables, and then substitute back the more advanced variables. There might be some corner cases, but it seems to work for me.
from sympy import symarray, collect
def mycollect(expr, var_list, evaluate=True, **kwargs):
""" Acts as collect but substitute the symbols with dummy symbols first so that it can work with partial derivatives.
Matrix expressions are also supported.
"""
if not hasattr(var_list, '__len__'):
var_list=[var_list]
# Mapping Var -> Dummy, and Dummy-> Var
Dummies=symarray('DUM', len(var_list))
Var2Dummy=[(var, Dummies[i]) for i,var in enumerate(var_list)]
Dummy2Var=[(b,a) for a,b in Var2Dummy]
# Replace var with dummies and apply collect
expr = expr.expand().doit()
expr = expr.subs(Var2Dummy)
if hasattr(expr, '__len__'):
expr = expr.applyfunc(lambda ij: collect(ij, Dummies, **kwargs))
else:
expr = collect(expr, Dummies, evaluate=evaluate, **kwargs)
# Substitute back
if evaluate:
return expr.subs(Dummy2Var)
d={}
for k,v in expr.items():
k=k.subs(Dummy2Var)
v=v.subs(Dummy2Var)
d[k]=v
return d
For your example:
mycollect(expr6, psi.diff(x), evaluate=False)
mycollect(expr7, psi.diff(x), evaluate=False)
returns:
{Derivative(psi(x, y, z, t), (x, 2)): 2, Derivative(psi(x, y, z, t), x): 3*U, 1: 5*Derivative(psi(x, y, z, t), y)}
{Derivative(psi(x, y, z, t), x, y): 2, Derivative(psi(x, y, z, t), x): 3*U, 1: 5*Derivative(psi(x, y, z, t), y)}
来源:https://stackoverflow.com/questions/58700443/how-do-i-get-sympy-to-collect-partial-derivatives