I have a trained Keras model and I would like:
1) to replace Con2D layer with the same but without bias.
2) to add BatchNormalization layer before first Activati
This was how i did it:
import keras
from keras.models import Model
from tqdm import tqdm
from keras import backend as K
def make_list(X):
if isinstance(X, list):
return X
return [X]
def list_no_list(X):
if len(X) == 1:
return X[0]
return X
def replace_layer(model, replace_layer_subname, replacement_fn,
**kwargs):
"""
args:
model :: keras.models.Model instance
replace_layer_subname :: str -- if str in layer name, replace it
replacement_fn :: fn to call to replace all instances
> fn output must produce shape as the replaced layers input
returns:
new model with replaced layers
quick examples:
want to just remove all layers with 'batch_norm' in the name:
> new_model = replace_layer(model, 'batch_norm', lambda **kwargs : (lambda u:u))
want to replace all Conv1D(N, m, padding='same') with an LSTM (lets say all have 'conv1d' in name)
> new_model = replace_layer(model, 'conv1d', lambda layer, **kwargs: LSTM(units=layer.filters, return_sequences=True)
"""
model_inputs = []
model_outputs = []
tsr_dict = {}
model_output_names = [out.name for out in make_list(model.output)]
for i, layer in enumerate(model.layers):
### Loop if layer is used multiple times
for j in range(len(layer._inbound_nodes)):
### check layer inp/outp
inpt_names = [inp.name for inp in make_list(layer.get_input_at(j))]
outp_names = [out.name for out in make_list(layer.get_output_at(j))]
### setup model inputs
if 'input' in layer.name:
for inpt_tsr in make_list(layer.get_output_at(j)):
model_inputs.append(inpt_tsr)
tsr_dict[inpt_tsr.name] = inpt_tsr
continue
### setup layer inputs
inpt = list_no_list([tsr_dict[name] for name in inpt_names])
### remake layer
if replace_layer_subname in layer.name:
print('replacing '+layer.name)
x = replacement_fn(old_layer=layer, **kwargs)(inpt)
else:
x = layer(inpt)
### reinstantialize outputs into dict
for name, out_tsr in zip(outp_names, make_list(x)):
### check if is an output
if name in model_output_names:
model_outputs.append(out_tsr)
tsr_dict[name] = out_tsr
return Model(model_inputs, model_outputs)
I have a custom layer (taken from someone online) called BatchNormalizationFreeze, so an example of usage is this:
new_model = model_replacement(model, 'batch_normal', lambda **kwargs : BatchNormalizationFreeze()(x))
If youre gonna do multiple layers just replace the replacement function with a psuedo model that does them all at once