Pytorch迁移学习

匿名 (未验证) 提交于 2019-12-02 23:53:01

环境:

Pytorch1.1,Python3.6,win10/ubuntu18,GPU

正文

  1. Pytorch构建ResNet18模型并训练,进行真实图片分类;
  2. 利用预训练的ResNet18模型进行Fine tune,直接进行图片分类

项目结构如下所示

pokemon里面存放数据,分别是五个文件夹,其中每个文件夹分别存放一定数量的图片,总共1000多张图片;

best.mdl是保存下来的模型,可以直接加载进行分类

resnet.py是自己搭建的ResNet18模型

train_scratch.py利用resnet.py中的ResNet18模型进行图片分类

train_transfer.py利用下载的ResNet18模型进行图片分类

接下来进入正题:

pokemon.py

import  torch import  os, glob import  random, csv  from    torch.utils.data import Dataset, DataLoader  from    torchvision import transforms from    PIL import Image   class Pokemon(Dataset):      def __init__(self, root, resize, mode):         super(Pokemon, self).__init__()          self.root = root         self.resize = resize          self.name2label = {} # "sq...":0         for name in sorted(os.listdir(os.path.join(root))):             if not os.path.isdir(os.path.join(root, name)):                 continue              self.name2label[name] = len(self.name2label.keys())          # print(self.name2label)          # image, label         self.images, self.labels = self.load_csv('images.csv')          if mode=='train': # 60%             self.images = self.images[:int(0.6*len(self.images))]             self.labels = self.labels[:int(0.6*len(self.labels))]         elif mode=='val': # 20% = 60%->80%             self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]             self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]         else: # 20% = 80%->100%             self.images = self.images[int(0.8*len(self.images)):]             self.labels = self.labels[int(0.8*len(self.labels)):]          def load_csv(self, filename):          if not os.path.exists(os.path.join(self.root, filename)):             images = []             for name in self.name2label.keys():                 # 'pokemon\\mewtwo\\00001.png                 images += glob.glob(os.path.join(self.root, name, '*.png'))                 images += glob.glob(os.path.join(self.root, name, '*.jpg'))                 images += glob.glob(os.path.join(self.root, name, '*.jpeg'))              # 1167, 'pokemon\\bulbasaur\\00000000.png'             print(len(images), images)              random.shuffle(images)             with open(os.path.join(self.root, filename), mode='w', newline='') as f:                 writer = csv.writer(f)                 for img in images: # 'pokemon\\bulbasaur\\00000000.png'                     name = img.split(os.sep)[-2]                     label = self.name2label[name]                     # 'pokemon\\bulbasaur\\00000000.png', 0                     writer.writerow([img, label])                 print('writen into csv file:', filename)          # read from csv file         images, labels = [], []         with open(os.path.join(self.root, filename)) as f:             reader = csv.reader(f)             for row in reader:                 # 'pokemon\\bulbasaur\\00000000.png', 0                 img, label = row                 label = int(label)                  images.append(img)                 labels.append(label)          assert len(images) == len(labels)          return images, labels        def __len__(self):          return len(self.images)       def denormalize(self, x_hat):          mean = [0.485, 0.456, 0.406]         std = [0.229, 0.224, 0.225]          # x_hat = (x-mean)/std         # x = x_hat*std = mean         # x: [c, h, w]         # mean: [3] => [3, 1, 1]         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)         std = torch.tensor(std).unsqueeze(1).unsqueeze(1)         # print(mean.shape, std.shape)         x = x_hat * std + mean          return x       def __getitem__(self, idx):         # idx~[0~len(images)]         # self.images, self.labels         # img: 'pokemon\\bulbasaur\\00000000.png'         # label: 0         img, label = self.images[idx], self.labels[idx]          tf = transforms.Compose([             lambda x:Image.open(x).convert('RGB'), # string path= > image data             transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),             transforms.RandomRotation(15),             transforms.CenterCrop(self.resize),             transforms.ToTensor(),             transforms.Normalize(mean=[0.485, 0.456, 0.406],                                  std=[0.229, 0.224, 0.225])         ])          img = tf(img)         label = torch.tensor(label)           return img, label      def main():      import  visdom     import  time     import  torchvision      viz = visdom.Visdom()      # tf = transforms.Compose([     #                 transforms.Resize((64,64)),     #                 transforms.ToTensor(),     # ])     # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)     # loader = DataLoader(db, batch_size=32, shuffle=True)     #     # print(db.class_to_idx)     #     # for x,y in loader:     #     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))     #     viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))     #     #     time.sleep(10)       db = Pokemon('pokemon', 224, 'train')      x,y = next(iter(db))     print('sample:', x.shape, y.shape, y)      viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))      loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)      for x,y in loader:         viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))         viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))          time.sleep(10)  if __name__ == '__main__':     main()

注释:Pokemon类功能是对数据集进行解析,把文件夹中的图片分成train,val,test三个集合

来源: https://www.cnblogs.com/yqpy/p/11333641.html

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!