How does binary cross entropy loss work on autoencoders?

≡放荡痞女 提交于 2019-12-31 10:45:15

问题


I wrote a vanilla autoencoder using only Dense layer. Below is my code:

iLayer = Input ((784,))
layer1 = Dense(128, activation='relu' ) (iLayer)
layer2 = Dense(64, activation='relu') (layer1)
layer3 = Dense(28, activation ='relu') (layer2)
layer4 = Dense(64, activation='relu') (layer3)
layer5 = Dense(128, activation='relu' ) (layer4)
layer6 = Dense(784, activation='softmax' ) (layer5)
model = Model (iLayer, layer6)
model.compile(loss='binary_crossentropy', optimizer='adam')

(trainX, trainY), (testX, testY) =  mnist.load_data()
print ("shape of the trainX", trainX.shape)
trainX = trainX.reshape(trainX.shape[0], trainX.shape[1]* trainX.shape[2])
print ("shape of the trainX", trainX.shape)
model.fit (trainX, trainX, epochs=5, batch_size=100)

Questions:

1) softmax provides probability distribution. Understood. This means, I would have a vector of 784 values with probability between 0 and 1. For example [ 0.02, 0.03..... upto 784 items], summing all 784 elements provides 1.

2) I don't understand how the binary crossentropy works with these values. Binary cross entropy is for two values of output, right?


回答1:


In the context of autoencoders the input and output of the model is the same. So, if the input values are in the range [0,1] then it is acceptable to use sigmoid as the activation function of last layer. Otherwise, you need to use an appropriate activation function for the last layer (e.g. linear which is the default one).

As for the loss function, it comes back to the values of input data again. If the input data are only between zeros and ones (and not the values between them), then binary_crossentropy is acceptable as the loss function. Otherwise, you need to use other loss functions such as 'mse' (i.e. mean squared error) or 'mae' (i.e. mean absolute error). Note that in the case of input values in range [0,1] you can use binary_crossentropy, as it is usually used (e.g. Keras autoencoder tutorial and this paper). However, don't expect that the loss value becomes zero since binary_crossentropy does not return zero when both prediction and label are not either zero or one (no matter they are equal or not). Here is a video from Hugo Larochelle where he explains the loss functions used in autoencoders (the part about using binary_crossentropy with inputs in range [0,1] starts at 5:30)

Concretely, in your example, you are using the MNIST dataset. So by default the values of MNIST are integers in the range [0, 255]. Usually you need to normalize them first:

trainX = trainX.astype('float32')
trainX /= 255.

Now the values would be in range [0,1]. So sigmoid can be used as the activation function and either of binary_crossentropy or mse as the loss function.


Why binary_crossentropy can be used even when the true label values (i.e. ground-truth) are in the range [0,1]?

Note that we are trying to minimize the loss function in training. So if the loss function we have used reaches its minimum value (which may not be necessarily equal to zero) when prediction is equal to true label, then it is an acceptable choice. Let's verify this is the case for binray cross-entropy which is defined as follows:

bce_loss = -y*log(p) - (1-y)*log(1-p)

where y is the true label and p is the predicted value. Let's consider y as fixed and see what value of p minimizes this function: we need to take the derivative with respect to p (I have assumed the log is the natural logarithm function for simplicity of calculations):

bce_loss_derivative = -y*(1/p) - (1-y)*(-1/(1-p)) = 0 =>
                      -y/p + (1-y)/(1-p) = 0 =>
                      -y*(1-p) + (1-y)*p = 0 =>
                      -y + y*p + p - y*p = 0 =>
                       p - y = 0 => y = p

As you can see binary cross-entropy have the minimum value when y=p, i.e. when the true label is equal to predicted label and this is exactly what we are looking for.



来源:https://stackoverflow.com/questions/52441877/how-does-binary-cross-entropy-loss-work-on-autoencoders

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!