How can I use a pre-trained neural network with grayscale images?

后端 未结 11 1255
无人及你
无人及你 2020-12-02 09:35

I have a dataset containing grayscale images and I want to train a state-of-the-art CNN on them. I\'d very much like to fine-tune a pre-trained model (like the ones here).

相关标签:
11条回答
  • 2020-12-02 09:46

    The model's architecture cannot be changed because the weights have been trained for a specific input configuration. Replacing the first layer with your own would pretty much render the rest of the weights useless.

    -- Edit: elaboration suggested by Prune--
    CNNs are built so that as they go deeper, they can extract high-level features derived from the lower-level features that the previous layers extracted. By removing the initial layers of a CNN, you are destroying that hierarchy of features because the subsequent layers won't receive the features that they are supposed to as their input. In your case the second layer has been trained to expect the features of the first layer. By replacing your first layer with random weights, you are essentially throwing away any training that has been done on the subsequent layers, as they would need to be retrained. I doubt that they could retain any of the knowledge learned during the initial training.
    --- end edit ---

    There is an easy way, though, which you can make your model work with grayscale images. You just need to make the image to appear to be RGB. The easiest way to do so is to repeat the image array 3 times on a new dimension. Because you will have the same image over all 3 channels, the performance of the model should be the same as it was on RGB images.

    In numpy this can be easily done like this:

    print(grayscale_batch.shape)  # (64, 224, 224)
    rgb_batch = np.repeat(grayscale_batch[..., np.newaxis], 3, -1)
    print(rgb_batch.shape)  # (64, 224, 224, 3)
    

    The way this works is that it first creates a new dimension (to place the channels) and then it repeats the existing array 3 times on this new dimension.

    I'm also pretty sure that keras' ImageDataGenerator can load grayscale images as RGB.

    0 讨论(0)
  • 2020-12-02 09:49

    why not try to convert a grayscale image to a RGB image?

    tf.image.grayscale_to_rgb(
        images,
        name=None
    )
    
    0 讨论(0)
  • 2020-12-02 09:52

    I faced the same problem while working with VGG16 along with gray-scale images. I solved this problem like follows:

    Let's say our training images are in train_gray_images, each row containing the unrolled gray scale image intensities. So if we directly pass it to fit function it will create an error as the fit function is expecting a 3 channel (RGB) image data-set instead of gray-scale data set. So before passing to fit function do the following:

    Create a dummy RGB image data set just like the gray scale data set with the same shape (here dummy_RGB_image). The only difference is here we are using the number of the channel is 3.

    dummy_RGB_images = np.ndarray(shape=(train_gray_images.shape[0], train_gray_images.shape[1], train_gray_images.shape[2], 3), dtype= np.uint8) 
    

    Therefore just copy the whole data-set 3 times to each of the channels of the "dummy_RGB_images". (Here the dimensions are [no_of_examples, height, width, channel])

    dummy_RGB_images[:, :, :, 0] = train_gray_images[:, :, :, 0]
    dummy_RGB_images[:, :, :, 1] = train_gray_images[:, :, :, 0]
    dummy_RGB_images[:, :, :, 2] = train_gray_images[:, :, :, 0]
    

    Finally pass the dummy_RGB_images instead of the gray scale data-set, like:

    model.fit(dummy_RGB_images,...)
    
    0 讨论(0)
  • 2020-12-02 09:52

    Dropping the input layer will not work out. This will cause that the all following layers will suffer.

    What you can do is Concatenate 3 black and white images together to expand your color dimension.

    img_input = tf.keras.layers.Input(shape=(img_size_target, img_size_target,1))
    img_conc = tf.keras.layers.Concatenate()([img_input, img_input, img_input])    
    
    model = ResNet50(include_top=True, weights='imagenet', input_tensor=img_conc)
    
    0 讨论(0)
  • 2020-12-02 09:52

    You can use OpenCV to convert GrayScale to RGB.

    cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    
    0 讨论(0)
  • 2020-12-02 10:03

    When you add the Resnet to model, you should input the input_shape in Resnet definition like

    model = ResNet50(include_top=True,input_shape=(256,256,1))

    .

    0 讨论(0)
提交回复
热议问题