How to quantize the values of tf.Variables in Tensorflow

杀马特。学长 韩版系。学妹 提交于 2019-12-24 02:33:13

问题


I have a training model like

Y = w * X + b

where Y and X are output and input placeholder, w and b are the vectors
I already know the value of w can only be 0 or 1, while b is still tf.float32.

How could I quantize the range of variable w when I define it?
or
Can I have two different learning rates? The rate for w is 1 or -1 and the rate for b is 0.0001 as usual.


回答1:


There is no way to limit your variable during the activation. But what you can do is to limit it after each iteration. Here is one way to do this with tf.where():

import tensorflow as tf

a = tf.random_uniform(shape=(3, 3))

b = tf.where(
    tf.less(a, tf.zeros_like(a) + 0.5),
    tf.zeros_like(a),
    tf.ones_like(a)
)

with tf.Session() as sess:
    A, B = sess.run([a, b])
    print A, '\n'
    print B

Which will convert everything above 0.5 to 1 and everything else to 0:

[[ 0.2068541   0.12682056  0.73839438]
 [ 0.00512838  0.43465161  0.98486936]
 [ 0.32126224  0.29998791  0.31065524]] 

[[ 0.  0.  1.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]



回答2:


One method I have used to limit variables to a particular range is to add a constraint to my loss equation. If the variable goes outside of the desired range, then the loss will get bigger and the optimizer will push it back within the desired range.

For example:

#initialize variable to be between 0 and 1
variable = tf.Variable(tf.random_uniform([self.numOutputs], 0, 1))

#Clip the variable to force the result to be between 0 and 1 during training
variableClipped = tf.clip_by_value(variable, 0, 1)

#Set the loss to be the difference between the clipped variable and actual variable.
#Anytime it goes outside the variable range the loss will increase,
#and the optimizer will push it back within the desired range.
loss =  originalLossEquation + tf.reduce_sum((variable - variableClipped)**2)


来源:https://stackoverflow.com/questions/43834533/how-to-quantize-the-values-of-tf-variables-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!