NotImplementedError: Layers with arguments in `__init__` must override `get_config`

后端 未结 1 1787
醉酒成梦
醉酒成梦 2020-11-30 09:46

I\'m trying to save my TensorFlow model using model.save(), however - I am getting this error.

The model summary is provided here: Model Summary

相关标签:
1条回答
  • 2020-11-30 10:37

    It's not a bug, it's a feature.

    This error lets you know that TF can't save your model, because it won't be able to load it.
    Specifically, it won't be able to reinstantiate your custom Layer classes: encoder and decoder.

    To solve this, just override their get_config method according to the new arguments you've added.

    A layer config is a Python dictionary (serializable) containing the configuration of a layer. The same layer can be reinstantiated later (without its trained weights) from this configuration.


    For example, if your encoder class looks something like this:

    class encoder(tf.keras.layers.Layer):
    
        def __init__(
            self,
            vocab_size, num_layers, units, d_model, num_heads, dropout,
            **kwargs,
        ):
            super().__init__(**kwargs)
            self.vocab_size = vocab_size
            self.num_layers = num_layers
            self.units = units
            self.d_model = d_model
            self.num_heads = num_heads
            self.dropout = dropout
    
        # Other methods etc.
    

    then you only need to override this method:

        def get_config(self):
    
            config = super().get_config().copy()
            config.update({
                'vocab_size': self.vocab_size,
                'num_layers': self.num_layers,
                'units': self.units,
                'd_model': self.d_model,
                'num_heads': self.num_heads,
                'dropout': self.dropout,
            })
            return config
    

    When TF sees this (for both classes), you will be able to save the model.

    Because now when the model is loaded, TF will be able to reinstantiate the same layer from config.


    Layer.from_config's source code may give a better sense of how it works:

    @classmethod
    def from_config(cls, config):
      return cls(**config)
    
    0 讨论(0)
提交回复
热议问题