How to replace (or insert) intermediate layer in Keras model?

后端 未结 4 459
温柔的废话
温柔的废话 2021-01-30 14:12

I have a trained Keras model and I would like:

1) to replace Con2D layer with the same but without bias.

2) to add BatchNormalization layer before first Activati

4条回答
  •  攒了一身酷
    2021-01-30 14:19

    This was how i did it:

    import keras 
    from keras.models import Model 
    from tqdm import tqdm 
    from keras import backend as K
    
    def make_list(X):
        if isinstance(X, list):
            return X
        return [X]
    
    def list_no_list(X):
        if len(X) == 1:
            return X[0]
        return X
    
    def replace_layer(model, replace_layer_subname, replacement_fn,
    **kwargs):
        """
        args:
            model :: keras.models.Model instance
            replace_layer_subname :: str -- if str in layer name, replace it
            replacement_fn :: fn to call to replace all instances
                > fn output must produce shape as the replaced layers input
        returns:
            new model with replaced layers
        quick examples:
            want to just remove all layers with 'batch_norm' in the name:
                > new_model = replace_layer(model, 'batch_norm', lambda **kwargs : (lambda u:u))
            want to replace all Conv1D(N, m, padding='same') with an LSTM (lets say all have 'conv1d' in name)
                > new_model = replace_layer(model, 'conv1d', lambda layer, **kwargs: LSTM(units=layer.filters, return_sequences=True)
        """
        model_inputs = []
        model_outputs = []
        tsr_dict = {}
    
        model_output_names = [out.name for out in make_list(model.output)]
    
        for i, layer in enumerate(model.layers):
            ### Loop if layer is used multiple times
            for j in range(len(layer._inbound_nodes)):
    
                ### check layer inp/outp
                inpt_names = [inp.name for inp in make_list(layer.get_input_at(j))]
                outp_names = [out.name for out in make_list(layer.get_output_at(j))]
    
                ### setup model inputs
                if 'input' in layer.name:
                    for inpt_tsr in make_list(layer.get_output_at(j)):
                        model_inputs.append(inpt_tsr)
                        tsr_dict[inpt_tsr.name] = inpt_tsr
                    continue
    
                ### setup layer inputs
                inpt = list_no_list([tsr_dict[name] for name in inpt_names])
    
                ### remake layer 
                if replace_layer_subname in layer.name:
                    print('replacing '+layer.name)
                    x = replacement_fn(old_layer=layer, **kwargs)(inpt)
                else:
                    x = layer(inpt)
    
                ### reinstantialize outputs into dict
                for name, out_tsr in zip(outp_names, make_list(x)):
    
                    ### check if is an output
                    if name in model_output_names:
                        model_outputs.append(out_tsr)
                    tsr_dict[name] = out_tsr
    
        return Model(model_inputs, model_outputs)
    

    I have a custom layer (taken from someone online) called BatchNormalizationFreeze, so an example of usage is this:

     new_model = model_replacement(model, 'batch_normal', lambda **kwargs : BatchNormalizationFreeze()(x))
    

    If youre gonna do multiple layers just replace the replacement function with a psuedo model that does them all at once

提交回复
热议问题