How to replace (or insert) intermediate layer in Keras model?

后端 未结 4 451
温柔的废话
温柔的废话 2021-01-30 14:12

I have a trained Keras model and I would like:

1) to replace Con2D layer with the same but without bias.

2) to add BatchNormalization layer before first Activati

相关标签:
4条回答
  • 2021-01-30 14:19

    This was how i did it:

    import keras 
    from keras.models import Model 
    from tqdm import tqdm 
    from keras import backend as K
    
    def make_list(X):
        if isinstance(X, list):
            return X
        return [X]
    
    def list_no_list(X):
        if len(X) == 1:
            return X[0]
        return X
    
    def replace_layer(model, replace_layer_subname, replacement_fn,
    **kwargs):
        """
        args:
            model :: keras.models.Model instance
            replace_layer_subname :: str -- if str in layer name, replace it
            replacement_fn :: fn to call to replace all instances
                > fn output must produce shape as the replaced layers input
        returns:
            new model with replaced layers
        quick examples:
            want to just remove all layers with 'batch_norm' in the name:
                > new_model = replace_layer(model, 'batch_norm', lambda **kwargs : (lambda u:u))
            want to replace all Conv1D(N, m, padding='same') with an LSTM (lets say all have 'conv1d' in name)
                > new_model = replace_layer(model, 'conv1d', lambda layer, **kwargs: LSTM(units=layer.filters, return_sequences=True)
        """
        model_inputs = []
        model_outputs = []
        tsr_dict = {}
    
        model_output_names = [out.name for out in make_list(model.output)]
    
        for i, layer in enumerate(model.layers):
            ### Loop if layer is used multiple times
            for j in range(len(layer._inbound_nodes)):
    
                ### check layer inp/outp
                inpt_names = [inp.name for inp in make_list(layer.get_input_at(j))]
                outp_names = [out.name for out in make_list(layer.get_output_at(j))]
    
                ### setup model inputs
                if 'input' in layer.name:
                    for inpt_tsr in make_list(layer.get_output_at(j)):
                        model_inputs.append(inpt_tsr)
                        tsr_dict[inpt_tsr.name] = inpt_tsr
                    continue
    
                ### setup layer inputs
                inpt = list_no_list([tsr_dict[name] for name in inpt_names])
    
                ### remake layer 
                if replace_layer_subname in layer.name:
                    print('replacing '+layer.name)
                    x = replacement_fn(old_layer=layer, **kwargs)(inpt)
                else:
                    x = layer(inpt)
    
                ### reinstantialize outputs into dict
                for name, out_tsr in zip(outp_names, make_list(x)):
    
                    ### check if is an output
                    if name in model_output_names:
                        model_outputs.append(out_tsr)
                    tsr_dict[name] = out_tsr
    
        return Model(model_inputs, model_outputs)
    

    I have a custom layer (taken from someone online) called BatchNormalizationFreeze, so an example of usage is this:

     new_model = model_replacement(model, 'batch_normal', lambda **kwargs : BatchNormalizationFreeze()(x))
    

    If youre gonna do multiple layers just replace the replacement function with a psuedo model that does them all at once

    0 讨论(0)
  • 2021-01-30 14:32

    Unfortunately replacing a layer is no small feat for models that do not follow the sequential pattern. For sequential patterns it is OK to just x = layer(x) and replace with new_layer when you see fit as in the previous answer. However, for models that do not have a classic sequential pattern (say you have a simple "concatenation" of two columns) you have to actually "parse" the graph and use your "new_layer" (or layers) in the right places. Hope this is not too discouraging and happy graph parsing and reconstructing :)

    0 讨论(0)
  • 2021-01-30 14:36

    The following function allows you to insert a new layer before, after or to replace each layer in the original model whose name matches a regular expression, including non-sequential models such as DenseNet or ResNet.

    import re
    from keras.models import Model
    
    def insert_layer_nonseq(model, layer_regex, insert_layer_factory,
                            insert_layer_name=None, position='after'):
    
        # Auxiliary dictionary to describe the network graph
        network_dict = {'input_layers_of': {}, 'new_output_tensor_of': {}}
    
        # Set the input layers of each layer
        for layer in model.layers:
            for node in layer._outbound_nodes:
                layer_name = node.outbound_layer.name
                if layer_name not in network_dict['input_layers_of']:
                    network_dict['input_layers_of'].update(
                            {layer_name: [layer.name]})
                else:
                    network_dict['input_layers_of'][layer_name].append(layer.name)
    
        # Set the output tensor of the input layer
        network_dict['new_output_tensor_of'].update(
                {model.layers[0].name: model.input})
    
        # Iterate over all layers after the input
        model_outputs = []
        for layer in model.layers[1:]:
    
            # Determine input tensors
            layer_input = [network_dict['new_output_tensor_of'][layer_aux] 
                    for layer_aux in network_dict['input_layers_of'][layer.name]]
            if len(layer_input) == 1:
                layer_input = layer_input[0]
    
            # Insert layer if name matches the regular expression
            if re.match(layer_regex, layer.name):
                if position == 'replace':
                    x = layer_input
                elif position == 'after':
                    x = layer(layer_input)
                elif position == 'before':
                    pass
                else:
                    raise ValueError('position must be: before, after or replace')
    
                new_layer = insert_layer_factory()
                if insert_layer_name:
                    new_layer.name = insert_layer_name
                else:
                    new_layer.name = '{}_{}'.format(layer.name, 
                                                    new_layer.name)
                x = new_layer(x)
                print('New layer: {} Old layer: {} Type: {}'.format(new_layer.name,
                                                                layer.name, position))
                if position == 'before':
                    x = layer(x)
            else:
                x = layer(layer_input)
    
            # Set new output tensor (the original one, or the one of the inserted
            # layer)
            network_dict['new_output_tensor_of'].update({layer.name: x})
    
            # Save tensor in output list if it is output in initial model
            if layer_name in model.output_names:
                model_outputs.append(x)
    
        return Model(inputs=model.inputs, outputs=model_outputs)
    
    

    The difference with respect to the simpler case of a purely sequential model is that before iterating over the layers to find the key layer, you first parse the graph and store the input layers of each layer in an auxiliary dictionary. Then, as you iterate over the layers, you also store the new output tensor of each layer, which is used to determine the input layers of each layer, when building the new model.

    A use case would be the following, where a Dropout layer is inserted after each activation layer of ResNet50:

    from keras.applications.resnet50 import ResNet50
    from keras.models import load_model
    
    model = ResNet50()
    def dropout_layer_factory():
        return Dropout(rate=0.2, name='dropout')
    model = insert_layer_nonseq(model, '.*activation.*', dropout_layer_factory)
    
    # Fix possible problems with new model
    model.save('temp.h5')
    model = load_model('temp.h5')
    
    model.summary()
    
    0 讨论(0)
  • 2021-01-30 14:42

    You can use the following functions:

    def replace_intermediate_layer_in_keras(model, layer_id, new_layer):
        from keras.models import Model
    
        layers = [l for l in model.layers]
    
        x = layers[0].output
        for i in range(1, len(layers)):
            if i == layer_id:
                x = new_layer(x)
            else:
                x = layers[i](x)
    
        new_model = Model(input=layers[0].input, output=x)
        return new_model
    
    def insert_intermediate_layer_in_keras(model, layer_id, new_layer):
        from keras.models import Model
    
        layers = [l for l in model.layers]
    
        x = layers[0].output
        for i in range(1, len(layers)):
            if i == layer_id:
                x = new_layer(x)
            x = layers[i](x)
    
        new_model = Model(input=layers[0].input, output=x)
        return new_model
    

    Example:

    if __name__ == '__main__':
        from keras.layers import Conv2D, BatchNormalization
        model = keras_simple_model()
        print(model.summary())
        model = replace_intermediate_layer_in_keras(model, 3, Conv2D(4, (3, 3), activation=None, padding='same', name='conv2_repl', use_bias=False))
        print(model.summary())
        model = insert_intermediate_layer_in_keras(model, 4, BatchNormalization())
        print(model.summary())
    

    There are some limitation on replacements due to layer shapes etc.

    0 讨论(0)
提交回复
热议问题