speed up function that takes a function as argument with numba

允我心安 提交于 2020-08-21 19:48:28

问题


I am trying to use numba to speed up a function that takes another function as argument. A minimal example would be the following:

import numba as nb

def f(x):
    return x*x

@nb.jit(nopython=True)
def call_func(func,x):
    return func(x)

if __name__ == '__main__':
    print(call_func(f,5))

This, however, doesn't work, as apparently numba doesn't know what to do with that function argument. The traceback is quite long:

Traceback (most recent call last):
  File "numba_function.py", line 15, in <module>
    print(call_func(f,5))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
    raise e
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
    cres = self._compiler.compile(args, return_type)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
    flags=flags, locals=self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
    return pipeline.compile_extra(func)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
    return self._compile_bytecode()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
    return self._compile_core()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
    res = pm.run(self.status)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
    raise patched_exception
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
    stage()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
    self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
    infer.propagate()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
    raise errors[0]
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
    constraint(typeinfer)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
    self.resolve(typeinfer, typevars, fnty)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
    raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>

Is there a way to fix this?


回答1:


It depends if the func you pass to call_func can be compiled in nopython mode.

If it can't be compiled in nopython mode then it's impossible because numba doesn't support python calls inside a nopython function (that's the reason why it's called nopython).

However if it can be compiled in nopython mode you can use a closure:

import numba as nb

def f(x):
    return x*x

def call_func(func, x):
    func = nb.njit(func)   # compile func in nopython mode!
    @nb.njit
    def inner(x):
        return func(x)
    return inner(x)

if __name__ == '__main__':
    print(call_func(f,5))

That approach has some obvious downsides because it needs to compile func and inner every time you call call_func. That means it's only viable if the speedup by compiling the function is bigger than the compilation cost. You can mitigate that overhead if you call call_func with the same function several times:

import numba as nb

def f(x):
    return x*x

def call_func(func):  # only take func
    func = nb.njit(func)   # compile func in nopython mode!
    @nb.njit
    def inner(x):
        return func(x)
    return inner  # return the closure

if __name__ == '__main__':
    call_func_with_f = call_func(f)   # compile once
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version

Just a general note: I wouldn't create numba functions that take a function argument. If you can't hardcode the function numba can't produce really fast functions and if you also include the compilation cost for closures it's mostly just not worth it.




回答2:


As suggested by the error message, Numba cannot deal with values of type function. You can check in the documentation what types can Numba work with. The reason is that Numba cannot in general optimize (jit-compile) arbitrary functions in noptyhon mode, they are considered basically a black box (in fact, the passed function could even be a native one!).

The usual approach would be to ask Numba to optimize the called function instead. If you cannot add the decorator to the function (e.g. because it is not part of your source code), you can still use it manually like:

import numba as nb

def f(x):
    return x*x

if __name__ == '__main__':
    f_opt = nb.jit(nopython=True)(f)
    print(f_opt(5))

Obviously it will still fail if f cannot be compiled by Numba either, but in that case there's not much you can do anyway.



来源:https://stackoverflow.com/questions/45976662/speed-up-function-that-takes-a-function-as-argument-with-numba

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!