Multi-label classification with class weights in Keras

微笑、不失礼 提交于 2019-12-20 12:39:04

问题


I have been stuck with multi-label classification for sometime(I must say I am quite new to Neural Network). First I will explain the network I am trying to train. I have a 1000 classes in the network and they have multi-label outputs. For each training example, the number of positive output is same(i.e 10) but they can be assigned to any of the 1000 classes. So 10 classes have output 1 and rest 990 have output 0. For the multi-label classification, I am using 'binary-cross entropy' as cost function and 'sigmoid' as the activation function. When I tried this rule of 0.5 as the cut-off for 1 or 0. All of them were 0. I understand this is a class imbalance problem. From this link, I understand that, I might have to create extra output labels.Unfortunately, I haven't been able to figure out how to incorporate that into a simple neural network in keras.

nclasses = 1000

# if we wanted to maximize an imbalance problem!
#class_weight = {k: len(Y_train)/(nclasses*(Y_train==k).sum()) for k in range(nclasses)}

#print(class_weight)
# building neural network model
inp = Input(shape=[X_train.shape[1]])
x = Dense(5000, activation='relu')(inp)

x = Dense(4000, activation='relu')(x)

x = Dense(3000, activation='relu')(x)
x = Dense(2000, activation='relu')(x)
x = Dense(nclasses, activation='sigmoid')(x)
model = Model(inputs=[inp], outputs=[x])
print(model.summary())

adam=keras.optimizers.adam(lr=0.00001)
model.compile('adam', 'binary_crossentropy')
history = model.fit(
    X_train, Y_train, batch_size=32, epochs=50,verbose=0,shuffle=False)

plt.plot(history.history['loss'])
#plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

model.save('model2.h5')

Could anyone help me with the code here and I would also highly appreciate if you could suggest a good 'accuracy' metric for this problem?

Thanks a lot :) :)


回答1:


I have a similar problem and unfortunately have no answer for most of the questions. Especially the class imbalance problem.

In terms of metric there are several possibilities: In my case I use the top 1/2/3/4/5 results and check if one of them is right. Because in your case you always have the same amount of labels=1 you could take your top 10 results and see how many percent of them are right and average this result over your batch size. I didn't find a possibility to include this algorithm as a keras metric. Instead, I wrote a callback, which calculates the metric on epoch end on my validation data set.

Also, if you predict the top n results on a test dataset, see how many times each class is predicted. The Counter Class is really convenient for this purpose.

Edit: If found a method to include class weights without splitting the output. You need a numpy 2d array containing weights with shape [number classes to predict, 2 (background and signal)]. Such an array could be calculated with this function:

def calculating_class_weights(y_true):
    from sklearn.utils.class_weight import compute_class_weight
    number_dim = np.shape(y_true)[1]
    weights = np.empty([number_dim, 2])
    for i in range(number_dim):
        weights[i] = compute_class_weight('balanced', [0.,1.], y_true[:, i])
    return weights

The solution is now to build your own binary crossentropy loss function in which you multiply your weights yourself:

def get_weighted_loss(weights):
    def weighted_loss(y_true, y_pred):
        return K.mean((weights[:,0]**(1-y_true))*(weights[:,1]**(y_true))*K.binary_crossentropy(y_true, y_pred), axis=-1)
    return weighted_loss

weights[:,0] is an array with all the background weights and weights[:,1] contains all the signal weights.

All that is left is to include this loss into the compile function:

model.compile(optimizer=Adam(), loss=get_weighted_loss(class_weights))


来源:https://stackoverflow.com/questions/48485870/multi-label-classification-with-class-weights-in-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!