After you train a model in Tensorflow:
Following @Vishnuvardhan Janapati 's answer, here is another way to save and reload model with custom layer/metric/loss under TensorFlow 2.0.0
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects
# custom loss (for example)
def custom_loss(y_true,y_pred):
return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss})
# custom loss (for example)
class CustomLayer(Layer):
def __init__(self, ...):
...
# define custom layer and all necessary custom operations inside custom layer
get_custom_objects().update({'CustomLayer': CustomLayer})
In this way, once you have executed such codes, and saved your model with tf.keras.models.save_model
or model.save
or ModelCheckpoint
callback, you can re-load your model without the need of precise custom objects, as simple as
new_model = tf.keras.models.load_model("./model.h5"})
Use tf.train.Saver to save a model, remerber, you need to specify the var_list, if you want to reduce the model size. The val_list can be tf.trainable_variables or tf.global_variables.
For tensorflow-2.0
it's very simple.
import tensorflow as tf
model.save("model_name")
model = tf.keras.models.load_model('model_name')
For tensorflow 2.0, it is as simple as
# Save the model model.save('path_to_my_model.h5')
To restore:
new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')
In most cases, saving and restoring from disk using a tf.train.Saver
is your best option:
... # build your model
saver = tf.train.Saver()
with tf.Session() as sess:
... # train the model
saver.save(sess, "/tmp/my_great_model")
with tf.Session() as sess:
saver.restore(sess, "/tmp/my_great_model")
... # use the model
You can also save/restore the graph structure itself (see the MetaGraph documentation for details). By default, the Saver
saves the graph structure into a .meta
file. You can call import_meta_graph()
to restore it. It restores the graph structure and returns a Saver
that you can use to restore the model's state:
saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")
with tf.Session() as sess:
saver.restore(sess, "/tmp/my_great_model")
... # use the model
However, there are cases where you need something much faster. For example, if you implement early stopping, you want to save checkpoints every time the model improves during training (as measured on the validation set), then if there is no progress for some time, you want to roll back to the best model. If you save the model to disk every time it improves, it will tremendously slow down training. The trick is to save the variable states to memory, then just restore them later:
... # build your model
# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]
with tf.Session() as sess:
... # train the model
# when needed, save the model state to memory
gvars_state = sess.run(gvars)
# when needed, restore the model state
feed_dict = {init_value: val
for init_value, val in zip(init_values, gvars_state)}
sess.run(assign_ops, feed_dict=feed_dict)
A quick explanation: when you create a variable X
, TensorFlow automatically creates an assignment operation X/Assign
to set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]
) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dict
and replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.
You can save the variables in the network using
saver = tf.train.Saver()
saver.save(sess, 'path of save/fileName.ckpt')
To restore the network for reuse later or in another script, use:
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....)
Important points:
sess
must be same between first and later runs (coherent structure). saver.restore
needs the path of the folder of the saved files, not an individual file path.