How to load trained autoencoder weights for decoder?

前端 未结 1 612
迷失自我
迷失自我 2021-01-24 10:44

I have a CNN 1d autoencoder which has a dense central layer. I would like to train this Autoencoder and save its model. I would also like to save the decoder part, with this goa

相关标签:
1条回答
  • 2021-01-24 11:07

    You'll need to: (1) save weights of AE (autoencoder); (2) load weights file; (3) deserialize the file and assign only those weights that are compatible with the new model (decoder).

    • (1): .save does include the weights, but with an extra deserialization step that's spared by using .save_weights instead. Also, .save saves optimizer state and model architecture, latter which is irrelevant for your new decoder
    • (2): load_weights by default attempts to assign all saved weights, which won't work

    Code below accomplishes (3) (and remedies (2)) as follows:

    1. Load all weights
    2. Retrieve loaded weight names and store them in file_layer_names (list)
    3. Retrieve current model weight names and store them in model_layer_names (list)
    4. Iterate over file_layer_names as name; if name is in model_layer_names, append loaded weight with that name to weight_values_to_load
    5. Assign weights in weight_values_to_load to model using K.batch_set_value

    Note that this requires you to name every layer in both AE and decoder models and make them match. It's possible to rewrite this code to brute-force assign sequentially in a try-except loop, but that's both inefficient and bug-prone.


    Usage:

    ## omitted; use code as in question but name all ## DECODER layers as below
    autoencoder.save_weights('autoencoder_weights.h5')
    
    ## DECODER (independent)
    decoder_input = Input(batch_shape=K.int_shape(x))
    y = Conv1D(32, 3, activation='tanh',padding='valid',name='decod_conv1d_1')(decoder_input)
    y = UpSampling1D(2, name='decod_upsampling1d_1')(y)
    y = Conv1D(256, 3, activation='tanh', padding='valid', name='decod_conv1d_2')(y)
    y = UpSampling1D(2, name='decod_upsampling1d_2')(y)
    y = Flatten(name='decod_flatten')(y)
    y = Dense(501, name='decod_dense1')(y)
    decoded = Reshape((501,1), name='decod_reshape')(y)
    
    decoder = Model(decoder_input, decoded)
    decoder.save_weights('decoder_weights.h5')
    
    load_weights(decoder, 'autoencoder_weights.h5')
    

    Function:

    import h5py
    import keras.backend as K
    
    def load_weights(model, filepath):
        with h5py.File(filepath, mode='r') as f:
            file_layer_names = [n.decode('utf8') for n in f.attrs['layer_names']]
            model_layer_names = [layer.name for layer in model.layers]
    
            weight_values_to_load = []
            for name in file_layer_names:
                if name not in model_layer_names:
                    print(name, "is ignored; skipping")
                    continue
                g = f[name]
                weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
    
                weight_values = []
                if len(weight_names) != 0:
                    weight_values = [g[weight_name] for weight_name in weight_names]
                try:
                    layer = model.get_layer(name=name)
                except:
                    layer = None
                if layer is not None:
                    symbolic_weights = (layer.trainable_weights + 
                                        layer.non_trainable_weights)
                    if len(symbolic_weights) != len(weight_values):
                        print('Model & file weights shapes mismatch')
                    else:
                        weight_values_to_load += zip(symbolic_weights, weight_values)
    
            K.batch_set_value(weight_values_to_load)
    
    0 讨论(0)
提交回复
热议问题