I have a CNN 1d autoencoder which has a dense central layer. I would like to train this Autoencoder and save its model. I would also like to save the decoder part, with this goa
You'll need to: (1) save weights of AE (autoencoder); (2) load weights file; (3) deserialize the file and assign only those weights that are compatible with the new model (decoder).
.save
does include the weights, but with an extra deserialization step that's spared by using .save_weights
instead. Also, .save
saves optimizer state and model architecture, latter which is irrelevant for your new decoderload_weights
by default attempts to assign all saved weights, which won't workCode below accomplishes (3) (and remedies (2)) as follows:
file_layer_names
(list)model_layer_names
(list)file_layer_names
as name
; if name
is in model_layer_names
, append loaded weight with that name to weight_values_to_load
weight_values_to_load
to model using K.batch_set_value
Note that this requires you to name every layer in both AE and decoder models and make them match. It's possible to rewrite this code to brute-force assign sequentially in a try-except
loop, but that's both inefficient and bug-prone.
Usage:
## omitted; use code as in question but name all ## DECODER layers as below
autoencoder.save_weights('autoencoder_weights.h5')
## DECODER (independent)
decoder_input = Input(batch_shape=K.int_shape(x))
y = Conv1D(32, 3, activation='tanh',padding='valid',name='decod_conv1d_1')(decoder_input)
y = UpSampling1D(2, name='decod_upsampling1d_1')(y)
y = Conv1D(256, 3, activation='tanh', padding='valid', name='decod_conv1d_2')(y)
y = UpSampling1D(2, name='decod_upsampling1d_2')(y)
y = Flatten(name='decod_flatten')(y)
y = Dense(501, name='decod_dense1')(y)
decoded = Reshape((501,1), name='decod_reshape')(y)
decoder = Model(decoder_input, decoded)
decoder.save_weights('decoder_weights.h5')
load_weights(decoder, 'autoencoder_weights.h5')
Function:
import h5py
import keras.backend as K
def load_weights(model, filepath):
with h5py.File(filepath, mode='r') as f:
file_layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
model_layer_names = [layer.name for layer in model.layers]
weight_values_to_load = []
for name in file_layer_names:
if name not in model_layer_names:
print(name, "is ignored; skipping")
continue
g = f[name]
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
weight_values = []
if len(weight_names) != 0:
weight_values = [g[weight_name] for weight_name in weight_names]
try:
layer = model.get_layer(name=name)
except:
layer = None
if layer is not None:
symbolic_weights = (layer.trainable_weights +
layer.non_trainable_weights)
if len(symbolic_weights) != len(weight_values):
print('Model & file weights shapes mismatch')
else:
weight_values_to_load += zip(symbolic_weights, weight_values)
K.batch_set_value(weight_values_to_load)