Tensorflow: How can I assign numpy pre-trained weights to subsections of graph?

前端 未结 2 1787
孤街浪徒
孤街浪徒 2021-02-20 02:00

This is a simple thing which I just couldn\'t figure out how to do.

I converted a pre-trained VGG caffe model to tensorflow using the github code from https://github.com

相关标签:
2条回答
  • 2021-02-20 02:26

    I suggest you have a detailed look at network.py from the https://github.com/ethereon/caffe-tensorflow, especially the function load(). It would help you understand what happened when you called net.load(weight_path, session).

    FYI, variables in Tensorflow can be assigned to a numpy array by using var.assign(np_array) which is executed in the session. Here is the solution to your question:

    with tf.Session() as sess:    
      W_conv1_b = weight_variable([3,3,3,64])
      sess.run(W_conv1_b.assign(net.layers['conv1_1'].weights))
      b_conv1_b = bias_variable([64])
      sess.run(b_conv1_b.assign(net.layers['conv1_1'].biases))
      h_conv1_b = tf.nn.relu(conv2d(im_batch, W_conv1_b) + b_conv1_b)
    

    I would like to kindly remind you the following points:

    1. var.assign(data) where 'data' is a numpy array and 'var' is a TensorFlow variable should be executed in the same session where you want to continue to execute your network either inference or training.
    2. The 'var' should be created as the same shape as the 'data' by default. Therefore, if you can obtain the 'data' before creating the 'var', I suggest you create the 'var' by the method var=tf.Variable(shape=data.shape). Otherwise, you need to create the 'var' by the method var=tf.Variable(validate_shape=False), which means the variable shape is feasible. Detailed explainations can be found in the Tensorflow's API doc.

    I extend the same repo caffe-tensorflow to support theano in caffe so that I can load the transformed model from caffe in Theano. Therefore, I am a reasonable expert w.r.t this repo's code. Please feel free to get in contact with me as you have any further question.

    0 讨论(0)
  • 2021-02-20 02:34

    You can get variable values using eval method of tf.Variable-s from the first network and load that values into variables of the second network using load method (also method of the tf.Variable).

    0 讨论(0)
提交回复
热议问题