Retraining a CNN without a high-level API

家住魔仙堡 提交于 2019-12-11 17:55:35

问题


Summary: I am trying to retrain a simple CNN for MNIST without using a high-level API. I already succeeded doing so by retraining the entire network, but my current goal is to retrain only the last one or two Fully Connected layers.

Work so far: Say I have a CNN with the following structure

  • Convolutional Layer
  • RELU
  • Pooling Layer
  • Convolutional Layer
  • RELU
  • Pooling Layer
  • Fully Connected Layer
  • RELU
  • Dropout Layer
  • Fully Connected Layer to 10 output classes

My goal is to retrain either the last Fully Connected Layer or the last two Fully Connected Layers.

An example of a Convolutional layer:

W_conv1 = tf.get_variable("W", [5, 5, 1, 32],
      initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2.0 / 784)))
b_conv1 = tf.get_variable("b", initializer=tf.constant(0.1, shape=[32]))
z = tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
z += b_conv1
h_conv1 = tf.nn.relu(z + b_conv1)

An example of a Fully Connected Layer:

input_size = 7 * 7 * 64
W_fc1 = tf.get_variable("W", [input_size, 1024], initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2.0/input_size)))
b_fc1 = tf.get_variable("b", initializer=tf.constant(0.1, shape=[1024]))
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

My assumption: When performing the backpropagation on the new dataset, I simply make sure that my weights W and b (from W*x+b) are fixed in the non-fully connected layers.

A first thought on how to do this: Save the W and b, perform a backpropagation step, and replace the new W and b with the old one in the layers I don't want changed.

My thoughts on this first approach:

  • This is computational intensive and wastes memory. The whole advantage of only doing the last layer is to not have to do the others
  • Backpropagation might function different if not applied on all layers?

My question:

  • How do I properly retrain particular layers in a Neural Network when not using high-level APIs. Both conceptual and coding answers are welcome.

P.S. Fully aware how one can do it using high-level APIs. Example: https://towardsdatascience.com/how-to-train-your-model-dramatically-faster-9ad063f0f718. Just don't want Neural Networks to be magic, I want to know what actually happens


回答1:


The minimize function of optimizers has an optional argument for choosing which variables to train, e.g.:

optimizer_step = tf.train.MomentumOptimizer(learning_rate, momentum, name='MomentumOptimizer').minimize(loss, var_list=training_variables)

You can get the variables for the layers you want to train by using tf.trainable_variables():

vars1 = tf.trainable_variables()

# FC Layer
input_size = 7 * 7 * 64
W_fc1 = tf.get_variable("W", [input_size, 1024], initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2.0/input_size)))
b_fc1 = tf.get_variable("b", initializer=tf.constant(0.1, shape=[1024]))
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

vars2 = tf.trainable_variables()

training_variables = list(set(vars2) - set(vars1))

Edit: actually, using tf.trainable_variables is probably overkill in this case, since you have W_fc1 and b_fc1 directly. This would be useful for example if you had used tf.layers.dense to create a dense layer, where you would not have the variables explicitly.



来源:https://stackoverflow.com/questions/54303730/retraining-a-cnn-without-a-high-level-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!