ResNet for Binary classification- Just 2 values of cross-validation accuracy

随声附和 提交于 2019-12-11 15:32:28

问题


I am new to python and Keras. I am trying to do a binary classification using transfer learning from ResNet. My dataset is very small but I am using image augmentation. My cross-validation accuracy is just either of 2 values 0.3442 and 0.6558 for all images. Can anyone tell me why this happens? Also when I predict (0 or 1), it labels all images as one class(0). Here is my code:

from keras.preprocessing.image import ImageDataGenerator, load_img
from keras.models import Sequential,Model,load_model
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense, GlobalMaxPooling2D
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.applications.resnet50 import ResNet50
from keras.optimizers import SGD, Adam
from keras.utils import plot_model
import matplotlib.pyplot as plt
import os, os.path
import glob
import cv2
import time
from keras.utils import np_utils
from keras.callbacks import ReduceLROnPlateau, CSVLogger, EarlyStopping
import numpy as np
import pydot
import graphviz

batch_size = 32
nb_classes = 2
data_augmentation = True

img_rows, img_cols = 224,224
img_channels = 3

#Creating array of training samples
train_path = "D:/data/train\*.*"
training_data=[]
for file in glob.glob(train_path):
    print(file)
    train_array= cv2.imread(file)
    train_array=cv2.resize(train_array,(img_rows,img_cols),3)
    training_data.append(train_array)

x_train=np.array(training_data)

#Creating array of validation samples
valid_path = "D:/data/valid\*.*"
valid_data=[]
for file in glob.glob(valid_path):
    print(file)
    valid_array= cv2.imread(file)
    valid_array=cv2.resize(valid_array,(img_rows,img_cols),3)
    valid_data.append(train_array)

x_valid=np.array(valid_data)

x_train = np.array(x_train, dtype="float")/255.0
x_valid = np.array(x_valid, dtype="float")/255.0

#Creating array for Labels
y_train=np.ones((num_trainsamples,),dtype = int)
y_train[0:224]=0 #Class1=0
y_train[225:363]=1 #Class2=1
print(y_train)

y_valid=np.ones((num_validsamples,),dtype = int)
y_valid[0:101]=0 
y_valid[102:155]=1 
print(y_valid)

y_train = np_utils.to_categorical(y_train,nb_classes,dtype='int32')
y_valid = np_utils.to_categorical(y_valid,nb_classes,dtype='int32')

base_model=ResNet50(weights='imagenet',include_top=False)

x = base_model.output
x = GlobalMaxPooling2D()(x)
x=Dense(1024,activation='relu')(x) 
x=Dense(1024,activation='relu')(x) 
x=Dense(512,activation='relu')(x) 
x=Dense(2, activation= 'sigmoid')(x)
model = Model(inputs = base_model.input, outputs = x)

for i,layer in enumerate(model.layers):
  print(i,layer.name)

for layer in model.layers[:75]:
    layer.trainable=False
for layer in model.layers[75:]:
    layer.trainable=True

adam = Adam(lr=0.0001)
model.compile(optimizer= adam, loss='binary_crossentropy', metrics=['accuracy'])

train_datagen = ImageDataGenerator(
    brightness_range=(0.2,2.5),
    rotation_range=180,
    zoom_range=0.5,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True)

train_datagen.fit(x_train)

history= model.fit_generator(train_datagen.flow(x_train, y_train, batch_size = 10,shuffle=True),steps_per_epoch=len(x_train),epochs = 500,shuffle=True,
    validation_data=(x_valid,y_valid),validation_steps=num_validsamples // batch_size,callbacks=[tensorboard])

eval = model.evaluate(x_valid, y_valid)
print ("Loss = " + str(eval[0]))
print ("Test Accuracy = " + str(eval[1]))

predictions= model.predict(x_valid)
print(predictions)

The training result is as follows:

Epoch 1/500
362/362 [==============================] - 34s 93ms/step - loss: 0.6060 - acc: 0.7257 - val_loss: 0.7747 - val_acc: 0.3442
Epoch 2/500
362/362 [==============================] - 30s 82ms/step - loss: 0.4353 - acc: 0.7722 - val_loss: 0.7658 - val_acc: 0.5000
Epoch 3/500
362/362 [==============================] - 30s 82ms/step - loss: 0.4391 - acc: 0.7863 - val_loss: 0.7949 - val_acc: 0.3442
Epoch 4/500
362/362 [==============================] - 30s 82ms/step - loss: 0.4007 - acc: 0.7992 - val_loss: 0.6540 - val_acc: 0.6558
Epoch 5/500
362/362 [==============================] - 30s 82ms/step - loss: 0.3638 - acc: 0.8226 - val_loss: 0.6460 - val_acc: 0.6558
Epoch 6/500
362/362 [==============================] - 30s 82ms/step - loss: 0.3509 - acc: 0.8294 - val_loss: 0.7875 - val_acc: 0.3442
Epoch 7/500
362/362 [==============================] - 30s 82ms/step - loss: 0.3406 - acc: 0.8359 - val_loss: 0.7667 - val_acc: 0.3442
Epoch 8/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3410 - acc: 0.8365 - val_loss: 0.6900 - val_acc: 0.6558
Epoch 9/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3297 - acc: 0.8366 - val_loss: 0.7292 - val_acc: 0.3442
Epoch 10/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3262 - acc: 0.8412 - val_loss: 0.6829 - val_acc: 0.6558
Epoch 11/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3168 - acc: 0.8457 - val_loss: 0.7032 - val_acc: 0.3442
Epoch 12/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3195 - acc: 0.8452 - val_loss: 0.6985 - val_acc: 0.5000
Epoch 13/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3030 - acc: 0.8432 - val_loss: 0.6740 - val_acc: 0.6558
Epoch 14/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3191 - acc: 0.8405 - val_loss: 0.6896 - val_acc: 0.6558
Epoch 15/500
362/362 [==============================] - 29s 80ms/step - loss: 0.3084 - acc: 0.8437 - val_loss: 0.7114 - val_acc: 0.3442

回答1:


When you implement binary classification with any CNN model, you had better use a single unit in the last dense layer.

x=Dense(2, activation= 'sigmoid')(x) ---> x=Dense(1, activation= 'sigmoid')(x)

It will give you better performance than now.



来源:https://stackoverflow.com/questions/54797065/resnet-for-binary-classification-just-2-values-of-cross-validation-accuracy

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!