A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average em
The previously posted method is too simple for cases like center loss where the expected value of the embeddings change over time as the model becomes more refined. This is because the previous center-finding routine averages all instances since start and therefore tracks changes in expected value very slowly. Instead, a moving window average is preferred. An exponential moving-window variant is as follows:
def get_embed_centers(embed_batch, label_batch):
''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
'''
decay = 0.95
with tf.variable_scope('embed', reuse=True):
embed_ctrs = tf.get_variable("ctrs")
label_batch = tf.reshape(label_batch, [-1])
old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
return embed_ctrs_batch
with tf.Session() as sess:
with tf.variable_scope('embed'):
embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label_batch_ph = tf.placeholder(tf.int32)
embed_batch_ph = tf.placeholder(tf.float32)
embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
sess.run(tf.initialize_all_variables())
tf.get_default_graph().finalize()