问题
I am generating C code with sympy the using the Common Subexpression Elimination (CSE) routine and the ccode printer.
However, I would like to have powered expressions as (x*x) instead of pow(x,2).
Anyway to do this?
Example:
import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)
res = sp.cse(b)
lines = []
for tmp in res[0]:
lines.append(sp.ccode(tmp[1], tmp[0]))
for i,result in enumerate(res[1]):
lines.append(sp.ccode(result,"result_%i"%i))
Will output:
x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1, 2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8, 2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11, 2) + x12 + x7;
Best Regards
回答1:
There is a function called create_expand_pow_optimization
that creates a wrapper to optimise your expressions in this respect. It takes as an argument the highest power it will replace by explicit multiplications.
The wrapper returns an UnevaluatedExpr
that is protected against automatic simplifications that would revert this change.
import sympy as sp
from sympy.codegen.rewriting import create_expand_pow_optimization
expand_opt = create_expand_pow_optimization(3)
a = sp.Matrix(sp.MatrixSymbol('a',3,3))
res = sp.cse(a@a)
for i,result in enumerate(res[1]):
print(sp.ccode(expand_opt(result),"result_%i"%i))
Finally, be aware that for sufficiently high optimisation levels, your compiler will take care of this (and is probably better at this).
回答2:
You can subclass the code printer, and only change the one function you want different. You'd need to investigate the original sympy code to find the correct function names and default implementation, so you can make sure you don't make errors. With a bit of care, the needed brackets can be generated automatically exact when and where they are needed.
Here is a minimal example:
import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.printing.precedence import precedence
from sympy.abc import x
class CustomCodePrinter(C99CodePrinter):
def _print_Pow(self, expr):
PREC = precedence(expr)
if expr.exp == 2:
return '({0} * {0})'.format(self.parenthesize(expr.base, PREC))
else:
return super()._print_Pow(expr)
default_printer = C99CodePrinter().doprint
custom_printer = CustomCodePrinter().doprint
expressions = [x, (2 + x) ** 2, x ** 3, x ** 15, sp.sqrt(5), sp.sqrt(x)**4, 1 / x, 1 / (x * x)]
print("Default: {}".format(default_printer(expressions)))
print("Custom: {}".format(custom_printer(expressions)))
Output:
Default: [x, pow(x + 2, 2), pow(x, 3), pow(x, 15), sqrt(5), pow(x, 2), 1.0/x, pow(x, -2)]
Custom: [x, ((x + 2) * (x + 2)), pow(x, 3), pow(x, 15), sqrt(5), (x * x), 1.0/x, pow(x, -2)]
PS: To support a wider range of exponents, you could use e.g.
class CustomCodePrinter(C99CodePrinter):
def _print_Pow(self, expr):
PREC = precedence(expr)
if expr.exp in range(2, 7):
return '*'.join([self.parenthesize(expr.base, PREC)] * int(expr.exp))
elif expr.exp in range(-6, 0):
return '1.0/(' + ('*'.join([self.parenthesize(expr.base, PREC)] * int(-expr.exp))) + ')'
else:
return super()._print_Pow(expr)
回答3:
I think I will go with the user_function approach:
As suggested in the comment above I will be using the user_functions
functionality of sp.ccode:
Assuming we have a number like a^3
sp.ccode(a**3, user_functions={'Pow': [(lambda x, y: y.is_integer, lambda x, y: '*'.join(['('+x+')']*int(y))),(lambda x, y: not y.is_integer, 'pow')]})
should output:
'(a)*(a)*(a)'
In the future, I will try to improve the function to only put parenthesis when needed.
Any improvements are welcome!
来源:https://stackoverflow.com/questions/65534432/generate-c-code-with-sympy-replace-powx-2-by-xx