Dimension mismatch in Keras during model.fit

萝らか妹 提交于 2020-01-10 04:38:33

问题


I put together a VAE using Dense Neural Networks in Keras. During model.fit I get a dimension mismatch, but not sure what is throwing the code off. Below is what my code looks like

from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K
import keras

import numpy as np
import matplotlib.pyplot as plt
import argparse
import os

(x_train, y_train), (x_test, y_test) = mnist.load_data()

image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# network parameters
input_shape = (original_dim, )
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50


x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_sigma = Dense(latent_dim)(h)

def sampling(args):
    z_mean, z_log_sigma = args
    #epsilon = K.random_normal(shape=(batch, dim))
    epsilon = K.random_normal(shape=(batch_size, latent_dim))
    return z_mean + K.exp(z_log_sigma) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_sigma])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_sigma])

decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

print('X Decoded Mean shape: ', x_decoded_mean.shape)

# end-to-end autoencoder
vae = Model(x, x_decoded_mean)

# encoder, from inputs to latent space
encoder = Model(x, z_mean)

# generator, from latent space to reconstructed inputs
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

def vae_loss(x, x_decoded_mean):
    xent_loss = keras.metrics.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

vae.compile(optimizer='rmsprop', loss=vae_loss)


print('X train shape: ', x_train.shape)
print('X test shape: ', x_test.shape)

vae.fit(x_train, x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, x_test)) 

Here is the stack trace that I see when model.fit is called.

File "/home/asattar/workspace/projects/keras-examples/blogautoencoder/VariationalAutoEncoder.py", line 81, in <module>
    validation_data=(x_test, x_test))
  File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/engine/training.py", line 1047, in fit
    validation_steps=validation_steps)
  File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/engine/training_arrays.py", line 195, in fit_loop
    outs = fit_function(ins_batch)
  File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/backend/tensorflow_backend.py", line 2897, in __call__
    return self._call(inputs)
  File "/usr/local/lib/python2.7/dist-packages/Keras-2.2.4-py2.7.egg/keras/backend/tensorflow_backend.py", line 2855, in _call
    fetched = self._callable_fn(*array_vals)
  File "/home/asattar/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1439, in __call__
    run_metadata_ptr)
  File "/home/asattar/.local/lib/python2.7/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [128,784] vs. [96,784]
     [[{{node training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/BroadcastGradientArgs}} = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@train...ad/Reshape"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/Shape, training/RMSprop/gradients/loss/dense_5_loss/logistic_loss/mul_grad/Shape_1)]]

Please note the "Incompatible shapes: [128,784] vs. [96,784]" in the stack trace" towards the end of the trace.


回答1:


According to Keras: What if the size of data is not divisible by batch_size?, one should better use model.fit_generator rather than model.fit here.

To use model.fit_generator, one should define one's own generator object. Following is an example:

from keras.utils import Sequence
import math

class Generator(Sequence):
    # Class is a dataset wrapper for better training performance
    def __init__(self, x_set, y_set, batch_size=256):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.indices = np.arange(self.x.shape[0])

    def __len__(self):
        return math.floor(self.x.shape[0] / self.batch_size)

    def __getitem__(self, idx):
        inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = self.x[inds]
        batch_y = self.y[inds]
        return batch_x, batch_y

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

train_datagen = Generator(x_train, x_train, batch_size)
test_datagen = Generator(x_test, x_test, batch_size)

vae.fit_generator(train_datagen,
    steps_per_epoch=len(x_train)//batch_size,
    validation_data=test_datagen,
    validation_steps=len(x_test)//batch_size,
    epochs=epochs)

Code adopted from How to shuffle after each epoch using a custom generator?.




回答2:


Just tried to replicate and found out that when you define

x = Input(batch_shape=(batch_size, original_dim))

you're setting the batch size and it's causing a mismatch when it starts to validate. Change to

x = Input(shape=input_shape)

and you should be all set.



来源:https://stackoverflow.com/questions/54524124/dimension-mismatch-in-keras-during-model-fit

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!