How to pass array pointer to Numba function?

余生颓废 提交于 2020-05-16 22:05:16

问题


I'd like to create a Numba-compiled function that takes a pointer or the memory address of an array as an argument and does calculations on it, e.g., modifies the underlying data.

The pure-python version to illustrate this looks like this:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array


def modify_data(addr):
    """ a function taking the memory address of an array to modify it """
    ptr = ctypes.c_void_p(addr)
    data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
    data += 2

addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])

As you can see in the example, the array arr got modified without passing it to the function explicitly. In my use case, the shape and dtype of the array are known and will remain unchanged at all times, which should simplify the interface.

1. Attempt: Naive jitting

I now tried to compile the modify_data function, but failed. My first attempt was to use

shape = arr.shape
dtype = arr.dtype

@nb.njit
def modify_data_nb(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2


ptr = ctypes.c_void_p(addr)
modify_data_nb(ptr)   # <<< error

This failed with cannot determine Numba type of <class 'ctypes.c_void_p'>, i.e., it does not know how to interpret the pointer.

2. Attempt: Explicit types

I tried putting explicit types,

arr_ptr_type = nb.types.CPointer(nb.float64)
shape = arr.shape

@nb.njit(nb.types.void(arr_ptr_type))
def modify_data_nb(ptr):
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(ptr, shape)
    data += 2

but this did not help. It did not throw any errors, but I do not know how to call the function modify_data_nb. I tried the following options

modify_data_nb(arr.ctypes.data)
# TypeError: No matching definition for argument type(s) int64

ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

ptr = ctypes.c_void_p(arr.ctypes.data)
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

Is there a way to obtain the correct pointer format from arr so I can pass it to the Numba-compiled modify_data_nb function? Alternatively, is there another way of passing the memory location to function.

3. Attempt: Using scipy.LowLevelCallable

I made some progress by using scipy.LowLevelCallable and its magic:

arr = np.arange(3).astype(np.double)
print(arr)
# >>> array([0., 1., 2.])

# create the function taking a pointer
shape = arr.shape
dtype = arr.dtype

@nb.cfunc(nb.types.void(nb.types.CPointer(nb.types.double)))
def modify_data(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2

modify_data_llc = LowLevelCallable(modify_data.ctypes).function    

# create pointer to array
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))

# call the function only with the pointer
modify_data_llc(ptr)

# check whether array got modified
print(arr)
# >>> array([2., 3., 4.])

I can now call a function to access the array, but this function is no longer a Numba function. In particular, it cannot be used in other Numba functions.


回答1:


Thanks to the great @stuartarchibald, I now have a working solution:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array
print(arr)

@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
    """ returns a void pointer from a given memory address """
    from numba.core import types, cgutils
    sig = types.voidptr(src)

    def codegen(cgctx, builder, sig, args):
        return builder.inttoptr(args[0], cgutils.voidptr_t)
    return sig, codegen

addr = arr.ctypes.data

@nb.njit
def modify_data():
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(address_as_void_pointer(addr), arr.shape, dtype=arr.dtype)
    data += 2

modify_data()
print(arr)

The key is the new address_as_void_pointer function that turns a memory address (given as an int) into a pointer that is usable by numba.carray.



来源:https://stackoverflow.com/questions/61509903/how-to-pass-array-pointer-to-numba-function

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!