tensorflow自己实现SGD功能

匿名 (未验证) 提交于 2019-12-03 00:32:02

手动实现SGD和调用优化器结果比较

import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #mnist已经作为官方的例子,做好了数据下载,分割,转浮点等一系列工作,源码在tensorflow源码中都可以找到 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  # 配置每个 GPU 上占用的内存的比例 # 没有GPU直接sess = tf.Session() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))  #每个批次的大小 batch_size = 20 #定义训练轮数据 train_epoch = 1 #定义每n轮输出一次 test_epoch_n = 1  #计算一共有多少批次 n_batch = mnist.train.num_examples // batch_size print("batch_size="+str(batch_size)+"n_batch="+str(n_batch))  #占位符,定义了输入,输出 x = tf.placeholder(tf.float32,[None, 784])  y_ = tf.placeholder(tf.float32,[None, 10])  #权重和偏置,使用0初始化 W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10]))  #权重和偏置,使用0初始化 W2 = tf.Variable(tf.zeros([784,10])) b2 = tf.Variable(tf.zeros([10]))  #这里定义的网络结构 y = tf.matmul(x,W) + b #损失函数是交叉熵 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_,logits=y))  lr = 0.01#学习率 gw,gb = tf.gradients(ys=cross_entropy, xs=[W,b]) wt = W - lr * gw bt = b - lr * gb updatew = tf.assign(W,wt) updateb = tf.assign(b,bt)  #这里定义的网络结构 y2 = tf.matmul(x,W2) + b2 #损失函数是交叉熵 cross_entropy2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_,logits=y2)) #训练方法: train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy2)  #初始化sess中所有变量 init = tf.global_variables_initializer()  sess.run(init)  batch_xs, batch_ys = mnist.train.next_batch(batch_size)  #输出手动SGD后的b值 for _ in range(2):     _,_,testsee1,testsee2 = sess.run([updatew,updateb,cross_entropy,b], feed_dict = {x: batch_xs, y_: batch_ys})     print(testsee1)     print(testsee2)  #输出优化器后的b值 for _ in range(2):     _,testsee1,testsee2 = sess.run([train_step,cross_entropy2,b2], feed_dict = {x: batch_xs, y_: batch_ys})     print(testsee1)     print(testsee2)

输出结果:

2.30259 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.] 2.26225 [ -5.00000024e-04   9.99999931e-04  -5.00000082e-04  -5.00000024e-04   -1.02445483e-10   4.99999849e-04  -1.00000005e-03  -6.51925805e-11   -2.79396766e-11   9.99999931e-04]

loss 和b值均一致,说明自己更新网络参数和优化器自动更新一致

此代码网络初始化均为0,mnist也是固定的数据,所以应该必定能复现上面的输出结果。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!