How to average summaries over multiple batches?

后端 未结 9 1384
刺人心
刺人心 2020-12-13 09:18

Assuming I have a bunch of summaries defined like:

loss = ...
tf.scalar_summary(\"loss\", loss)
# ...
summaries = tf.m         


        
9条回答
  •  醉梦人生
    2020-12-13 10:07

    I had the same problem when I realized I had to iterate over my validation data when the memory space cramped up and the OOM errors flooding.

    As multiple of these answers say, the tf.metrics have this built in, but I'm not using tf.metrics in my project. So inspired by that, I made this:

    import tensorflow as tf
    import numpy as np
    
    
    def batch_persistent_mean(tensor):
        # Make a variable that keeps track of the sum
        accumulator = tf.Variable(initial_value=tf.zeros_like(tensor), dtype=tf.float32)
        # Keep count of batches in accumulator (needed to estimate mean)
        batch_nums = tf.Variable(initial_value=tf.zeros_like(tensor), dtype=tf.float32)
        # Make an operation for accumulating, increasing batch count
        accumulate_op = tf.assign_add(accumulator, tensor)
        step_batch = tf.assign_add(batch_nums, 1)
        update_op = tf.group([step_batch, accumulate_op])
        eps = 1e-5
        output_tensor = accumulator / (tf.nn.relu(batch_nums - eps) + eps)
        # In regards to the tf.nn.relu, it's a hacky zero_guard:
        # if batch_nums are zero then return eps, else it'll be batch_nums
        # Make an operation to reset
        flush_op = tf.group([tf.assign(accumulator, 0), tf.assign(batch_nums, 0)])
        return output_tensor, update_op, flush_op
    
    # Make a variable that we want to accumulate
    X = tf.Variable(0., dtype=tf.float32)
    # Make our persistant mean operations
    Xbar, upd, flush = batch_persistent_mean(X)
    

    Now you send Xbar to your summary e.g. tf.scalar_summary("mean_of_x", Xbar), and where you'd do sess.run(X) before, you'll do sess.run(upd). And between epochs you'd do sess.run(flush).

    Testing behaviour:

    ### INSERT ABOVE CODE CHUNK IN S.O. ANSWER HERE ###
    sess = tf.InteractiveSession()
    with tf.Session() as sess:
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        # Calculate the mean of 1+2+...+20
        for i in range(20):
            sess.run(upd, {X: i})
        print(sess.run(Xbar), "=", np.mean(np.arange(20)))
        for i in range(40):
            sess.run(upd, {X: i})
        # Now Xbar is the mean of (1+2+...+20+1+2+...+40):
        print(sess.run(Xbar), "=", np.mean(np.concatenate([np.arange(20), np.arange(40)])))
        # Now flush it
        sess.run(flush)
        print("flushed. Xbar=", sess.run(Xbar))
        for i in range(40):
            sess.run(upd, {X: i})
        print(sess.run(Xbar), "=", np.mean(np.arange(40)))
    

提交回复
热议问题