Getting TypeError: can't pickle _thread.RLock objects

前端 未结 1 542
我寻月下人不归
我寻月下人不归 2021-01-22 05:07

Read a number of similar questions, most of them mentioned that you shouldn\'t try to serialize an unserializable object. I am not able to understand the issue. I am able to sav

相关标签:
1条回答
  • 2021-01-22 05:39

    I was able to replicate your issue in TF 2.3.0 using Google Colab

    import pickle
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense
    
    model = Sequential()
    model.add(Dense(1, input_dim=42, activation='sigmoid'))
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    
    with open('model.pkl', 'wb') as f:
        pickle.dump(model, f)
    

    Output:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-1-afb2bf58a891> in <module>()
          8 
          9 with open('model.pkl', 'wb') as f:
    ---> 10     pickle.dump(model, f)
    
    TypeError: can't pickle _thread.RLock objects
    

    @adriangb, proposed hot fix to this issue in github for more details please refer this

    import pickle
    
    from tensorflow.keras.models import Sequential, Model
    from tensorflow.keras.layers import Dense
    from tensorflow.python.keras.layers import deserialize, serialize
    from tensorflow.python.keras.saving import saving_utils
    
    
    def unpack(model, training_config, weights):
        restored_model = deserialize(model)
        if training_config is not None:
            restored_model.compile(
                **saving_utils.compile_args_from_training_config(
                    training_config
                )
            )
        restored_model.set_weights(weights)
        return restored_model
    
    # Hotfix function
    def make_keras_picklable():
    
        def __reduce__(self):
            model_metadata = saving_utils.model_metadata(self)
            training_config = model_metadata.get("training_config", None)
            model = serialize(self)
            weights = self.get_weights()
            return (unpack, (model, training_config, weights))
    
        cls = Model
        cls.__reduce__ = __reduce__
    
    # Run the function
    make_keras_picklable()
    
    # Create the model
    model = Sequential()
    model.add(Dense(1, input_dim=42, activation='sigmoid'))
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    
    # Save
    with open('model.pkl', 'wb') as f:
        pickle.dump(model, f)
    
    0 讨论(0)
提交回复
热议问题