Avoiding duplicating graph in tensorflow (LSTM model)

青春壹個敷衍的年華 提交于 2019-12-24 10:39:58

问题


I have the following simplified code (actually, unrolled LSTM model):

def func(a, b):
    with tf.variable_scope('name'):
        res = tf.add(a, b)
    print(res.name)
    return res

func(tf.constant(10), tf.constant(20))

Whenever I run the last line, it seems that it changes the graph. But I don't want the graph changes. Actually my code is different and is a neural network model but it is too huge, so I've added the above code. I want to call the func without changing the graph of model but it changes. I read about variable scope in TensorFlow but it seems that I've not understand it at all.


回答1:


You should take a look at the source code of tf.nn.dynamic_rnn, specifically _dynamic_rnn_loop function at python/ops/rnn.py - it's solving the same problem. In order not blow up the graph, it's using tf.while_loop to reuse the same graph ops for new data. But this approach adds several restrictions, namely the shape of tensors that are passing through in a loop must be invariant. See the examples in tf.while_loop documentation:

i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])


来源:https://stackoverflow.com/questions/49114306/avoiding-duplicating-graph-in-tensorflow-lstm-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!