【pytorch报错解决】expected input to have 3 channels, but got 1 channels instead

大城市里の小女人 提交于 2020-11-11 15:08:51

遇到的问题

数据是png图像的时候,如果用PIL读取图像,获得的是单通道的,不是多通道的。虽然使用opencv读取图片可以获得三通道图像数据,如下:

    def __getitem__(self, idx):
        image_root = self.train_image_file_paths[idx]
        image_name = image_root.split(os.path.sep)[-1]
        image = cv.imread(image_root)

        if self.transform is not None:
            image = self.transform(image)
        label = ohe.encode(image_name.split('_')[0]) 
        return image, label

但是会出现报错:

TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

  File "c:/Users/pprp/Desktop/pytorch-captcha-recognition-master/captcha_train.py", line 77, in <module>
    main(args)
  File "c:/Users/pprp/Desktop/pytorch-captcha-recognition-master/captcha_train.py", line 47, in main
    predict_labels = cnn(images)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\models\resnet.py", line 192, in forward
    x = self.conv1(x)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\conv.py", line 338, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 3 7 7, expected input[64, 60, 160, 3] to have 3 channels, but got 60 channels instead

最终解决方案:

class mydataset(Dataset):
    def __init__(self, folder, transform=None):
        self.train_image_file_paths = [os.path.join(folder, image_file) for image_file in os.listdir(folder)]
        self.transform = transforms.Compose([
                                            transforms.ToTensor(), # 转化为pytorch中的tensor
                                            transforms.Lambda(lambda x: x.repeat(1,1,1)), # 由于图片是单通道的,所以重叠三张图像,获得一个三通道的数据
                                            # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
                                            ]) # 主要改这个地方

    def __len__(self):
        return len(self.train_image_file_paths)

    def __getitem__(self, idx):
        image_root = self.train_image_file_paths[idx]
        image_name = image_root.split(os.path.sep)[-1]
        image = Image.open(image_root)
        if self.transform is not None:
            image = self.transform(image)
        label = ohe.encode(image_name.split('_')[0]) 
        return image, label

pytorch transform 知识点:https://blog.csdn.net/u011995719/article/details/85107009 PIL PNG格式通道问题的解决方法 : https://www.cnblogs.com/wzjbg/p/8516531.html

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!