pytorch数据读取机制:
sampler生成索引index,根据索引从DataSet中获取图片和标签
1.torch.utils.data.DataLoader
功能:构建可迭代的数据装在器
dataset:Dataset类,决定数据从哪读取及如何读取
batchsize:批大小
num_works:是否多进程读取数据,当条件允许时,多进程读取数据会加快数据读取速度。
shuffle:每个epoch是否乱序
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
DataLoader(dataset, batchsize=1, shuffle=False, batch_sampler=None, num_workers=0, collate_fn=None, pin_memeory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
epoch:所有训练样本都已输入到模型中,称为一个epoch
iteration:一批样本输入到模型中,称为一个iteration
batchsize:批大小,决定一个epoch有多少个iteration
例如:
样本总数:80, batchsize:8
1epoch = 10 iteraion
样本总数:87, batchsize:8
1 epoch = 10 iteration drop_last=True
1 epoch = 11 iteration drop_last=False
2.torch.utils.data.Dataset
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写
__getitem__()
getitem:接收一个索引,返回一个样本
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
人命币分类实例:
数据分割:
import os
import random
import shutil
def makedir(new_dir):
if not os.path.exists(new_dir):
os.makedirs(new_dir)
if __name__ == '__main__':
random.seed(1)
dataset_dir = os.path.join("..", "..", "data", "RMB_data")
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
test_dir = os.path.join(split_dir, "test")
train_pct = 0.8
valid_pct = 0.1
test_pct = 0.1
for root, dirs, files in os.walk(dataset_dir):
for sub_dir in dirs:
imgs = os.listdir(os.path.join(root, sub_dir))
imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
random.shuffle(imgs)
img_count = len(imgs)
train_point = int(img_count * train_pct)
valid_point = int(img_count * (train_pct + valid_pct))
for i in range(img_count):
if i < train_point:
out_dir = os.path.join(train_dir, sub_dir)
elif i < valid_point:
out_dir = os.path.join(valid_dir, sub_dir)
else:
out_dir = os.path.join(test_dir, sub_dir)
makedir(out_dir)
target_path = os.path.join(out_dir, imgs[i])
src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
shutil.copy(src_path, target_path)
print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
img_count-valid_point))
创建Dataset
import os
import random
from PIL import Image
from torch.utils.data import Dataset
random.seed(1)
rmb_label = {"1": 0, "100": 1}
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info
3.transforms
torch.transforms:常用图像处理方法
数据中心化 数据标准化 缩放 裁剪 旋转 翻转 填充 噪声添加 灰度转换 线性变换 仿射变换 亮度、饱和度及对比度
来源:oschina
链接:https://my.oschina.net/u/4361197/blog/3361215