Tensorflow: Trainable Variable Masking

后端 未结 1 542
伪装坚强ぢ
伪装坚强ぢ 2021-01-15 04:24

I am working on a convolutional neural net that requires some parts of the a kernel weights to be untrainable. tf.nn.conv2d(x, W) takes in a trainable variable W as weights.

1条回答
  •  被撕碎了的回忆
    2021-01-15 05:10

    Maybe you could have your trainable weights W1, a mask M indicating where the trainable variables are, and a constant / untrainable weight matrix W2, and use

    W = tf.multiply(W1, tf.cast(M, dtype=W1.dtype)) + tf.multiply(W2, tf.cast(tf.logical_not(M), dtype=W2.dtype)) 
    

    0 讨论(0)
提交回复
热议问题