Keras train partial model issue (about GAN model)

喜你入骨 提交于 2019-12-06 05:47:06

First, I would advise you to switch to the Functional API models. These kinds of mixed models are more easily handled by Functional models.

I have no idea why your solution didn't work to be honnest, it seems like when you link the D model to a new input, it gets kind of "corrupted" and gets linked to it. The way I have found around that problem, is to define the layers and use them for both the Discriminator and the GAN models. Here is the code :

import numpy as np
from keras.layers import *
import keras.models as km
import keras.optimizers as ko
from keras.datasets import mnist

batch_size = 16
lr = 0.0001

def noise_gen(batch_size, z_dim):
    noise = np.zeros((batch_size, z_dim), dtype=np.float32)
    for i in range(batch_size):
        noise[i, :] = np.random.uniform(-1, 1, z_dim)
    return noise

# Changes the traiable argument for all the layers of model
# to the boolean argument "trainable"
def make_trainable(model, trainable):
    model.trainable = trainable
    for l in model.layers:
        l.trainable = trainable

# --------------------Generator Model--------------------

g_input = Input(shape=(100,))

g_hidden = Dense(1024, activation='relu')(g_input)
g_hidden = Dense(7*7*128, activation='relu')(g_hidden)
g_hidden = BatchNormalization()(g_hidden)
g_hidden = Reshape((7,7,128))(g_hidden)

g_hidden = Deconvolution2D(64,5,5, (None, 14, 14, 64), subsample=(2,2),
        border_mode='same', activation='relu')(g_hidden)
g_hidden = BatchNormalization()(g_hidden)
g_output = Deconvolution2D(1,5,5, (None, 28, 28, 1), subsample=(2,2),
        border_mode='same')(g_hidden)

G = km.Model(input=g_input,output=g_output)
G.compile(loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))
G.summary()

# --------------------Discriminator Model--------------------

d_input = Input(shape=(28,28,1))

d_l1 = Convolution2D(64,5,5, subsample=(2,2))
d_hidden_1 = d_l1(d_input)
d_l2 = LeakyReLU(alpha=0.2)
d_hidden_2 = d_l2(d_hidden_1)

d_l3 = Convolution2D(128,5,5, subsample=(2,2))
d_hidden_3 = d_l3(d_hidden_2)
d_l4 = BatchNormalization()
d_hidden_4 = d_l4(d_hidden_3)
d_l5 = LeakyReLU(alpha=0.2)
d_hidden_5 = d_l5(d_hidden_4)

d_l6 = Flatten()
d_hidden_6 = d_l6(d_hidden_5)
d_l7 = Dense(1, activation='sigmoid')
d_output = d_l7(d_hidden_6)

D = km.Model(input=d_input,output=d_output)
D.compile(loss='binary_crossentropy',optimizer=ko.SGD(lr=lr,momentum=0.9, nesterov=True))
D.summary()

# --------------------GAN Model--------------------
make_trainable(D,False)

gan_input = Input(shape=(100,))
gan_hidden = G(gan_input)
gan_hidden = d_l1(gan_hidden)
gan_hidden = d_l2(gan_hidden)
gan_hidden = d_l3(gan_hidden)
gan_hidden = d_l4(gan_hidden)
gan_hidden = d_l5(gan_hidden)
gan_hidden = d_l6(gan_hidden)
gan_output = d_l7(gan_hidden)

GAN = km.Model(input=gan_input,output=gan_output)
GAN.compile(loss='binary_crossentropy',optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))
GAN.summary()

# --------------------Main Code--------------------
(X, _), _ = mnist.load_data()
X = X / 255.
X = X[:, :, :, np.newaxis]

X_batch = X[0:batch_size, :]
Z1_batch = noise_gen(batch_size, 100)
Z2_batch = noise_gen(batch_size, 100)

print(type(X_batch),X_batch.shape)
print(type(Z1_batch),Z1_batch.shape)

fake_batch = G.predict(Z1_batch)
real_batch = X_batch
print('--------------------Fake Image Generated!--------------------')

combined_X_batch = np.concatenate((real_batch, fake_batch))
combined_y_batch = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))))
print('real_batch={}, fake_batch={}'.format(real_batch.shape, fake_batch.shape))
print(type(combined_X_batch),combined_X_batch.dtype,combined_X_batch.shape)
print(type(combined_y_batch),combined_y_batch.dtype,combined_y_batch.shape)
make_trainable(D,True)
d_loss = D.train_on_batch(combined_X_batch, combined_y_batch)
print('--------------------Discriminator trained!--------------------')
print(d_loss)

make_trainable(D,False)
g_loss = GAN.train_on_batch(Z2_batch, np.ones((batch_size, 1)))
print('--------------------GAN trained!--------------------')
print(g_loss)

Does that help?

After strived for quite a long time, I finally get it that it's the Discriminator's BatchNormalization layer that caused the problem.

If you just comment out the model.add(kl.BatchNormalization()) in the Discriminator. It'll work fine.

However, as @NassimBen shown, the functional API does not cause any problems.

import numpy as np
import keras.layers as kl
import keras.models as km
import keras.optimizers as ko
from keras.datasets import mnist

batch_size = 16
lr = 0.0001

def noise_gen(batch_size, z_dim):
    noise = np.zeros((batch_size, z_dim), dtype=np.float32)
    for i in range(batch_size):
        noise[i, :] = np.random.uniform(-1, 1, z_dim)
    return noise

# --------------------Generator Model--------------------

model = km.Sequential()

model.add(kl.Dense(input_dim=100, output_dim=1024))
model.add(kl.Activation('relu'))

model.add(kl.Dense(7*7*128))
model.add(kl.BatchNormalization())
model.add(kl.Activation('relu'))
model.add(kl.Reshape((7, 7, 128), input_shape=(7*7*128,)))

model.add(kl.Deconvolution2D(64, 5, 5, (None, 14, 14, 64), subsample=(2, 2),
                             input_shape=(7, 7, 128), border_mode='same'))
model.add(kl.BatchNormalization())
model.add(kl.Activation('relu'))

model.add(kl.Deconvolution2D(1, 5, 5, (None, 28, 28, 1), subsample=(2, 2),
                             input_shape=(14, 14, 64), border_mode='same'))

G = model
G.compile(  loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))

# --------------------Discriminator Model--------------------

model = km.Sequential()

model.add(kl.Convolution2D( 64, 5, 5, subsample=(2, 2), input_shape=(28, 28, 1)))
model.add(kl.LeakyReLU(alpha=0.2))

model.add(kl.Convolution2D(128, 5, 5, subsample=(2, 2)))
# model.add(kl.BatchNormalization())
model.add(kl.LeakyReLU(alpha=0.2))

model.add(kl.Flatten())
model.add(kl.Dense(1))
model.add(kl.Activation('sigmoid'))

D = model
D.compile(  loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))

# --------------------GAN Model--------------------

model = km.Sequential()
model.add(G)
D.trainable = False  # Is this necessary?
model.add(D)
GAN = model
GAN.compile(loss='binary_crossentropy', optimizer=ko.SGD(lr=lr, momentum=0.9, nesterov=True))

# --------------------Main Code--------------------
(X, _), _ = mnist.load_data()
X = X / 255.
X = X[:, :, :, np.newaxis]

X_batch = X[0:batch_size, :]
Z1_batch = noise_gen(batch_size, 100)
Z2_batch = noise_gen(batch_size, 100)

fake_batch = G.predict(Z1_batch)
real_batch = X_batch
print('--------------------Fake Image Generated!--------------------')

combined_X_batch = np.concatenate((real_batch, fake_batch))
combined_y_batch = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))))
print('real_batch={}, fake_batch={}'.format(real_batch.shape, fake_batch.shape))

D.trainable = True
d_loss = D.train_on_batch(combined_X_batch, combined_y_batch)
print('--------------------Discriminator trained!--------------------')
print(d_loss)

D.trainable = False
g_loss = GAN.train_on_batch(Z2_batch, np.ones((batch_size, 1)))
print('--------------------GAN trained!--------------------')
print(g_loss)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!