Output and Broadcast shape mismatch in MNIST, torchvision

跟風遠走 提交于 2020-03-01 06:15:06

问题


I am getting following error when using MNIST dataset in Torchvision

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

Here is my code:

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                          ])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
images, labels = next(iter(trainloader))

回答1:


The error is due to color vs grayscale on the dataset, the dataset is grayscale.

I fixed it by changing transform to

transform = transforms.Compose([transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])


来源:https://stackoverflow.com/questions/55124407/output-and-broadcast-shape-mismatch-in-mnist-torchvision

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!