Implementing gradient descent in TensorFlow instead of using the one provided with it

≯℡__Kan透↙ 提交于 2019-12-06 03:57:11

问题


I want to use gradient descent with momentum (keep track of previous gradients) while building a classifier in TensorFlow.

So I don't want to use tensorflow.train.GradientDescentOptimizer but I want to use tensorflow.gradients to calculate gradients and keep track of previous gradients and update the weights based on all of them.

How do I do this in TensorFlow?


回答1:


TensorFlow has an implementation of gradient descent with momentum.

To answer your general question about implementing your own optimization algorithm, TensorFlow gives you the primitives to calculate the gradients, and update variables using the calculated gradients. In your model, suppose loss designates the loss function, and var_list is a python list of TensorFlow variables in your model (which you can get by calling tf.all_variables or tf.trainable_variables, then you can calculate the gradients w.r.t your variables as follows :

grads = tf.gradients(loss, var_list)

For the simple gradient descent, you would simply subtract the product of the gradient and the learning rate from the variable. The code for that would look as follows :

var_updates = []
for grad, var in zip(grads, var_list):
  var_updates.append(var.assign_sub(learning_rate * grad))
train_op = tf.group(*var_updates)

You can train your model by calling sess.run(train_op). Now, you can do all sorts of things before actually updating your variables. For instance, you can keep track of the gradients in a different set of variables and use it for the momentum algorithm. Or, you can clip your gradients before updating the variables. All these are simple TensorFlow operations because the gradient tensors are no different from other tensors that you compute in TensorFlow. Please look at the implementations (Momentum, RMSProp, Adam) of some the fancier optimization algorithms to understand how you can implement your own.



来源:https://stackoverflow.com/questions/39167070/implementing-gradient-descent-in-tensorflow-instead-of-using-the-one-provided-wi

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!