问题
I recently read a paper about UNet++,and I want to implement this structure with tensorflow-2.0 and keras customized model. As the structure is so complicated, I decided to manage the keras layers by a dictionary. Everything went well in training, but an error occurred while saving the model. Here is a minimum code to show the error:
class DicModel(tf.keras.Model):
def __init__(self):
super(DicModel, self).__init__(name='SequenceEECNN')
self.c = {}
self.c[0] = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3,activation='relu',padding='same'),
tf.keras.layers.BatchNormalization()]
)
self.c[1] = tf.keras.layers.Conv2D(3,3,activation='softmax',padding='same')
def call(self,images):
x = self.c[0](images)
x = self.c[1](x)
return x
X_train,y_train = load_data()
X_test,y_test = load_data()
class_weight.compute_class_weight('balanced',np.ravel(np.unique(y_train)),np.ravel(y_train))
model = DicModel()
model_name = 'test'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='logs/'+model_name+'/')
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=100,mode='min')
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
results = model.fit(X_train,y_train,batch_size=4,epochs=5,validation_data=(X_test,y_test),
callbacks=[tensorboard_callback,early_stop_callback],
class_weight=[0.2,2.0,100.0])
model.save_weights('model/'+model_name,save_format='tf')
The error information is:
Traceback (most recent call last):
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/learn_tf2/test_model.py", line 61, in \<module>
model.save_weights('model/'+model_name,save_format='tf')
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 1328, in save_weights
self.\_trackable_saver.save(filepath, session=session)
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1106, in save
file_prefix=file_prefix_tensor, object_graph_tensor=object_graph_tensor)
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1046, in \_save_cached_when_graph_building
object_graph_tensor=object_graph_tensor)
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/util.py", line 1014, in \_gather_saveables
feed_additions) = self.\_graph_view.serialize_object_graph()
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 379, in serialize_object_graph
trackable_objects, path_to_root = self.\_breadth_first_traversal()
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 199, in \_breadth_first_traversal
for name, dependency in self.list_dependencies(current_trackable):
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/graph_view.py", line 159, in list_dependencies
return obj.\_checkpoint_dependencies
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 690, in \_\_getattribute\_\_
return object.\_\_getattribute\_\_(self, name)
File "/media/xrzhang/Data/ZHS/Research/CNN-TF2/venv/lib/python3.6/site-packages/tensorflow/python/training/tracking/data_structures.py", line 732, in \_checkpoint_dependencies
"ignored." % (self,))
ValueError: Unable to save the object {0: \<tensorflow.python.keras.engine.sequential.Sequential object at 0x7fb5c6c36588>, 1: \<tensorflow.python.keras.layers.convolutional.Conv2D object at 0x7fb5c6c36630>} (a dictionary wrapper constructed automatically on attribute assignment). The wrapped dictionary contains a non-string key which maps to a trackable object or mutable data structure.
If you don't need this dictionary checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency object; it will be automatically un-wrapped and subsequently ignored.
The tf.contrib.checkpoint.NoDependency seems has been removed from Tensorflow-2.0 (https://medium.com/tensorflow/whats-coming-in-tensorflow-2-0-d3663832e9b8). How can I fix this issue? Or should I just give up using dictionary in customized Keras Model. Thank you for your time and helps!
回答1:
Use string keys. For some reason tensorflow doesn't like int keys.
回答2:
The exception message was incorrect in Tensorflow 2.0 and has been fixed in 2.2
You can avoid the problem by wrapping the c
attribute like this
from tensorflow.python.training.tracking.data_structures import NoDependency
self.c = NoDependency({})
For more details check this issue.
来源:https://stackoverflow.com/questions/57517992/can-i-use-dictionary-in-keras-customized-model