环境:
Pytorch1.1,Python3.6,win10/ubuntu18,GPU
正文
- Pytorch构建ResNet18模型并训练,进行真实图片分类;
- 利用预训练的ResNet18模型进行Fine tune,直接进行图片分类
项目结构如下所示
pokemon里面存放数据,分别是五个文件夹,其中每个文件夹分别存放一定数量的图片,总共1000多张图片;
best.mdl是保存下来的模型,可以直接加载进行分类
resnet.py是自己搭建的ResNet18模型
train_scratch.py利用resnet.py中的ResNet18模型进行图片分类
train_transfer.py利用下载的ResNet18模型进行图片分类
接下来进入正题:
pokemon.py:
import torch import os, glob import random, csv from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image class Pokemon(Dataset): def __init__(self, root, resize, mode): super(Pokemon, self).__init__() self.root = root self.resize = resize self.name2label = {} # "sq...":0 for name in sorted(os.listdir(os.path.join(root))): if not os.path.isdir(os.path.join(root, name)): continue self.name2label[name] = len(self.name2label.keys()) # print(self.name2label) # image, label self.images, self.labels = self.load_csv('images.csv') if mode=='train': # 60% self.images = self.images[:int(0.6*len(self.images))] self.labels = self.labels[:int(0.6*len(self.labels))] elif mode=='val': # 20% = 60%->80% self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))] self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))] else: # 20% = 80%->100% self.images = self.images[int(0.8*len(self.images)):] self.labels = self.labels[int(0.8*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)): images = [] for name in self.name2label.keys(): # 'pokemon\\mewtwo\\00001.png images += glob.glob(os.path.join(self.root, name, '*.png')) images += glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) # 1167, 'pokemon\\bulbasaur\\00000000.png' print(len(images), images) random.shuffle(images) with open(os.path.join(self.root, filename), mode='w', newline='') as f: writer = csv.writer(f) for img in images: # 'pokemon\\bulbasaur\\00000000.png' name = img.split(os.sep)[-2] label = self.name2label[name] # 'pokemon\\bulbasaur\\00000000.png', 0 writer.writerow([img, label]) print('writen into csv file:', filename) # read from csv file images, labels = [], [] with open(os.path.join(self.root, filename)) as f: reader = csv.reader(f) for row in reader: # 'pokemon\\bulbasaur\\00000000.png', 0 img, label = row label = int(label) images.append(img) labels.append(label) assert len(images) == len(labels) return images, labels def __len__(self): return len(self.images) def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # x_hat = (x-mean)/std # x = x_hat*std = mean # x: [c, h, w] # mean: [3] => [3, 1, 1] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) # print(mean.shape, std.shape) x = x_hat * std + mean return x def __getitem__(self, idx): # idx~[0~len(images)] # self.images, self.labels # img: 'pokemon\\bulbasaur\\00000000.png' # label: 0 img, label = self.images[idx], self.labels[idx] tf = transforms.Compose([ lambda x:Image.open(x).convert('RGB'), # string path= > image data transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = tf(img) label = torch.tensor(label) return img, label def main(): import visdom import time import torchvision viz = visdom.Visdom() # tf = transforms.Compose([ # transforms.Resize((64,64)), # transforms.ToTensor(), # ]) # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf) # loader = DataLoader(db, batch_size=32, shuffle=True) # # print(db.class_to_idx) # # for x,y in loader: # viz.images(x, nrow=8, win='batch', opts=dict(title='batch')) # viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y')) # # time.sleep(10) db = Pokemon('pokemon', 224, 'train') x,y = next(iter(db)) print('sample:', x.shape, y.shape, y) viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8) for x,y in loader: viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch')) viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y')) time.sleep(10) if __name__ == '__main__': main()
注释:Pokemon类功能是对数据集进行解析,把文件夹中的图片分成train,val,test三个集合