Weighted cost function in tensorflow

后端 未结 2 430
说谎
说谎 2021-01-25 19:02

I\'m trying to introduce weighting into the following cost function:

_cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels         


        
2条回答
  •  爱一瞬间的悲伤
    2021-01-25 19:18

    The weight masks can be computed using tf.where. Here is the weighted cost example:

    batch_size = 100
    y1Weight = 0.25
    y0Weight = 0.75
    
    
    _logits = tf.Variable(tf.random_normal(shape=(batch_size, 2), stddev=1.))
    y = tf.random_uniform(shape=(batch_size,), maxval=2, dtype=tf.int32)
    
    _cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=_logits, labels=y)
    
    #Weight mask, the weights for label=0 is y0Weight and for 1 is y1Weight
    y_w = tf.where(tf.cast(y, tf.bool), tf.ones((batch_size,))*y0Weight, tf.ones((batch_size,))*y1Weight)
    
    # New weighted cost
    cost_w = tf.reduce_mean(tf.multiply(_cost, y_w))
    

    As suggested by @user1761806, the simpler solution would be to use tf.losses.sparse_softmax_cross_entropy() which has allows weighting of the classes.

提交回复
热议问题