Pytorch打怪路(一)pytorch进行CIFAR-10分类(5)测试

三世轮回 提交于 2019-12-02 01:13:54

pytorch进行CIFAR-10分类(5)测试

我的系列博文:

 

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

Pytorch打怪路(一)pytorch进行CIFAR-10分类(2)定义卷积神经网络

Pytorch打怪路(一)pytorch进行CIFAR-10分类(3)定义损失函数和优化器

Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练

Pytorch打怪路(一)pytorch进行CIFAR-10分类(5)测试本文

 

1.直接上代码

代码第一部分

dataiter = iter(testloader)      # 创建一个python迭代器,读入的是我们第一步里面就已经加载好的testloader
images, labels = dataiter.next() # 返回一个batch_size的图片,根据第一步的设置,应该是4张

# print images
imshow(torchvision.utils.make_grid(images))  # 展示这四张图片
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4))) # python字符串格式化 ' '.join表示用空格来连接后面的字符串,参考python的join()方法

 

这一部分代码就是先随机读取4张图片,让我们看看这四张图片是什么并打印出相应的label信息,

因为第一步里面设置了是shuffle了数据的,也就是顺序是打乱的,所以各自出现的图像不一定相同,

代码第二部分

outputs = net(Variable(images))      # 注意这里的images是我们从上面获得的那四张图片,所以首先要转化成variable
_, predicted = torch.max(outputs.data, 1)  
                # 这个 _ , predicted是python的一种常用的写法,表示后面的函数其实会返回两个值
                # 但是我们对第一个值不感兴趣,就写个_在那里,把它赋值给_就好,我们只关心第二个值predicted
                # 比如 _ ,a = 1,2 这中赋值语句在python中是可以通过的,你只关心后面的等式中的第二个位置的值是多少
 
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))  # python的字符串格式化


这里用到了torch.max(  ), 它是属于Tensor的一个方法:

注意到注释中第一句话,是说返回返回输入Tensor中每行的最大值,并转换成指定的dim(维度),

所以我们程序中的 torch.max(outputs.data, 1) ,返回一个tuple (元组)

而这里很明显,这个返回的元组的第一个元素是image data,即是最大的 值,第二个元素是label, 即是最大的值 的 索引!

我们只需要label(最大值的索引),所以就会有 _ , predicted这样的赋值语句,表示忽略第一个返回值,把它赋值给 _, 就是舍弃它的意思;

我在注释中也说明了这是什么意思

这里说一下,这第二个参数1,看清楚上面的说明是 the dimension to reduce! 而不是去这个dimension上面找最大

所以这里dim=1,基于我们的a是 4行 x 4列 这么一个维度,所以指的是 消除列这个维度,这是个什么意思呢?

如果我们把上面的示例代码中,的参数 keepdim=True写上,torch.max(a,1,keepdim=True), 会发现,返回的结果的第一个元素,即表示最大的值的那部分,其实是一个 size为 【4,1】的Tensor,也就是其实它是在 按照每行 来找最大,所以结果是4行,然后因为只找一个最大值,所以是1列,整个size就是 4行 1 列, 然后参数dim=1,相当于调用了 squeeze(1),这个操作,上面的说明也是这么写的,所以最后就得到结果是一个size为4的vector。

你可以自己下去在ipython里面做实验,发现如果dim=0,它其实是在返回每列的最大值,

所以一定不要搞混!这里的dim是指的 the dimension to reduce!并不是在the dimension上去返回最大值。

所以其实我自己写的时候一般更喜欢用 torch.argmax()这个函数更直观更好理解一些

总之在这里你只需要理解这行操作的功能是:返回了最大的索引,即预测出来的类别。 想深入研究可以自己去ipython里面试一下

 

代码第三部分

correct = 0   # 定义预测正确的图片数,初始化为0
total = 0     # 总共参与测试的图片数,也初始化为0
for data in testloader:  # 循环每一个batch
    images, labels = data
    outputs = net(Variable(images))  # 输入网络进行测试
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)          # 更新测试图片的数量
    correct += (predicted == labels).sum() # 更新正确分类的图片的数量

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))          # 最后打印结果

tutorial给的结果是53%


代码第四部分

来测试一下每一类的分类正确率

class_correct = list(0. for i in range(10)) # 定义一个存储每类中测试正确的个数的 列表,初始化为0
class_total = list(0. for i in range(10))   # 定义一个存储每类中测试总数的个数的 列表,初始化为0
for data in testloader:     # 以一个batch为单位进行循环
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    c = (predicted == labels).squeeze()
    for i in range(4):      # 因为每个batch都有4张图片,所以还需要一个4的小循环
        label = labels[i]   # 对各个类的进行各自累加
        class_correct[label] += c[i]
        class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

 


 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!