How to use tensorflow's Dataset API Iterator as an input of a (recurrent) neural network?

后端 未结 1 1291
闹比i
闹比i 2021-01-03 12:50

When using the tensorflow\'s Dataset API Iterator, my goal is to define an RNN that operates on the iterator\'s get_next() tensors as its input (see (1)

相关标签:
1条回答
  • 2021-01-03 13:46

    Turns out the mysterious error is likely a bug in tensorflow, see https://github.com/tensorflow/tensorflow/issues/14729. More specifically, the error really comes from feeding a wrong data type (in my example above, the data array contains int32 values but it should contain floats).

    Instead of getting the ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct error,
    tensorflow should return:
    TypeError: Tensors in list passed to 'values' of 'ConcatV2' Op have types [int32, float32] that don't all match. (see 1).

    To fix this problem, simply change
    data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]
    to
    data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32)

    and then the following code shall work properly:

    import tensorflow as tf
    import numpy as np
    
    from tensorflow.contrib.rnn import BasicLSTMCell
    from tensorflow.python.data import Iterator
    
    data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32)
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.batch(2)
    iterator = Iterator.from_structure(dataset.output_types,
                                       dataset.output_shapes)
    next_batch = iterator.get_next()
    iterator_init = iterator.make_initializer(dataset)
    
    # (2):
    # X = tf.placeholder(tf.float32, shape=(None, 3, 1))
    
    cell = BasicLSTMCell(num_units=8)
    
    # (1):
    outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)
    
    # (2):
    # outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
    
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        sess.run(iterator_init)
    
        # (1):
        o, s = sess.run([outputs, states])
        o, s = sess.run([outputs, states])
    
        # (2):
        # o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
        # o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
    
    0 讨论(0)
提交回复
热议问题