Why does my model work with `tf.GradientTape()` but fail when using `keras.models.Model.fit()`

白昼怎懂夜的黑 提交于 2020-03-23 12:03:53

问题


After much effort, I managed to build a tensorflow 2 implementation of an existing pytorch style-transfer project. Then I wanted to get all the nice extra features that are available through Keras standard learning, e.g. model.fit().

But the same model fails when learning through model.fit(). The model seems to learn the content features, but is unable to learn style features. This is the diagram of the model in quesion:

def vgg_layers19(content_layers, style_layers, input_shape=(256,256,3)):
  """ creates a VGG model that returns output values for the given layers
  see: https://keras.io/applications/#extract-features-from-an-arbitrary-intermediate-layer-with-vgg19
  Returns: 
    function(x, preprocess=True):
      Args: 
        x: image tuple/ndarray h,w,c(RGB), domain=(0.,255.)
      Returns:
        a tuple of lists, ([content_features], [style_features])
  usage:
    (content_features, style_features) = vgg_layers16(content_layers, style_layers)(x_train)
  """
  preprocessingFn = tf.keras.applications.vgg19.preprocess_input
  base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
  base_model.trainable = False
  content_features = [base_model.get_layer(name).output for name in content_layers]
  style_features = [base_model.get_layer(name).output for name in style_layers]
  output_features = content_features + style_features

  model = Model( inputs=base_model.input, outputs=output_features, name="vgg_layers")
  model.trainable = False

  def _get_features(x, preprocess=True):
    """
    Args:
      x: expecting tensor, domain=255. hwcRGB
    """
    if preprocess and callable(preprocessingFn): 
      x = preprocessingFn(x)
    output = model(x) # call as tf.keras.Layer()
    return ( output[:len(content_layers)], output[len(content_layers):] )

  return _get_features 



class VGG_Features():
""" get content and style features from VGG model """
  def __init__(self, loss_model, style_image=None, target_style_gram=None):
    self.loss_model = loss_model
    if style_image is not None:
      assert style_image.shape == (256,256,3), "ERROR: loss_model expecting input_shape=(256,256,3), got {}".format(style_image.shape)
      self.style_image = style_image
      self.target_style_gram = VGG_Features.get_style_gram(self.loss_model, self.style_image)
    if target_style_gram is not None:
      self.target_style_gram = target_style_gram

  @staticmethod
  def get_style_gram(vgg_features_model, style_image):
    style_batch = tf.repeat( style_image[tf.newaxis,...], repeats=_batch_size, axis=0)
    # show([style_image], w=128, domain=(0.,255.) )

    # B, H, W, C = style_batch.shape
    (_, style_features) = vgg_features_model( style_batch , preprocess=True ) # hwcRGB
    target_style_gram = [ fnstf_utils.gram(value)  for value in style_features ]  # list
    return target_style_gram  

  def __call__(self, input_batch):
    content_features, style_features = self.loss_model( input_batch, preprocess=True )
    style_gram = tuple(fnstf_utils.gram(value)  for value in style_features)  # tuple(<generator>)
    return (content_features[0],) + style_gram  # tuple = tuple + tuple




class TransformerNetwork_VGG(tf.keras.Model):
  def __init__(self, transformer=transformer, vgg_features=vgg_features):
    super(TransformerNetwork_VGG, self).__init__()
    self.transformer = transformer 
    # type: tf.keras.models.Model
    # input_shapes:  (None, 256,256,3)
    # output_shapes: (None, 256,256,3)


    style_model = {
       'content_layers':['block5_conv2'],
       'style_layers': ['block1_conv1',
                  'block2_conv1',
                  'block3_conv1', 
                  'block4_conv1', 
                  'block5_conv1']
    }
    vgg_model = vgg_layers19( style_model['content_layers'], style_model['style_layers'] )

    self.vgg_features = VGG_Features(vgg_model, style_image=style_image, batch_size=batch_size) 

    # input_shapes:  (None, 256,256,3)
    # output_shapes: [(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    #                [ content_loss,        style_loss_1, style_loss_2, style_loss_3, style_loss_4, style_loss_5 ]


  def call(self, inputs):
    x = inputs                # shape=(None, 256,256,3)

    # shape=(None, 256,256,3)
    generated_image = self.transformer(x)                    

    # shape=[(None, 16, 16, 512),  (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
    vgg_feature_losses = self.vgg(generated_image)           

    return vgg_feature_losses       # tuple(content1, style1, style2, style3, style4, style5)

Style Image

FEATURE_WEIGHTS= [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

GradientTape learning

With the tf.GradientTape() loop, I'm manually handling the multiple outputs, e.g. tuple of 6 tensors, from TransformerNetwork_VGG(x_train). This method learns correctly.

  @tf.function()
  def train_step(x_train, y_true, loss_weights=None, log_freq=10):
    with tf.GradientTape() as tape:
      y_pred = TransformerNetwork_VGG(x_train)
      generated_content_features = y_pred[:1]
      generated_style_gram = y_pred[1:]


      y_true = TransformerNetwork_VGG.vgg(x_train)
      target_content_features = y_true[:1]
      target_style_gram = TransformerNetwork_VGG.vgg.target_style_gram

      content_loss = get_MEAN_mse_loss(target_content_features, generated_content_features, weights)
      style_loss = tuple(get_MEAN_mse_loss(x,y)*w for x,y,w in zip(target_style_gram, generated_style_gram, weights))

      total_loss = content_loss + = tf.reduce_sum(style_loss)
      TransformerNetwork = TransformerNetwork_VGG.transformer
      grads = tape.gradient(total_loss, TransformerNetwork.trainable_weights)
      optimizer.apply_gradients(zip(grads, TransformerNetwork.trainable_weights))
# GradientTape epoch=5: 
# losses:             [   6078.71         70.23  4495.13 13817.65 88217.99    48.36]

model.fit() learning

With tf.keras.models.Model.fit(), the multiple outputs, e.g. tuple of 6 tensors, are fed to the loss function individually as loss(y_pred, y_true) and then multipled by the correct weight on reduction. This method does learn to approximate the content_image, but does not learn to minimize the style losses! II cannot figure out why.

  history = TransformerNetwork_VGG.fit(
    x=train_dataset.repeat(NUM_EPOCHS),
    epochs=NUM_EPOCHS,
    steps_per_epoch=NUM_BATCHES,
    callbacks=callbacks,
  )
# model.fit() epoch=5: 
# losses:             [  4661.08       219.95   6959.01   4897.39 209201.16     84.68]]

50 epochs, with boosted style_weights, FEATURE_WEIGHTS= [ 0.1854, 1605.23, 25.08, 8.16, 1.28, 2330.79] # boost style loss x100

step=50, losses=[269899.45 337.5 69617.7 38424.96 9192.36 85903.44 66423.51]

check mse losses * weights

I tested my model with losses and weights fixed as follows * FEATURE_WEIGHTS = SEQ = [1.,2.,3.,4.,5.,6.,] * MSELoss(y_true, y_pred) == tf.ones() of equal shape and confirmed that model.fit() is handling multiple output losses * weights correctly

I've checked everything I can think of, but I cannot figure out how to make the model learn correctly with model.fit(). What am I missing??

The full notebook is available here: https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_FastStyleTransfer.ipynb

来源:https://stackoverflow.com/questions/60545104/why-does-my-model-work-with-tf-gradienttape-but-fail-when-using-keras-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!