问题
I'll start by disclosing that I'm a machine learning and Keras novice and don't know much beyond general CNN binary classifiers. I'm trying to perform pixelwise multi-class classification using a U-Net architecture (TF backend) on many 256x256 images. In other words, I input a 256x256 image, and I want it to output a 256x256 "mask" (or label image) where the values are integers from 0-30 (each integer represents a unique class). I'm training on 2 1080Ti NVIDIA GPUs.
When I attempt to perform one-hot encoding, I get an OOM error, which is why I'm using sparse categorical cross entropy as my loss function instead of regular categorical cross entropy. However, when training my U-Net, my loss value is "nan" from start to finish (it initializes as nan and never changes). When I normalize my "masks" by dividing all values by 30 (so they go from 0-1), I get ~0.97 accuracy, which I'm guessing is because most of the labels in my image are 0 (and it's just outputting a bunch of 0s).
Here's the U-Net I'm using:
def unet(pretrained_weights = None,input_size = (256,256,1)):
inputs = keras.engine.input_layer.Input(input_size)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
#drop4 = Dropout(0.5)(conv4)
drop4 = SpatialDropout2D(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
#drop5 = Dropout(0.5)(conv5)
drop5 = SpatialDropout2D(0.5)(conv5)
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = concatenate([drop4,up6], axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv3,up7], axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv2,up8], axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv1,up9], axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(1, 1, activation = 'softmax')(conv9)
#conv10 = Flatten()(conv10)
#conv10 = Dense(65536, activation = 'softmax')(conv10)
flat10 = Reshape((65536,1))(conv10)
#conv10 = Conv1D(1, 1, activation='linear')(conv10)
model = Model(inputs = inputs, outputs = flat10)
opt = Adam(lr=1e-6,clipvalue=0.01)
model.compile(optimizer = opt, loss = 'sparse_categorical_crossentropy', metrics = ['sparse_categorical_accuracy'])
#model.compile(optimizer = Adam(lr = 1e-6), loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
#model.compile(optimizer = Adam(lr = 1e-4),
#model.summary()
if(pretrained_weights):
model.load_weights(pretrained_weights)
return model
Note that I needed to flatten the output just to get sparse categorical cross entropy to function properly (it didn't like my 2D matrix for some reason).
And here's an example of a training run (just 1 epoch because it's the same no matter how many I run)
model = unet()
model.fit(x=x_train, y=y_train, batch_size=1, epochs=1, verbose=1, validation_split=0.2, shuffle=True)
Train on 2308 samples, validate on 577 samples Epoch 1/1 2308/2308 [==============================] - 191s 83ms/step - loss: nan - sparse_categorical_accuracy: 0.9672 - val_loss: nan - val_sparse_categorical_accuracy: 0.9667 Out[18]:
Let me know if more information is needed to diagnose the problem. Thanks in advance!
回答1:
The problem is that for multiclass classification, you need to output a vector with one dimension per category, which represents the confidence in that category. If you want to identify 30 different classes, then your final layer should be a 3D tensor, (256, 256, 30).
conv10 = Conv2D(30, 1, activation = 'softmax')(conv9)
flat10 = Reshape((256*256*30,1))(conv10)
opt = Adam(lr=1e-6,clipvalue=0.01)
model.compile(optimizer = opt, loss = 'sparse_categorical_crossentropy', metrics =
['sparse_categorical_accuracy'])
I'm assuming that your input is a (256, 256, 1) float tensor with values between 0 and 1, and your target is a (256*256) Int tensor.
Does that help?
回答2:
conv10 = Conv2D(nclasses, kernel_size=(1, 1))(up9)
out = BatchNormalization()(conv10)
out = Reshape((img_height*img_width, nclasses), input_shape=(img_height, img_width, nclasses))(out)
out = Activation('softmax')(out)
model = Model(inputs=[inputs], outputs=[out])
model.compile(optimizer = Adam(lr = 1e-4), loss = 'sparse_categorical_crossentropy', metrics = ['sparse_categorical_accuracy'])
x_train :(batch_size, 224, 224, 3) float32 (Input images)
y_train: (batch_size, 50176, 1) uint8 (Target labels)
The above code seems to work for multi-class segmentation (nclasses), where target labels are not one hot encoded. One hot encoding creates memory issues if your data size and/or model is very large.
The last layer has shape (None, 50176, 16) (since nclasses=16, None corr to batch). The elements in labels have value 0 - (nclasses-1).
Using argmax on class index(-1) and reshaping the output outside seems to be the trick, in case you want a corr. image output ...
NB: Sparse Categorical Entropy seems to have issues in keras 2.2.2 and above !!!
回答3:
OOM:
Make a custom function to derive one-hot encoding instead of using a predefined function like "to_categorical".
It takes 1/4th the amount of memory (in my case).
回答4:
It seems now that you can simply do softmax
activation on the last Conv2D
layer and then specify categorical_crossentropy
loss and train on the image without any reshaping tricks. I've tested with a dummy dataset and it works well. Try it ~ !
inp = keras.Input(...)
# define your model here
out = keras.layers.Conv2D(classes, (1, 1), activation='softmax') (...)
model = keras.Model(inputs=[inp], outputs=[out], name='unet')
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(tensor4d, tensor4d)
You can also compile using sparse_categorical_crossentropy
and then train with output of shape (samples, height, width)
where each pixel in the output corresponds to a class label: model.fit(tensor4d, tensor3d)
PS. I use keras
from tensorflow.keras
(tensorflow 2)
来源:https://stackoverflow.com/questions/54136325/use-of-keras-sparse-categorical-crossentropy-for-pixel-wise-multi-class-classifi