How do I use numba on a member function of a class?

牧云@^-^@ 提交于 2020-02-12 09:01:09

问题


I'm using the stable version of Numba 0.30.1.

I can do this:

import numba as nb
@nb.jit("void(f8[:])",nopython=True)                             
def complicated(x):                                  
    for a in x:
        b = a**2.+a**3.

as a test case, and the speedup is enormous. But I don't know how to proceed if I need to speed up a function inside a class.

import numba as nb
def myClass(object):
    def __init__(self):
        self.k = 1
    #@nb.jit(???,nopython=True)                             
    def complicated(self,x):                                  
        for a in x:
            b = a**2.+a**3.+self.k

What numba type do I use for the self object? I need to have this function inside a class since it needs to access a member variable.


回答1:


I was in a very similar situation and I found a way to use a Numba-JITed function inside of a class.

The trick is to use a static method, since this kind of methods are not called prepending the object instance to the argument list. The downside of not having access to self is that you cannot use variables defined outside of the method. So you have to pass them to the static method from a calling method that has access to self. In my case I did not need to define a wrapper method. I just had to split the method I wanted to JIT compile into two methods.

In the case of your example, the solution would be:

from numba import jit

class MyClass:
    def __init__(self):
        self.k = 1

    def calculation(self):
        k = self.k
        return self.complicated([1,2,3],k)

    @staticmethod
    @jit(nopython=True)                             
    def complicated(x,k):                                  
        for a in x:
            b = a**2 .+ a**3 .+ k



回答2:


You have several options:

Use a jitclass (http://numba.pydata.org/numba-doc/0.30.1/user/jitclass.html) to "numba-ize" the whole thing.

Or make the member function a wrapper and pass the member variables through:

import numba as nb

@nb.jit
def _complicated(x, k):
    for a in x:
        b = a**2.+a**3.+k

def myClass(object):
    def __init__(self):
        self.k = 1

    def complicated(self,x):                                  
        _complicated(x, self.k)


来源:https://stackoverflow.com/questions/41769100/how-do-i-use-numba-on-a-member-function-of-a-class

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!