关于CIFAR-10数据集的处理

社会主义新天地 提交于 2019-12-03 12:24:12

关于CIFAR-10数据集的处理

 

CIFAR-10和CIFAR-100是带有标签的数据集,出自于规模更大的一个数据集,他有八千万张小图片(http://groups.csail.mit.edu/vision/TinyImages/这个是一个大项目,你可以点击那个big map提交自己的标签,可以帮助他们训练让计算机识别物体的模型)。

 

在学习cs231n中接触到CIFAR-10数据集,对于图像类数据首次接触,特将处理过程记录如下

 

CIFAR-10

该数据集共有60000张彩色图像,这些图像是32*32的彩色照片,每个像素点包括RGB三个数值,数值范围0~255,分为10个类,分别是'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck',每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。

 

下面这幅图就是列举了10各类,每一类展示了随机的10张图片:

 

数据的下载

(共有3个版本:python、matlab、binary version适用于C语言)

http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

数据的处理

这里以python为例。该数据集文件包含data_batch1……data_batch5和test_batch。他们都是由cPickle库产生的序列化后的对象(关于pickle,移步https://docs.python.org/3/library/pickle.html)这里给出python2和python3的代码,可以打开这样的pkl文件,返回一个字典结构的数据:

import numpy as np
import random
import pickle
import platform
import os

#加载序列文件
def load_pickle(f):
    version=platform.python_version_tuple()#判断python的版本
    if version[0]== '2':
        return pickle.load(f)
    elif version[0]== '3':
        return pickle.load(f,encoding='latin1')
    raise ValueError("invalid python version:{}".format(version))

经上述代码,传入的每个batch文件,返回的是一个字典,该字典包含有:

  • labels

  对应的值是一个长度为10000的列表,每个数字取值范围 0~9,代表当前图片所属类别

  • data

  10000 * 3072 的二维数组,每一行代表一张图片的像素值。(32*32*3=3072) 

 

数据集除了6个batch之外,还有一个文件batches.meta。它包含一个python字典对象,内容有:一个包含10个元素的列表,每一个描述了labels array中每个数字对应类标的名字。比如:label_names[0] == "airplane", label_names[1] == "automobile"

#处理原数据
def load_CIFAR_batch(filename):
    with open(filename,'rb') as f:
        datadict=load_pickle(f)
        X=datadict['data']
        Y=datadict['labels']
        X=X.reshape(10000,3,32,32).transpose(0,2,3,1).astype("float")
        #reshape()是在不改变矩阵的数值的前提下修改矩阵的形状,transpose()对矩阵进行转置
        Y=np.array(Y)
        return X,Y
        
        
#返回可以直接使用的数据集
def load_CIFAR10(ROOT):
    xs=[]
    ys=[]
    for b in range(1,6):
        f=os.path.join(ROOT,'data_batch_%d'%(b,))#os.path.join()将多个路径组合后返回
        X,Y=load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr=np.concatenate(xs)#这个函数用于将多个数组进行连接
    Ytr=np.concatenate(ys)
    del X,Y
    Xte,Yte=load_CIFAR_batch(os.path.join(ROOT,'test_batch'))
    return Xtr,Ytr,Xte,Yte

测试代码

datasets='cifar-10-batches-py'
X_train,Y_train,X_test,Y_test=load_CIFAR10(datasets)
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', Y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', Y_test.shape)

 经上述处理后,对原始CIFAR-10数据集的处理,返回结果如下:

Training data shape:  (50000, 32, 32, 3)
Training labels shape:  (50000,)
Test data shape:  (10000, 32, 32, 3)
Test labels shape:  (10000,)

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!