问题
I am trying to use numba
to speed up a function that takes another function as argument. A minimal example would be the following:
import numba as nb
def f(x):
return x*x
@nb.jit(nopython=True)
def call_func(func,x):
return func(x)
if __name__ == '__main__':
print(call_func(f,5))
This, however, doesn't work, as apparently numba
doesn't know what to do with that function argument. The traceback is quite long:
Traceback (most recent call last):
File "numba_function.py", line 15, in <module>
print(call_func(f,5))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
return self.compile(tuple(argtypes))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
return pipeline.compile_extra(func)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
return self._compile_bytecode()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
return self._compile_core()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
res = pm.run(self.status)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
raise patched_exception
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
stage()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
infer.propagate()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
raise errors[0]
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
constraint(typeinfer)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>
Is there a way to fix this?
回答1:
It depends if the func
you pass to call_func
can be compiled in nopython
mode.
If it can't be compiled in nopython mode then it's impossible because numba doesn't support python calls inside a nopython function (that's the reason why it's called nopython).
However if it can be compiled in nopython mode you can use a closure:
import numba as nb
def f(x):
return x*x
def call_func(func, x):
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner(x)
if __name__ == '__main__':
print(call_func(f,5))
That approach has some obvious downsides because it needs to compile func
and inner
every time you call call_func
. That means it's only viable if the speedup by compiling the function is bigger than the compilation cost. You can mitigate that overhead if you call call_func
with the same function several times:
import numba as nb
def f(x):
return x*x
def call_func(func): # only take func
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner # return the closure
if __name__ == '__main__':
call_func_with_f = call_func(f) # compile once
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
Just a general note: I wouldn't create numba functions that take a function argument. If you can't hardcode the function numba can't produce really fast functions and if you also include the compilation cost for closures it's mostly just not worth it.
回答2:
As suggested by the error message, Numba cannot deal with values of type function
. You can check in the documentation what types can Numba work with. The reason is that Numba cannot in general optimize (jit-compile) arbitrary functions in noptyhon
mode, they are considered basically a black box (in fact, the passed function could even be a native one!).
The usual approach would be to ask Numba to optimize the called function instead. If you cannot add the decorator to the function (e.g. because it is not part of your source code), you can still use it manually like:
import numba as nb
def f(x):
return x*x
if __name__ == '__main__':
f_opt = nb.jit(nopython=True)(f)
print(f_opt(5))
Obviously it will still fail if f
cannot be compiled by Numba either, but in that case there's not much you can do anyway.
来源:https://stackoverflow.com/questions/45976662/speed-up-function-that-takes-a-function-as-argument-with-numba