I have a dataset containing grayscale images and I want to train a state-of-the-art CNN on them. I\'d very much like to fine-tune a pre-trained model (like the ones here).
A simple way to do this is to add a convolution layer before the base model and then feed the output to the base model. Like this:
from keras.models import Model
from keras.layers import Input
resnet = Resnet50(weights='imagenet',include_top= 'TRUE')
input_tensor = Input(shape=(IMG_SIZE,IMG_SIZE,1) )
x = Conv2D(3,(3,3),padding='same')(input_tensor) # x has a dimension of (IMG_SIZE,IMG_SIZE,3)
out = resnet (x)
model = Model(inputs=input_tensor,outputs=out)
numpy's depth-stack function, np.dstack((img, img, img)) is a natural way to go.
Converting grayscale images to RGB as per the currently accepted answer is one approach to this problem, but not the most efficient. You most certainly can modify the weights of the model's first convolutional layer and achieve the stated goal. The modified model will both work out of the box (with reduced accuracy) and be finetunable. Modifying the weights of the first layer does not render the rest of the weights useless as suggested by others.
To do this, you'll have to add some code where the pretrained weights are loaded. In your framework of choice, you need to figure out how to grab the weights of the first convolutional layer in your network and modify them before assigning to your 1-channel model. The required modification is to sum the weight tensor over the dimension of the input channels. The way the weights tensor is organized varies from framework to framework. The PyTorch default is [out_channels, in_channels, kernel_height, kernel_width]. In Tensorflow I believe it is [kernel_height, kernel_width, in_channels, out_channels].
Using PyTorch as an example, in a ResNet50 model from Torchvision (https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py), the shape of the weights for conv1 is [64, 3, 7, 7]. Summing over dimension 1 results in a tensor of shape [64, 1, 7, 7]. At the bottom I've included a snippet of code that would work with the ResNet models in Torchvision assuming that an argument (inchans) was added to specify a different number of input channels for the model.
To prove this works I did three runs of ImageNet validation on ResNet50 with pretrained weights. There is a slight difference in the numbers for run 2 & 3, but it's minimal and should be irrelevant once finetuned.
def _load_pretrained(model, url, inchans=3):
state_dict = model_zoo.load_url(url)
if inchans == 1:
conv1_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif inchans != 3:
assert False, "Invalid number of inchans for pretrained weights"
model.load_state_dict(state_dict)
def resnet50(pretrained=False, inchans=3):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], inchans=inchans)
if pretrained:
_load_pretrained(model, model_urls['resnet50'], inchans=inchans)
return model
If you're already using scikit-image
, you can get the desired result by using gray2RGB.
from skimage.color import gray2rgb
rgb_img = gray2rgb(gray_img)
I believe you can use a pretrained resnet with 1 channel gray scale images without repeating 3 times the image.
What I have done is to replace the first layer (this is pythorch not keras, but the idea might be similar):
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
With the following layer:
(conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
And then copy the sum (in the channel axis) of the weights to the new layer, for example, the shape of the original weights was:
torch.Size([64, 3, 7, 7])
So I did:
resnet18.conv1.weight.data = resnet18.conv1.weight.data.sum(axis=1).reshape(64, 1, 7, 7)
And then check that the output of the new model is the same than the output with the gray scale image:
y_1 = model_resnet_1(input_image_1)
y_3 = model_resnet_3(input_image_3)
print(torch.abs(y_1).sum(), torch.abs(y_3).sum())
(tensor(710.8860, grad_fn=<SumBackward0>),
tensor(710.8861, grad_fn=<SumBackward0>))
input_image_1: one channel image
input_image_3: 3 channel image (gray scale - all channels equal)
model_resnet_1: modified model
model_resnet_3: Original resnet model