When using the tensorflow\'s Dataset API Iterator, my goal is to define an RNN that operates on the iterator\'s get_next()
tensors as its input (see (1)
Turns out the mysterious error is likely a bug in tensorflow, see https://github.com/tensorflow/tensorflow/issues/14729. More specifically, the error really comes from feeding a wrong data type (in my example above, the data
array contains int32
values but it should contain floats).
Instead of getting the ValueError: Initializer for variable rnn/basic_lstm_cell/kernel/ is from inside a control-flow construct
error,
tensorflow should return:
TypeError: Tensors in list passed to 'values' of 'ConcatV2' Op have types [int32, float32] that don't all match.
(see 1).
To fix this problem, simply change
data = [ [[1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ]
to
data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32)
and then the following code shall work properly:
import tensorflow as tf
import numpy as np
from tensorflow.contrib.rnn import BasicLSTMCell
from tensorflow.python.data import Iterator
data = np.array([[ [1], [2], [3]], [[4], [5], [6]], [[1], [2], [3]] ], dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(2)
iterator = Iterator.from_structure(dataset.output_types,
dataset.output_shapes)
next_batch = iterator.get_next()
iterator_init = iterator.make_initializer(dataset)
# (2):
# X = tf.placeholder(tf.float32, shape=(None, 3, 1))
cell = BasicLSTMCell(num_units=8)
# (1):
outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, next_batch, dtype=tf.float32)
# (2):
# outputs, states = lstm_outputs, lstm_states = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
sess.run(iterator_init)
# (1):
o, s = sess.run([outputs, states])
o, s = sess.run([outputs, states])
# (2):
# o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})
# o, s = sess.run([outputs, states], feed_dict={X: next_batch.eval()})