How to Implement Center Loss and Other Running Averages of Labeled Embeddings

后端 未结 2 442
说谎
说谎 2021-02-09 19:57

A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average em

2条回答
  •  太阳男子
    2021-02-09 20:41

    The get_new_centers() routine below takes in labelled embeddings and updates shared variables center/sums and center/cts. These variables are then used to calculate and return the embedding centers using the updated values.

    The loop just exercises get_new_centers() and shows that it converges to the expected average embeddings for all classes over time.

    Note that the alpha term used in the original paper isn't included here but should be straightforward to add if needed.

    ndims = 2
    nclass = 4
    nbatch = 100
    
    with tf.variable_scope('center'):
        center_sums = tf.get_variable("sums", [nclass, ndims], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
        center_cts = tf.get_variable("cts", [nclass], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
    
    def get_new_centers(embeddings, indices):
        '''
        Update embedding for selected class indices and return the new average embeddings.
        Only the newly-updated average embeddings are returned corresponding to
        the indices (including duplicates).
        '''
        with tf.variable_scope('center', reuse=True):
            center_sums = tf.get_variable("sums")
            center_cts = tf.get_variable("cts")
    
        # update embedding sums, cts
        if embeddings is not None:
            ones = tf.ones_like(indices, tf.float32)
            center_sums = tf.scatter_add(center_sums, indices, embeddings, name='sa1')
            center_cts = tf.scatter_add(center_cts, indices, ones, name='sa2')
    
        # return updated centers
        num = tf.gather(center_sums, indices)
        denom = tf.reshape(tf.gather(center_cts, indices), [-1, 1])
        return tf.div(num, denom)
    
    
    with tf.Session() as sess:
        labels_ph = tf.placeholder(tf.int32)
        embeddings_ph = tf.placeholder(tf.float32)
    
        unq_labels, ul_idxs = tf.unique(labels_ph)
        indices = tf.gather(unq_labels, ul_idxs)
        new_centers_with_update = get_new_centers(embeddings_ph, indices)
        new_centers = get_new_centers(None, indices)
    
        sess.run(tf.initialize_all_variables())
        tf.get_default_graph().finalize()
    
        for i in range(100001):
            embeddings = 100*np.random.randn(nbatch, ndims)
            labels = np.random.randint(0, nclass, nbatch)
            feed_dict = {embeddings_ph:embeddings, labels_ph:labels}
            rval = sess.run([new_centers_with_update], feed_dict)
            if i % 1000 == 0:
                feed_dict = {labels_ph:range(nclass)}
                rval = sess.run(new_centers, feed_dict)
                print('\nFor step ', i)
                for iclass in range(nclass):
                    print('Class %d, center: %s' % (iclass, str(rval[iclass])))
    

    A typical result at step 0 is:

    For step  0
    Class 0, center: [-1.7618252  -0.30574229]
    Class 1, center: [ -4.50493908  10.12403965]
    Class 2, center: [ 3.6156714  -9.94263649]
    Class 3, center: [-4.20281982 -8.28845882]
    

    and the output at step 10,000 demonstrates convergence:

    For step  10000
    Class 0, center: [ 0.00313433 -0.00757505]
    Class 1, center: [-0.03476512  0.04682625]
    Class 2, center: [-0.03865958  0.06585111]
    Class 3, center: [-0.02502561 -0.03370816]
    

提交回复
热议问题