关于CIFAR-10数据集的处理
CIFAR-10和CIFAR-100是带有标签的数据集,出自于规模更大的一个数据集,他有八千万张小图片(http://groups.csail.mit.edu/vision/TinyImages/这个是一个大项目,你可以点击那个big map提交自己的标签,可以帮助他们训练让计算机识别物体的模型)。
在学习cs231n中接触到CIFAR-10数据集,对于图像类数据首次接触,特将处理过程记录如下
CIFAR-10
该数据集共有60000张彩色图像,这些图像是32*32的彩色照片,每个像素点包括RGB三个数值,数值范围0~255,分为10个类,分别是'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck',每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
下面这幅图就是列举了10各类,每一类展示了随机的10张图片:
数据的下载
(共有3个版本:python、matlab、binary version适用于C语言)
http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
数据的处理
这里以python为例。该数据集文件包含data_batch1……data_batch5和test_batch。他们都是由cPickle库产生的序列化后的对象(关于pickle,移步https://docs.python.org/3/library/pickle.html)这里给出python2和python3的代码,可以打开这样的pkl文件,返回一个字典结构的数据:
import numpy as np
import random
import pickle
import platform
import os
#加载序列文件
def load_pickle(f):
version=platform.python_version_tuple()#判断python的版本
if version[0]== '2':
return pickle.load(f)
elif version[0]== '3':
return pickle.load(f,encoding='latin1')
raise ValueError("invalid python version:{}".format(version))
经上述代码,传入的每个batch文件,返回的是一个字典,该字典包含有:
- labels
对应的值是一个长度为10000的列表,每个数字取值范围 0~9,代表当前图片所属类别
- data
10000 * 3072 的二维数组,每一行代表一张图片的像素值。(32*32*3=3072)
数据集除了6个batch之外,还有一个文件batches.meta。它包含一个python字典对象,内容有:一个包含10个元素的列表,每一个描述了labels array中每个数字对应类标的名字。比如:label_names[0] == "airplane", label_names[1] == "automobile"
#处理原数据
def load_CIFAR_batch(filename):
with open(filename,'rb') as f:
datadict=load_pickle(f)
X=datadict['data']
Y=datadict['labels']
X=X.reshape(10000,3,32,32).transpose(0,2,3,1).astype("float")
#reshape()是在不改变矩阵的数值的前提下修改矩阵的形状,transpose()对矩阵进行转置
Y=np.array(Y)
return X,Y
#返回可以直接使用的数据集
def load_CIFAR10(ROOT):
xs=[]
ys=[]
for b in range(1,6):
f=os.path.join(ROOT,'data_batch_%d'%(b,))#os.path.join()将多个路径组合后返回
X,Y=load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr=np.concatenate(xs)#这个函数用于将多个数组进行连接
Ytr=np.concatenate(ys)
del X,Y
Xte,Yte=load_CIFAR_batch(os.path.join(ROOT,'test_batch'))
return Xtr,Ytr,Xte,Yte
测试代码
datasets='cifar-10-batches-py'
X_train,Y_train,X_test,Y_test=load_CIFAR10(datasets)
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', Y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', Y_test.shape)
经上述处理后,对原始CIFAR-10数据集的处理,返回结果如下:
Training data shape: (50000, 32, 32, 3)
Training labels shape: (50000,)
Test data shape: (10000, 32, 32, 3)
Test labels shape: (10000,)
来源:CSDN
作者:SSeazen
链接:https://blog.csdn.net/SSeazen/article/details/85621547