How to load the Keras model with custom layers from .h5 file correctly?

前端 未结 3 783
Happy的楠姐
Happy的楠姐 2021-01-11 19:49

I built a Keras model with a custom layers, and it was saved to a .h5 file by the callback ModelCheckPoint. When I tried to load this model after

相关标签:
3条回答
  • 2021-01-11 19:55

    If you don't have enough time to retrain the model in the solution way of Matias Valdenegro. You can set the default value of pool_size in class MyMeanPooling like the following code. Note that the value of pool_size should be consistent with the value while training the model. Then you can load the model.

    class MyMeanPooling(Layer):
        def __init__(self, pool_size, axis=1, **kwargs):
            self.supports_masking = True
            self.pool_size = 2  # The value should be consistent with the value while training the model
            self.axis = axis
            self.y_shape = None
            self.y_mask = None
            super(MyMeanPooling, self).__init__(**kwargs)
    

    ref: https://www.jianshu.com/p/e97112c34e43

    0 讨论(0)
  • 2021-01-11 20:09

    Actually I don't think you can load this model.

    The most likely issue is that you did not implement the get_config() method in your layer. This method returns a dictionary of configuration values that should be saved:

    def get_config(self):
        config = {'pool_size': self.pool_size,
                  'axis': self.axis}
        base_config = super(MyMeanPooling, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
    

    You have to retrain the model after adding this method to your layer, as the previously saved model does not have the configuration for this layer saved into it. This is why you cannot load it, it requires retraining after making this change.

    0 讨论(0)
  • 2021-01-11 20:18

    From the answer of "LiamHe commented on Sep 27, 2017" on the following issue: https://github.com/keras-team/keras/issues/4871.

    I met the same problem today : ** TypeError: init() missing 1 required positional arguments**. Here is how I solve the problem : (Keras 2.0.2)

    1. Give the positional arguments of the layer with some default values
    2. Override get_config function to the layer with some thing like
    def get_config(self):
        config = super().get_config()
        config['pool_size'] = # say self._pool_size  if you store the argument in __init__
        return config
    
    1. Add layer class to custom_objects when you are loading model.
    0 讨论(0)
提交回复
热议问题