Can I use dictionary in keras customized model?

允我心安 提交于 2021-02-08 09:51:41

问题


I recently read a paper about UNet++,and I want to implement this structure with tensorflow-2.0 and keras customized model. As the structure is so complicated, I decided to manage the keras layers by a dictionary. Everything went well in training, but an error occurred while saving the model. Here is a minimum code to show the error:

class DicModel(tf.keras.Model):
    def __init__(self):
        super(DicModel, self).__init__(name='SequenceEECNN')
        self.c = {}
        self.c[0] = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, 3,activation='relu',padding='same'),
            tf.keras.layers.BatchNormalization()]
        )
        self.c[1] = tf.keras.layers.Conv2D(3,3,activation='softmax',padding='same')
    def call(self,images):
        x = self.c[0](images)
        x = self.c[1](x)
        return x

X_train,y_train = load_data()
X_test,y_test = load_data()

class_weight.compute_class_weight('balanced',np.ravel(np.unique(y_train)),np.ravel(y_train))

model = DicModel()
model_name = 'test'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs/'+model_name+'/')
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=100,mode='min')

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['accuracy'])

results = model.fit(X_train,y_train,batch_size=4,epochs=5,validation_data=(X_test,y_test),
                    callbacks=[tensorboard_callback,early_stop_callback],
                    class_weight=[0.2,2.0,100.0])

model.save_weights('model/'+model_name,save_format='tf')

The error information is:

Traceback (most recent call last):

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/learn_tf2/test_model.py", line 61, in \<module>

    model.save_weights('model/'+model_name,save_format='tf')

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1328, in save_weights

    self.\_trackable_saver.save(filepath, session=session)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1106, in save

    file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1046, in \_save_cached_when_graph_building

    object_graph_tensor=object_graph_tensor)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1014, in \_gather_saveables

    feed_additions) = self.\_graph_view.serialize_object_graph()

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 379, in serialize_object_graph

    trackable_objects, path_to_root = self.\_breadth_first_traversal()

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 199, in \_breadth_first_traversal

    for name, dependency in self.list_dependencies(current_trackable):

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 159, in list_dependencies

    return obj.\_checkpoint_dependencies

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 690, in \_\_getattribute\_\_

    return object.\_\_getattribute\_\_(self, name)

  File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 732, in \_checkpoint_dependencies

    "ignored." % (self,))

ValueError: Unable to save the object {0: \<tensorflow.python.keras.engine.sequential.Sequential object at 0x7fb5c6c36588>, 1: \<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb5c6c36630>} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.



If you don't need this dictionary checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency object; it will be automatically un-wrapped and subsequently ignored.

The tf.contrib.checkpoint.NoDependency seems has been removed from Tensorflow-2.0 (https://medium.com/tensorflow/whats-coming-in-tensorflow-2-0-d3663832e9b8). How can I fix this issue? Or should I just give up using dictionary in customized Keras Model. Thank you for your time and helps!


回答1:


Use string keys. For some reason tensorflow doesn't like int keys.




回答2:


The exception message was incorrect in Tensorflow 2.0 and has been fixed in 2.2

You can avoid the problem by wrapping the c attribute like this

from tensorflow.python.training.tracking.data_structures import NoDependency
self.c = NoDependency({})

For more details check this issue.



来源:https://stackoverflow.com/questions/57517992/can-i-use-dictionary-in-keras-customized-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!