问题
Trying combine GAN generator and critic to train both as VAE. Base code is here. Code modified to create encoder on top of critic:
def _build_critic(self):
#### THE critic
critic_input = Input(shape=self.input_dim, name='critic_input')
x = critic_input
for i in range(self.n_layers_critic):
x = Conv2D(
filters = self.critic_conv_filters[i]
, kernel_size = self.critic_conv_kernel_size[i]
, strides = self.critic_conv_strides[i]
, padding = 'same'
, name = 'critic_conv_' + str(i)
, kernel_initializer = self.weight_init
)(x)
if self.critic_batch_norm_momentum and i > 0:
x = BatchNormalization(momentum = self.critic_batch_norm_momentum)(x)
x = self.get_activation(self.critic_activation)(x)
if self.critic_dropout_rate:
x = Dropout(rate = self.critic_dropout_rate)(x)
x = Flatten()(x)
for i in range(self.n_dense_layers_critic):
x = Dense(self.critic_dense_layers[i] , kernel_initializer = self.weight_init)(x)
x = self.get_activation(self.critic_activation)(x)
if self.critic_droupout_dense:
x = Dropout(rate = self.critic_droupout_dense[i])(x)
critic_output = Dense(1, activation=None, kernel_initializer = self.weight_init)(x)
#building encooder on top of critic
self.mu = Dense(self.z_dim, name='mu')(x)
self.log_var = Dense(self.z_dim, name='log_var')(x)
self.encoder_mu_log_var = Model(critic_input, (self.mu, self.log_var))
def sampling(args):
mu, log_var = args
epsilon = K.random_normal(shape=K.shape(mu), mean=0., stddev=1.)
return mu + K.exp(log_var / 2) * epsilon
encoder_output = Lambda(sampling, name='encoder_output')([self.mu, self.log_var])
self.encoder = Model(critic_input, encoder_output)
self.critic = Model(critic_input, critic_output)
code added to def _build_adversarial(self): to assemble VAE:
#construct VAE
# For vae all layers are trainable
self.set_trainable(self.encoder, True)
self.set_trainable(self.generator, True)
vae_input = Input(shape=self.input_dim,)
decoder_input = self.encoder(vae_input)
vae_output = self.generator(decoder_input)
self.vae = Model(vae_input, vae_output)
### COMPILATION
def vae_r_loss(y_true, y_pred):
r_loss = K.mean(K.square(y_true - y_pred), axis = [1,2,3])
return self.r_loss_factor * r_loss
def vae_kl_loss(y_true, y_pred):
kl_loss = -0.5 * K.sum(1 + self.log_var - K.square(self.mu) - K.exp(self.log_var), axis = 1)
return kl_loss
def vae_loss(y_true, y_pred):
r_loss = vae_r_loss(y_true, y_pred)
kl_loss = vae_kl_loss(y_true, y_pred)
return r_loss + kl_loss
optimizer = Adam(lr=self.critic_learning_rate)
self.vae.compile(optimizer=optimizer, loss = vae_loss, metrics = [vae_r_loss, vae_kl_loss])
Compiled model:
Model: "model_21" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_11 (InputLayer) (None, 64, 64, 3) 0 _________________________________________________________________ model_16 (Model) (None, 100) 13323592 _________________________________________________________________ model_18 (Model) (None, 64, 64, 3) 5168003 ================================================================= Total params: 18,491,595 Trainable params: 18,474,315 Non-trainable params: 17,280
This code throws an error:
def train_vae(self, x_train, batch_size, epoches = 1):
imgs = next(x_train)[0]
if imgs.shape[0] != batch_size:
imgs = next(x_train)[0]
#return self.vae.train_on_batch(imgs, imgs)
return self.vae.fit(x_train, epochs = 1)
Tried both train_on_batch and simple fit, error:
File "<ipython-input-15-5c41e1111b5c>", line 8, in <module>
, using_generator = True
File "D:\dev\python\GAN\models\WGANGPMOD.py", line 410, in train
vae_loss = self.train_vae(x_train, batch_size, n_vae)
File "D:\dev\python\GAN\models\WGANGPMOD.py", line 389, in train_vae
, epochs = 1
File "D:\programs\conda\envs\ml\lib\site-packages\keras\engine\training.py", line 1147, in fit
initial_epoch=initial_epoch)
File "D:\programs\conda\envs\ml\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "D:\programs\conda\envs\ml\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
initial_epoch=initial_epoch)
File "D:\programs\conda\envs\ml\lib\site-packages\keras\engine\training_generator.py", line 220, in fit_generator
reset_metrics=False)
File "D:\programs\conda\envs\ml\lib\site-packages\keras\engine\training.py", line 1514, in train_on_batch
outputs = self.train_function(ins)
File "D:\programs\conda\envs\ml\lib\site-packages\tensorflow_core\python\keras\backend.py", line 3727, in __call__
outputs = self._graph_fn(*converted_inputs)
File "D:\programs\conda\envs\ml\lib\site-packages\tensorflow_core\python\eager\function.py", line 1551, in __call__
return self._call_impl(args, kwargs)
File "D:\programs\conda\envs\ml\lib\site-packages\tensorflow_core\python\eager\function.py", line 1591, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "D:\programs\conda\envs\ml\lib\site-packages\tensorflow_core\python\eager\function.py", line 1692, in _call_flat
ctx, args, cancellation_manager=cancellation_manager))
File "D:\programs\conda\envs\ml\lib\site-packages\tensorflow_core\python\eager\function.py", line 545, in call
ctx=ctx)
File "D:\programs\conda\envs\ml\lib\site-packages\tensorflow_core\python\eager\execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
InvalidArgumentError: You must feed a value for placeholder tensor 'critic_input_2' with dtype float and shape [?,64,64,3]
[[node critic_input_2 (defined at D:\programs\conda\envs\ml\lib\site-packages\keras\backend\tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_45035]
My steps to fix according after googling many sources, including stackowerflow:
- Clear name for critic_input, since it’s used in two models only slide changed error text to “You must feed a value for placeholder tensor 'input_12' with dtype float and shape [?,64,64,3]”
- Clear default graph with keras backend before combined model creation – no effect on error
- Switching off batch normalization
Have no clue how to fix it.
来源:https://stackoverflow.com/questions/63871470/error-after-combination-of-two-keras-models-into-vae-you-must-feed-a-value-for