Numba jit and deferred types

℡╲_俬逩灬. 提交于 2020-08-25 03:53:22

问题


I'm passing numba as signature of my function

@numba.jit(numba.types.UniTuple(numba.float64[:, :], 2)(
    numba.float64[:, :], numba.float64[:, :], numba.float64[:, :], 
earth_model_type))

where earth_model_type is defined as

earth_model_type = numba.deferred_type()
earth_model_type.define(em.EarthModel.class_type.instance_type)

and it compiles fine, but when I try to call the function I get

*** TypeError: No matching definition for argument type(s) array(float64, 2d, F), array(float64, 2d, C), array(float64, 2d, F), instance.jitclass.EarthModel#7fd9c48dd668

The types of the arguments with no-matching definition look to me pretty much the same as my types above. On the other hand, if I don't specify the signature by just using @numba.jit(nopython=True), it works fine and the signature of the function compiled by numba is

ipdb> numbed_cowell_propagator_propagate.signatures                   

[(array(float64, 2d, F), array(float64, 2d, C), array(float64, 2d, F), instance.jitclass.EarthModel#7f81bbc0e780)]

EDIT

If I enforce C-order arrays using the way in the FAQ I still get an error

TypeError: No matching definition for argument type(s) array(float64, 2d, C), array(float64, 2d, C), array(float64, 2d, C), instance.jitclass.EarthModel#7f6edd8d57b8

I'm pretty sure the problem is related to the deferred type, since if instead of passing the jit class, I pass all the attributes I need from that class (4 numba.float64s), it works fine.

What am I doing wrong when I specify the signature?

Cheers.


回答1:


Without understanding exactly how your full code works, I'm not sure why you need to use a deferred type. Typically it is used for jitclasses that contain an instance variable that is of the same type, like a linked list or some other tree of nodes, and therefore needs to be deferred until after the compiler processes the class itself (see the source) The following minimal example works (I can reproduce your error if I use a deferred type):

import numpy as np
import numba as nb

spec = [('x', nb.float64)]

@nb.jitclass(spec)
class EarthModel:
    def __init__(self, x):
        self.x = x

earth_model_type = EarthModel.class_type.instance_type

@nb.jit(nb.float64(nb.float64[:, :], nb.float64[:, :], nb.float64[:, :], earth_model_type))
def test(x, y, z, em):
    return em.x

and then running it:

em = EarthModel(9.9)
x = np.random.normal(size=(3,3))
y = np.random.normal(size=(3,3))
z = np.random.normal(size=(3,3))

res = test(x, y, z, em)
print(res)  # 9.9


来源:https://stackoverflow.com/questions/57640039/numba-jit-and-deferred-types

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!