Equivalent for np.add.at in tensorflow

后端 未结 2 804
一整个雨季
一整个雨季 2021-01-21 07:10

How do I convert a np.add.at statement into tensorflow?

np.add.at(dW, self.x.ravel(), dout.reshape(-1, self.D))

Edit

s

相关标签:
2条回答
  • 2021-01-21 07:54

    For np.add.at, you probably want to look at tf.SparseTensor, which represents a tensor by a list of values and a list of indices (which is more suitable for sparse data, hence the name).

    So for your example:

    np.add.at(dW, self.x.ravel(), dout.reshape(-1, self.D))
    

    that would be (assuming dW, x and dout are tensors):

    tf.sparse_add(dW, tf.SparseTensor(x, tf.reshape(dout, [-1])))
    

    This is assuming x is of shape [n, nDims] (i.e. x is a 'list' of n indices, each of dimension nDims), and dout has shape [n].

    0 讨论(0)
  • 2021-01-21 08:05

    Here's an example of what np.add.at does:

    In [324]: a=np.ones((10,))
    In [325]: x=np.array([1,2,3,1,4,5])
    In [326]: b=np.array([1,1,1,1,1,1])
    In [327]: np.add.at(a,x,b)
    In [328]: a
    Out[328]: array([ 1.,  3.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.])
    

    If instead I use +=

    In [331]: a1=np.ones((10,))
    In [332]: a1[x]+=b
    In [333]: a1
    Out[333]: array([ 1.,  2.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.])
    

    note that a1[1] is 2, not 3.

    If instead I use an iterative solution

    In [334]: a2=np.ones((10,))
    In [335]: for i,j in zip(x,b):
         ...:     a2[i]+=j
         ...:     
    In [336]: a2
    Out[336]: array([ 1.,  3.,  2.,  2.,  2.,  2.,  1.,  1.,  1.,  1.])
    

    it matches.

    If x does not have duplicates then += works just fine. But with the duplicates, the add.at is required to match the iterative solution.

    0 讨论(0)
提交回复
热议问题