Optimizing dict of set of tuple of ints with Numba?

匿名 (未验证) 提交于 2019-12-03 03:10:03

问题:

I am learning how to use Numba (while I am already fairly familiar with Cython). How should I go about speeding up this code? Notice the function returns a dict of sets of two-tuples of ints. I am using IPython notebook. I would prefer Numba over Cython.

@autojit def generateadj(width,height):     adj = {}     for y in range(height):         for x in range(width):             s = set()             if x>0:                 s.add((x-1,y))             if x<width-1:                 s.add((x+1,y))             if y>0:                 s.add((x,y-1))             if y<height-1:                 s.add((x,y+1))             adj[x,y] = s     return adj 

I managed to write this in Cython but I had to give up on the way data is structured. I do not like this. I read somewhere in Numba documentation that it can work with basic things like lists, tuples, etc.

%%cython import numpy as np  def generateadj(int width, int height):     cdef int[:,:,:,:] adj = np.zeros((width,height,4,2), np.int32)     cdef int count      for y in range(height):         for x in range(width):             count = 0             if x>0:                 adj[x,y,count,0] = x-1                 adj[x,y,count,1] = y                 count += 1             if x<width-1:                 adj[x,y,count,0] = x+1                 adj[x,y,count,1] = y                 count += 1             if y>0:                 adj[x,y,count,0] = x                 adj[x,y,count,1] = y-1                 count += 1             if y<height-1:                 adj[x,y,count,0] = x                 adj[x,y,count,1] = y+1                 count += 1             for i in range(count,4):                 adj[x,y,i] = adj[x,y,0]     return adj 

回答1:

While numba supports such Python data structures as dicts and sets, it does so in object mode. From the numba glossary, object mode is defined as:

A Numba compilation mode that generates code that handles all values as Python objects and uses the Python C API to perform all operations on those objects. Code compiled in object mode will often run no faster than Python interpreted code, unless the Numba compiler can take advantage of loop-jitting.

So when writing numba code, you need to stick to built-in data types such as arrays. Here's some code that does just that:

@jit def gen_adj_loop(width, height, adj):     i = 0     for x in range(width):         for y in range(height):             if x > 0:                 adj[i,0] = x                 adj[i,1] = y                 adj[i,2] = x - 1                 adj[i,3] = y                 i += 1              if x < width - 1:                 adj[i,0] = x                 adj[i,1] = y                 adj[i,2] = x + 1                 adj[i,3] = y                 i += 1              if y > 0:                 adj[i,0] = x                 adj[i,1] = y                 adj[i,2] = x                 adj[i,3] = y - 1                 i += 1              if y < height - 1:                 adj[i,0] = x                 adj[i,1] = y                 adj[i,2] = x                 adj[i,3] = y + 1                 i += 1     return 

This takes an array adj. Each row has the form x y adj_x adj_y. So for the pixel at (3,4), we'd have the four rows:

3 4 2 4 3 4 4 4 3 4 3 3 3 4 3 5 

We can wrap the above function in another:

@jit def gen_adj(width, height):     # each pixel has four neighbors, but some of these neighbors are     # off the grid -- 2*width + 2*height of them to be exact     n_entries = width*height*4 - 2*width - 2*height     adj = np.zeros((n_entries, 4), dtype=int)     gen_adj_loop(width, height, adj) 

This function is very fast, but incomplete. We must convert adj to a dictionary of the form in your question. The problem is that this is a very slow process. We must iterate over the adj array and add each entry to a Python dictionary. This cannot be jitted by numba.

So the bottom line is this: the requirement that the result is a dictionary of tuples really constrains how much you can optimize this code.



标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!