RNN model predicting only one class?

心不动则不痛 提交于 2020-02-07 05:42:24

问题


I am trying to use GloVe embeddings to train a rnn model based on this article. I have a labeled data: text(tweets) on one column, labels on another (hate, offensive or neither). However the model seems to predict only one class in the result.

This is the LSTM model:

model = Sequential()
hidden_layer = 3
gru_node = 32

# model embedding matrix here....

for i in range(0,hidden_layer):
    model.add(GRU(gru_node,return_sequences=True, recurrent_dropout=0.2))
    model.add(Dropout(dropout))
model.add(GRU(gru_node, recurrent_dropout=0.2))
model.add(Dropout(dropout))
model.add(Dense(64, activation='softmax'))
model.add(Dense(nclasses, activation='softmax'))
start=time.time()
model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

fitting the model:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 1)
X_train_Glove,X_test_Glove, word_index, embeddings_index = loadData_Tokenizer(X_train, X_test)

model_RNN = Build_Model_RNN_Text(word_index,embeddings_index, 20)    
model_RNN.fit(X_train_Glove,y_train,
                            validation_data=(X_test_Glove, y_test),
                            epochs=4,
                            batch_size=128,
                            verbose=2)
y_preds = model_RNN.predict_classes(X_test_Glove)
print(metrics.classification_report(y_test, y_preds))

Results:

  1. classification report

  1. Confusion matrix

Am I missing something here?

Update: this is what the distribution looks like

and the model summary, more or less


回答1:


How the distribution of your data looks like? The first suggestion is to stratify train/test split (here is the link for the documentation).

The second question is how much data do you have in comparison with the complexity of the model? Maybe, your model is so complex, that just do overfitting. You can use the command model.summary() to see the number of trainable parameters.



来源:https://stackoverflow.com/questions/57748144/rnn-model-predicting-only-one-class

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!