How to assign a value to a tf.Variable in TensorFlow without using tf.assign

前端 未结 1 825
猫巷女王i
猫巷女王i 2021-01-23 12:58

I have a variable that contains the 4x4 identitiy matrix. I wish to assign some values to this matrix (these values are learned by the model).

When I use tf.assi

相关标签:
1条回答
  • 2021-01-23 13:35

    Here is an example of how you could do what (I think) you want:

    import tensorflow as tf
    import numpy as np
    
    with tf.Graph().as_default(), tf.Session() as sess:
        params = [[1.0, 2.0, 3.0]]
        M_gt = np.eye(4)
        M_gt[0:3, 3] = [4.0, 5.0, 6.0]
    
        M = tf.Variable(tf.eye(4, batch_shape=[1]), dtype=tf.float32)
        params_t = tf.constant(params, dtype=tf.float32)
    
        shape_m = tf.shape(M)
        batch_size = shape_m[0]
        num_m = shape_m[1]
        num_params = tf.shape(params_t)[1]
    
        last_column = tf.concat([tf.tile(tf.transpose(params_t)[tf.newaxis], (batch_size, 1, 1)),
                                 tf.zeros((batch_size, num_m - num_params, 1), dtype=params_t.dtype)], axis=1)
        replace = tf.concat([tf.zeros((batch_size, num_m, num_m - 1), dtype=params_t.dtype), last_column], axis=2)
    
        r = tf.range(num_m)
        ii = r[tf.newaxis, :, tf.newaxis]
        jj = r[tf.newaxis, tf.newaxis, :]
        mask = tf.tile((ii < num_params) & (tf.equal(jj, num_m - 1)), (batch_size, 1, 1))
        M_replaced = tf.where(mask, replace, M)
    
        loss = tf.nn.l2_loss(M_replaced - M_gt[np.newaxis])
        optimizer = tf.train.AdamOptimizer(0.001)
        train_op = optimizer.minimize(loss)
        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init)
        M_val, M_replaced_val = sess.run([M, M_replaced])
        print('M:')
        print(M_val)
        print('M_replaced:')
        print(M_replaced_val)
    

    Output:

    M:
    [[[ 1.  0.  0.  0.]
      [ 0.  1.  0.  0.]
      [ 0.  0.  1.  0.]
      [ 0.  0.  0.  1.]]]
    M_replaced:
    [[[ 1.  0.  0.  1.]
      [ 0.  1.  0.  2.]
      [ 0.  0.  1.  3.]
      [ 0.  0.  0.  1.]]]
    
    0 讨论(0)
提交回复
热议问题