How can I assign/update subset of tensor shared variable in Theano?

二次信任 提交于 2019-12-17 23:03:14

问题


When compiling a function in theano, a shared variable(say X) can be updated by specifying updates=[(X, new_value)]. Now I am trying to update only subset of a shared variable:

from theano import tensor as T
from theano import function
import numpy

X = T.shared(numpy.array([0,1,2,3,4]))
Y = T.vector()
f = function([Y], updates=[(X[2:4], Y)] # error occur:
                                        # 'update target must 
                                        # be a SharedVariable'

The codes will raise a error "update target must be a SharedVariable", I guess that means update targets can't be non-shared variables. So is there any way to compile a function to just udpate subset of shared variables?


回答1:


Use set_subtensor or inc_subtensor:

from theano import tensor as T
from theano import function, shared
import numpy

X = shared(numpy.array([0,1,2,3,4]))
Y = T.vector()
X_update = (X, T.set_subtensor(X[2:4], Y))
f = function([Y], updates=[X_update])
f([100,10])
print X.get_value() # [0 1 100 10 4]

There's now a page about this in the Theano FAQ: http://deeplearning.net/software/theano/tutorial/faq_tutorial.html




回答2:


This code should solve your problem:

from theano import tensor as T
from theano import function, shared
import numpy

X = shared(numpy.array([0,1,2,3,4], dtype='int'))
Y = T.lvector()
X_update = (X, X[2:4]+Y)
f = function(inputs=[Y], updates=[X_update])
f([100,10])
print X.get_value()
# output: [102 13]

And here is the introduction about shared variables in the official tutorial.

Please ask, if you have further questions!



来源:https://stackoverflow.com/questions/15917849/how-can-i-assign-update-subset-of-tensor-shared-variable-in-theano

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!