I am trying to save a Keras model in a H5 file. The Keras model has a custom layer. When I try to restore the model, I get the following er
Correction number 1 is to use Custom_Objects
while loading
the Saved Model
i.e., replace the code,
new_model = tf.keras.models.load_model('model.h5')
with
new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})
Since we are using Custom Layers
to build
the Model
and before Saving
it, we should use Custom Objects
while Loading
it.
Correction number 2 is to add **kwargs
in the __init__
function of the Custom Layer like
def __init__(self, k, name=None, **kwargs):
super(CustomLayer, self).__init__(name=name)
self.k = k
super(CustomLayer, self).__init__(**kwargs)
Complete working code is shown below:
import tensorflow as tf
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, k, name=None, **kwargs):
super(CustomLayer, self).__init__(name=name)
self.k = k
super(CustomLayer, self).__init__(**kwargs)
def get_config(self):
config = super(CustomLayer, self).get_config()
config.update({"k": self.k})
return config
def call(self, input):
return tf.multiply(input, 2)
model = tf.keras.models.Sequential([
tf.keras.Input(name='input_layer', shape=(10,)),
CustomLayer(10, name='custom_layer'),
tf.keras.layers.Dense(1, activation='sigmoid', name='output_layer')
])
tf.keras.models.save_model(model, 'model.h5')
new_model = tf.keras.models.load_model('model.h5', custom_objects={'CustomLayer': CustomLayer})
print(new_model.summary())
Output of the above code is shown below:
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
custom_layer_1 (CustomLayer) (None, 10) 0
_________________________________________________________________
output_layer (Dense) (None, 1) 11
=================================================================
Total params: 11
Trainable params: 11
Non-trainable params: 0
Hope this helps. Happy Learning!