How to Implement Center Loss and Other Running Averages of Labeled Embeddings

后端 未结 2 441
说谎
说谎 2021-02-09 19:57

A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average em

2条回答
  •  忘了有多久
    2021-02-09 20:24

    The previously posted method is too simple for cases like center loss where the expected value of the embeddings change over time as the model becomes more refined. This is because the previous center-finding routine averages all instances since start and therefore tracks changes in expected value very slowly. Instead, a moving window average is preferred. An exponential moving-window variant is as follows:

    def get_embed_centers(embed_batch, label_batch):
        ''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
        '''
        decay = 0.95
        with tf.variable_scope('embed', reuse=True):
            embed_ctrs = tf.get_variable("ctrs")
    
        label_batch = tf.reshape(label_batch, [-1])
        old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
        dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
        embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
        embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
        return embed_ctrs_batch
    
    
    with tf.Session() as sess:
        with tf.variable_scope('embed'):
            embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
                            initializer=tf.constant_initializer(0), trainable=False)
        label_batch_ph = tf.placeholder(tf.int32)
        embed_batch_ph = tf.placeholder(tf.float32)
        embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
        sess.run(tf.initialize_all_variables())
        tf.get_default_graph().finalize()
    

提交回复
热议问题