问题
I understand that there are advantages (especially as I expand the scope of the models I build and the size of the datasets they work on) to using TensorFlow's new Dataset as the idiom for my data feeding pipeline. However I'm having trouble mapping my existing feed_dict
based code to this new model.
One problem I face is that I can't sort out how batching and epochs interact, or how these interleave with the logging and validation that I often do.
For example, how does something like the following map to using Dataset
?
# Load and process data into tensors of dimension (N, C_i) for input and (N, C_o) for output
# where N is the number of examples and C_ is the number of chanels, and the values are activations
train_x, train_y, valid_x, valid_y = load_data(file, [segments], ...)
train_size = len(train_x)
train_stats_feed = {input_activation: train_x, correct_output: train_y, is_train: False}
valid_stats_feed = {input_activation: valid_x, correct_output: valid_y, is_train: False}
with tf.Session(config=tf.ConfigProto(...)) as sess:
sess.run(tf.initialize_all_variables())
# Some analysis; not always done but the code needs to support it
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), 0)
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), 0)
test_writer.add_summary(sess.run(gs_summary), 0)
print(log_fmt.format(0, float(sess.run(accuracy, feed_dict=valid_stats_feed)),
float(sess.run(loss, feed_dict=valid_stats_feed))))
for ep in range(epochs):
# Slice the training data into random batches
batch_indices = np.array_split(np.random.permutation(train_size), int(train_size/mb_size))
for mini_batch_indices in batch_indices:
sess.run(train_step, feed_dict={input_activation: train_x[mini_batch_indices],
correct_output: train_y[mini_batch_indices], is_train: True})
gs = int(sess.run(global_step))
if gs % log_steps == 0:
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), gs)
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), gs)
acc = float(sess.run(accuracy, feed_dict=valid_stats_feed))
sess.run(validation_accuracy.assign(acc))
print(log_fmt.format(gs, acc, float(sess.run(loss, feed_dict=valid_stats_feed))))
print(ep_fmt.format(ep + 2))
test_writer.add_summary(sess.run(gs_summary), ep + 1)
Some of the less obvious definitions for the above, if needed:
# Preliminaries
# Some basic preliminaries, the details of which are not important to the question
# Mostly pretty standard; obvious things omitted from MWE for brevity
global_step = tf.Variable(0, trainable=False, name='global_step')
validation_accuracy = tf.Variable(0.0, trainable=False, name='validation_accuracy', dtype=tf.float32)
is_train = tf.placeholder(tf.bool, [], name='is_train')
input_activation = tf.placeholder(tf.float32, shape=[None, in_nodes], name='inputs')
correct_output = tf.placeholder(tf.float32, shape=[None, out_nodes], name='correct_outputs')
network_output = tf.identity(out_activations)
correct_predictions = correct_fn(correct_output, network_output)
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
error = cost_fn(correct_output, network_output)
loss = error + FLAGS.regularization_weight * sum(tf.nn.l2_loss(w) for w in layer_weights)
train_step = tf.train.MomentumOptimizer(learning_rate, momentum=momentum).minimize(loss, global_step=global_step)
# Logging
train_writer = tf.summary.FileWriter(trainlogfile, tf.get_default_graph())
test_writer = tf.summary.FileWriter(testlogfile, tf.get_default_graph())
gs_summary = tf.summary.scalar('global_step_at_epoch', global_step)
merged = tf.summary.merge_all()
回答1:
Here're few lines for training to get started. Same logics apply for validation
# Define placeholder for inputs data and labels
inputs_placeholder = tf.placeholder(train_x.dtype, train_x.shape)
labels_placeholder = tf.placeholder(train_y.dtype, train_y.shape)
# Define a Dataset object using the above placeholders
dataset = tf.contrib.data.Dataset.from_tensor_slices((inputs_placeholder, labels_placeholder))
# Define batch_size
batch_size = 128
dataset = dataset.batch(batch_size)
# Define iterator
iterator = dataset.make_initializable_iterator()
# Get one batch
next_example, next_label = iterator.get_next()
# calculate loss from the model fucntion you are using
loss = some_model(next_example, next_label)
# Set number of Epochs here
num_epochs = 100
for _ in range(num_epochs):
sess.run(iterator.initializer, feed_dict={inputs_placeholder: train_x, labels_placeholder: train_y}))
while True:
try:
_loss = sess.run(loss)
except tf.errors.OutOfRangeError:
break
来源:https://stackoverflow.com/questions/45399907/how-do-i-convert-my-basic-feed-based-tensorflow-code-to-use-dataset