Pytorch:实现CNN手写数字识别
原文地址 分类目录——Pytorch 直接上程序,通过注释说明 import torch import torch . nn as nn import torch . utils . data as Data import torchvision # 数据库模块 import os import time torch . manual_seed ( 1 ) # 为pytorch中的随机操作设置一个随机种子,使得每次随机的结果都一样 # 一些超参数(全局参数) EPOCH = 2 # 训练整批数据多少次, 为了节约时间, 我们只训练一次 BATCH_SIZE = 50 # 小批量梯度下降的梯度规格,每次拿一个batch的数据来训练,来优化一波参数 LR = 0.001 # 学习率 if os . path . exists ( './mnist/' ) : # 如果已经存在(下载)了就不用下载了 DOWNLOAD_MNIST = False else : DOWNLOAD_MNIST = True # Mnist 手写数字 train_data = torchvision . datasets . MNIST ( root = './mnist/' , # 保存或者提取位置 train = True , # this is training data transform =