How to Implement Center Loss and Other Running Averages of Labeled Embeddings

后端 未结 2 440
说谎
说谎 2021-02-09 19:57

A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average em

相关标签:
2条回答
  • 2021-02-09 20:24

    The previously posted method is too simple for cases like center loss where the expected value of the embeddings change over time as the model becomes more refined. This is because the previous center-finding routine averages all instances since start and therefore tracks changes in expected value very slowly. Instead, a moving window average is preferred. An exponential moving-window variant is as follows:

    def get_embed_centers(embed_batch, label_batch):
        ''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
        '''
        decay = 0.95
        with tf.variable_scope('embed', reuse=True):
            embed_ctrs = tf.get_variable("ctrs")
    
        label_batch = tf.reshape(label_batch, [-1])
        old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
        dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
        embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
        embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
        return embed_ctrs_batch
    
    
    with tf.Session() as sess:
        with tf.variable_scope('embed'):
            embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
                            initializer=tf.constant_initializer(0), trainable=False)
        label_batch_ph = tf.placeholder(tf.int32)
        embed_batch_ph = tf.placeholder(tf.float32)
        embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
        sess.run(tf.initialize_all_variables())
        tf.get_default_graph().finalize()
    
    0 讨论(0)
  • 2021-02-09 20:41

    The get_new_centers() routine below takes in labelled embeddings and updates shared variables center/sums and center/cts. These variables are then used to calculate and return the embedding centers using the updated values.

    The loop just exercises get_new_centers() and shows that it converges to the expected average embeddings for all classes over time.

    Note that the alpha term used in the original paper isn't included here but should be straightforward to add if needed.

    ndims = 2
    nclass = 4
    nbatch = 100
    
    with tf.variable_scope('center'):
        center_sums = tf.get_variable("sums", [nclass, ndims], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
        center_cts = tf.get_variable("cts", [nclass], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
    
    def get_new_centers(embeddings, indices):
        '''
        Update embedding for selected class indices and return the new average embeddings.
        Only the newly-updated average embeddings are returned corresponding to
        the indices (including duplicates).
        '''
        with tf.variable_scope('center', reuse=True):
            center_sums = tf.get_variable("sums")
            center_cts = tf.get_variable("cts")
    
        # update embedding sums, cts
        if embeddings is not None:
            ones = tf.ones_like(indices, tf.float32)
            center_sums = tf.scatter_add(center_sums, indices, embeddings, name='sa1')
            center_cts = tf.scatter_add(center_cts, indices, ones, name='sa2')
    
        # return updated centers
        num = tf.gather(center_sums, indices)
        denom = tf.reshape(tf.gather(center_cts, indices), [-1, 1])
        return tf.div(num, denom)
    
    
    with tf.Session() as sess:
        labels_ph = tf.placeholder(tf.int32)
        embeddings_ph = tf.placeholder(tf.float32)
    
        unq_labels, ul_idxs = tf.unique(labels_ph)
        indices = tf.gather(unq_labels, ul_idxs)
        new_centers_with_update = get_new_centers(embeddings_ph, indices)
        new_centers = get_new_centers(None, indices)
    
        sess.run(tf.initialize_all_variables())
        tf.get_default_graph().finalize()
    
        for i in range(100001):
            embeddings = 100*np.random.randn(nbatch, ndims)
            labels = np.random.randint(0, nclass, nbatch)
            feed_dict = {embeddings_ph:embeddings, labels_ph:labels}
            rval = sess.run([new_centers_with_update], feed_dict)
            if i % 1000 == 0:
                feed_dict = {labels_ph:range(nclass)}
                rval = sess.run(new_centers, feed_dict)
                print('\nFor step ', i)
                for iclass in range(nclass):
                    print('Class %d, center: %s' % (iclass, str(rval[iclass])))
    

    A typical result at step 0 is:

    For step  0
    Class 0, center: [-1.7618252  -0.30574229]
    Class 1, center: [ -4.50493908  10.12403965]
    Class 2, center: [ 3.6156714  -9.94263649]
    Class 3, center: [-4.20281982 -8.28845882]
    

    and the output at step 10,000 demonstrates convergence:

    For step  10000
    Class 0, center: [ 0.00313433 -0.00757505]
    Class 1, center: [-0.03476512  0.04682625]
    Class 2, center: [-0.03865958  0.06585111]
    Class 3, center: [-0.02502561 -0.03370816]
    
    0 讨论(0)
提交回复
热议问题