问题
I am using VGG16 with keras for transfer learning (I have 7 classes in my new model) and as such I want to use the build-in decode_predictions method to output the predictions of my model. However, using the following code:
preds = model.predict(img)
decode_predictions(preds, top=3)[0]
I receive the following error message:
ValueError:
decode_predictions
expects a batch of predictions (i.e. a 2D array of shape (samples, 1000)). Found array with shape: (1, 7)
Now I wonder why it expects 1000 when I only have 7 classes in my retrained model.
A similar question I found here on stackoverflow (Keras: ValueError: decode_predictions expects a batch of predictions ) suggests to include 'inlcude_top=True' upon model definition to solve this problem:
model = VGG16(weights='imagenet', include_top=True)
I have tried this, however it is still not working - giving me the same error as before. Any hint or suggestion on how to solve this issue is highly appreciated.
回答1:
i suspect you are using some pre-trained model, let's say for instance resnet50 and you are importing decode_predictions
like this:
from keras.applications.resnet50 import decode_predictions
decode_predictions transform an array of (num_samples, 1000) probabilities to class name of original imagenet classes.
if you want to transer learning and classify between 7 different classes you need to do it like this:
base_model = resnet50 (weights='imagenet', include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 7 classes
predictions = Dense(7, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
...
after fitting the model and calculate predictions you have to manually assign the class name to output number without using imported decode_predictions
来源:https://stackoverflow.com/questions/49259361/valueerror-decode-predictions-expects-a-batch-of-predictions-i-e-a-2d-array