After much effort, I managed to build a tensorflow 2
implementation of an existing pytorch
style-transfer project. Then I wanted to get all the nice extra features that are available through Keras standard learning, e.g. model.fit()
But the same model fails when learning through model.fit()
. The model seems to learn the content features, but is unable to learn style features. This is the diagram of the model in quesion:
def vgg_layers19(content_layers, style_layers, input_shape=(256,256,3)):
""" creates a VGG model that returns output values for the given layers
see: https://keras.io/applications/#extract-features-from-an-arbitrary-intermediate-layer-with-vgg19
function(x, preprocess=True):
x: image tuple/ndarray h,w,c(RGB), domain=(0.,255.)
a tuple of lists, ([content_features], [style_features])
(content_features, style_features) = vgg_layers16(content_layers, style_layers)(x_train)
preprocessingFn = tf.keras.applications.vgg19.preprocess_input
base_model = tf.keras.applications.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
base_model.trainable = False
content_features = [base_model.get_layer(name).output for name in content_layers]
style_features = [base_model.get_layer(name).output for name in style_layers]
output_features = content_features + style_features
model = Model( inputs=base_model.input, outputs=output_features, name="vgg_layers")
model.trainable = False
def _get_features(x, preprocess=True):
x: expecting tensor, domain=255. hwcRGB
if preprocess and callable(preprocessingFn):
x = preprocessingFn(x)
output = model(x) # call as tf.keras.Layer()
return ( output[:len(content_layers)], output[len(content_layers):] )
return _get_features
class VGG_Features():
""" get content and style features from VGG model """
def __init__(self, loss_model, style_image=None, target_style_gram=None):
self.loss_model = loss_model
if style_image is not None:
assert style_image.shape == (256,256,3), "ERROR: loss_model expecting input_shape=(256,256,3), got {}".format(style_image.shape)
self.style_image = style_image
self.target_style_gram = VGG_Features.get_style_gram(self.loss_model, self.style_image)
if target_style_gram is not None:
self.target_style_gram = target_style_gram
def get_style_gram(vgg_features_model, style_image):
style_batch = tf.repeat( style_image[tf.newaxis,...], repeats=_batch_size, axis=0)
# show([style_image], w=128, domain=(0.,255.) )
# B, H, W, C = style_batch.shape
(_, style_features) = vgg_features_model( style_batch , preprocess=True ) # hwcRGB
target_style_gram = [ fnstf_utils.gram(value) for value in style_features ] # list
return target_style_gram
def __call__(self, input_batch):
content_features, style_features = self.loss_model( input_batch, preprocess=True )
style_gram = tuple(fnstf_utils.gram(value) for value in style_features) # tuple(<generator>)
return (content_features[0],) + style_gram # tuple = tuple + tuple
class TransformerNetwork_VGG(tf.keras.Model):
def __init__(self, transformer=transformer, vgg_features=vgg_features):
super(TransformerNetwork_VGG, self).__init__()
self.transformer = transformer
# type: tf.keras.models.Model
# input_shapes: (None, 256,256,3)
# output_shapes: (None, 256,256,3)
style_model = {
'style_layers': ['block1_conv1',
vgg_model = vgg_layers19( style_model['content_layers'], style_model['style_layers'] )
self.vgg_features = VGG_Features(vgg_model, style_image=style_image, batch_size=batch_size)
# input_shapes: (None, 256,256,3)
# output_shapes: [(None, 16, 16, 512), (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
# [ content_loss, style_loss_1, style_loss_2, style_loss_3, style_loss_4, style_loss_5 ]
def call(self, inputs):
x = inputs # shape=(None, 256,256,3)
# shape=(None, 256,256,3)
generated_image = self.transformer(x)
# shape=[(None, 16, 16, 512), (None, 64, 64), (None, 128, 128), (None, 256, 256), (None, 512, 512), (None, 512, 512)]
vgg_feature_losses = self.vgg(generated_image)
return vgg_feature_losses # tuple(content1, style1, style2, style3, style4, style5)
Style Image
FEATURE_WEIGHTS= [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
With the tf.GradientTape()
loop, I'm manually handling the multiple outputs, e.g. tuple of 6 tensors, from TransformerNetwork_VGG(x_train)
. This method learns correctly.
def train_step(x_train, y_true, loss_weights=None, log_freq=10):
with tf.GradientTape() as tape:
y_pred = TransformerNetwork_VGG(x_train)
generated_content_features = y_pred[:1]
generated_style_gram = y_pred[1:]
y_true = TransformerNetwork_VGG.vgg(x_train)
target_content_features = y_true[:1]
target_style_gram = TransformerNetwork_VGG.vgg.target_style_gram
content_loss = get_MEAN_mse_loss(target_content_features, generated_content_features, weights)
style_loss = tuple(get_MEAN_mse_loss(x,y)*w for x,y,w in zip(target_style_gram, generated_style_gram, weights))
total_loss = content_loss + = tf.reduce_sum(style_loss)
TransformerNetwork = TransformerNetwork_VGG.transformer
grads = tape.gradient(total_loss, TransformerNetwork.trainable_weights)
optimizer.apply_gradients(zip(grads, TransformerNetwork.trainable_weights))
# GradientTape epoch=5:
# losses: [ 6078.71 70.23 4495.13 13817.65 88217.99 48.36]
With tf.keras.models.Model.fit()
, the multiple outputs, e.g. tuple of 6 tensors, are fed to the loss function individually as loss(y_pred, y_true)
and then multipled by the correct weight on reduction
. This method does learn to approximate the content_image, but does not learn to minimize the style losses! II cannot figure out why.
history = TransformerNetwork_VGG.fit(
# model.fit() epoch=5:
# losses: [ 4661.08 219.95 6959.01 4897.39 209201.16 84.68]]
50 epochs, with boosted style_weights, FEATURE_WEIGHTS= [ 0.1854, 1605.23, 25.08, 8.16, 1.28, 2330.79] # boost style loss x100
step=50, losses=[269899.45 337.5 69617.7 38424.96 9192.36 85903.44 66423.51]
check mse
losses * weights
I tested my model with losses and weights fixed as follows
* FEATURE_WEIGHTS = SEQ = [1.,2.,3.,4.,5.,6.,]
* MSELoss(y_true, y_pred) == tf.ones() of equal shape
and confirmed that model.fit()
is handling multiple output losses * weights correctly
I've checked everything I can think of, but I cannot figure out how to make the model learn correctly with model.fit()
. What am I missing??
The full notebook is available here: https://colab.research.google.com/github/mixuala/fast_neural_style_pytorch/blob/master/notebook/%5BSO%5D_FastStyleTransfer.ipynb