Assuming I have a bunch of summaries defined like:
loss = ...
tf.scalar_summary(\"loss\", loss)
# ...
summaries = tf.m
I had the same problem when I realized I had to iterate over my validation data when the memory space cramped up and the OOM errors flooding.
As multiple of these answers say, the tf.metrics
have this built in, but I'm not using tf.metrics
in my project. So inspired by that, I made this:
import tensorflow as tf
import numpy as np
def batch_persistent_mean(tensor):
# Make a variable that keeps track of the sum
accumulator = tf.Variable(initial_value=tf.zeros_like(tensor), dtype=tf.float32)
# Keep count of batches in accumulator (needed to estimate mean)
batch_nums = tf.Variable(initial_value=tf.zeros_like(tensor), dtype=tf.float32)
# Make an operation for accumulating, increasing batch count
accumulate_op = tf.assign_add(accumulator, tensor)
step_batch = tf.assign_add(batch_nums, 1)
update_op = tf.group([step_batch, accumulate_op])
eps = 1e-5
output_tensor = accumulator / (tf.nn.relu(batch_nums - eps) + eps)
# In regards to the tf.nn.relu, it's a hacky zero_guard:
# if batch_nums are zero then return eps, else it'll be batch_nums
# Make an operation to reset
flush_op = tf.group([tf.assign(accumulator, 0), tf.assign(batch_nums, 0)])
return output_tensor, update_op, flush_op
# Make a variable that we want to accumulate
X = tf.Variable(0., dtype=tf.float32)
# Make our persistant mean operations
Xbar, upd, flush = batch_persistent_mean(X)
Now you send Xbar
to your summary e.g. tf.scalar_summary("mean_of_x", Xbar)
, and where you'd do sess.run(X)
before, you'll do sess.run(upd)
. And between epochs you'd do sess.run(flush)
.
### INSERT ABOVE CODE CHUNK IN S.O. ANSWER HERE ###
sess = tf.InteractiveSession()
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# Calculate the mean of 1+2+...+20
for i in range(20):
sess.run(upd, {X: i})
print(sess.run(Xbar), "=", np.mean(np.arange(20)))
for i in range(40):
sess.run(upd, {X: i})
# Now Xbar is the mean of (1+2+...+20+1+2+...+40):
print(sess.run(Xbar), "=", np.mean(np.concatenate([np.arange(20), np.arange(40)])))
# Now flush it
sess.run(flush)
print("flushed. Xbar=", sess.run(Xbar))
for i in range(40):
sess.run(upd, {X: i})
print(sess.run(Xbar), "=", np.mean(np.arange(40)))