After you train a model in Tensorflow:
According to the new Tensorflow version, tf.train.Checkpoint
is the preferable way of saving and restoring a model:
Checkpoint.save
andCheckpoint.restore
write and read object-based checkpoints, in contrast to tf.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables when executing eagerly. Prefertf.train.Checkpoint
overtf.train.Saver
for new code.
Here is an example:
import tensorflow as tf
import os
tf.enable_eager_execution()
checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
optimizer.minimize( ... ) # Variables will be restored on creation.
status.assert_consumed() # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)
More information and example here.
Adapted from the docs
# -------------------------
# ----- Toy Context -----
# -------------------------
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def toy_dataset():
inputs = tf.range(10.0)[:, None]
labels = inputs * 5.0 + tf.range(5.0)[None, :]
return (
tf.data.Dataset.from_tensor_slices(dict(x=inputs, y=labels)).repeat().batch(2)
)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example["x"])
loss = tf.reduce_mean(tf.abs(output - example["y"]))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
# ----------------------------
# ----- Create Objects -----
# ----------------------------
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# ----------------------------
# ----- Train and Save -----
# ----------------------------
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
# ---------------------
# ----- Restore -----
# ---------------------
# In another script, re-initialize objects
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(
step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator
)
manager = tf.train.CheckpointManager(ckpt, "./tf_ckpts", max_to_keep=3)
# Re-use the manager code above ^
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
# Continue training or evaluate etc.
exhaustive and useful tutorial on saved_model
-> https://www.tensorflow.org/guide/saved_model
keras
detailed guide to save models -> https://www.tensorflow.org/guide/keras/save_and_serialize
Checkpoints capture the exact value of all parameters (tf.Variable objects) used by a model. Checkpoints do not contain any description of the computation defined by the model and thus are typically only useful when source code that will use the saved parameter values is available.
The SavedModel format on the other hand includes a serialized description of the computation defined by the model in addition to the parameter values (checkpoint). Models in this format are independent of the source code that created the model. They are thus suitable for deployment via TensorFlow Serving, TensorFlow Lite, TensorFlow.js, or programs in other programming languages (the C, C++, Java, Go, Rust, C# etc. TensorFlow APIs).
(Highlights are my own)
From the docs:
# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
inc_v1.op.run()
dec_v2.op.run()
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in path: %s" % save_path)
tf.reset_default_graph()
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
simple_save
Many good answer, for completeness I'll add my 2 cents: simple_save. Also a standalone code example using the tf.data.Dataset
API.
Python 3 ; Tensorflow 1.14
import tensorflow as tf
from tensorflow.saved_model import tag_constants
with tf.Graph().as_default():
with tf.Session() as sess:
...
# Saving
inputs = {
"batch_size_placeholder": batch_size_placeholder,
"features_placeholder": features_placeholder,
"labels_placeholder": labels_placeholder,
}
outputs = {"prediction": model_output}
tf.saved_model.simple_save(
sess, 'path/to/your/location/', inputs, outputs
)
Restoring:
graph = tf.Graph()
with restored_graph.as_default():
with tf.Session() as sess:
tf.saved_model.loader.load(
sess,
[tag_constants.SERVING],
'path/to/your/location/',
)
batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')
sess.run(prediction, feed_dict={
batch_size_placeholder: some_value,
features_placeholder: some_other_value,
labels_placeholder: another_value
})
Original blog post
The following code generates random data for the sake of the demonstration.
Dataset
and then its Iterator
. We get the iterator's generated tensor, called input_tensor
which will serve as input to our model.input_tensor
: a GRU-based bidirectional RNN followed by a dense classifier. Because why not.softmax_cross_entropy_with_logits
, optimized with Adam
. After 2 epochs (of 2 batches each), we save the "trained" model with tf.saved_model.simple_save
. If you run the code as is, then the model will be saved in a folder called simple/
in your current working directory.tf.saved_model.loader.load
. We grab the placeholders and logits with graph.get_tensor_by_name
and the Iterator
initializing operation with graph.get_operation_by_name
.Code:
import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
def model(graph, input_tensor):
"""Create the model which consists of
a bidirectional rnn (GRU(10)) followed by a dense classifier
Args:
graph (tf.Graph): Tensors' graph
input_tensor (tf.Tensor): Tensor fed as input to the model
Returns:
tf.Tensor: the model's output layer Tensor
"""
cell = tf.nn.rnn_cell.GRUCell(10)
with graph.as_default():
((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell,
cell_bw=cell,
inputs=input_tensor,
sequence_length=[10] * 32,
dtype=tf.float32,
swap_memory=True,
scope=None)
outputs = tf.concat((fw_outputs, bw_outputs), 2)
mean = tf.reduce_mean(outputs, axis=1)
dense = tf.layers.dense(mean, 5, activation=None)
return dense
def get_opt_op(graph, logits, labels_tensor):
"""Create optimization operation from model's logits and labels
Args:
graph (tf.Graph): Tensors' graph
logits (tf.Tensor): The model's output without activation
labels_tensor (tf.Tensor): Target labels
Returns:
tf.Operation: the operation performing a stem of Adam optimizer
"""
with graph.as_default():
with tf.variable_scope('loss'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels_tensor, name='xent'),
name="mean-xent"
)
with tf.variable_scope('optimizer'):
opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
return opt_op
if __name__ == '__main__':
# Set random seed for reproducibility
# and create synthetic data
np.random.seed(0)
features = np.random.randn(64, 10, 30)
labels = np.eye(5)[np.random.randint(0, 5, (64,))]
graph1 = tf.Graph()
with graph1.as_default():
# Random seed for reproducibility
tf.set_random_seed(0)
# Placeholders
batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
# Dataset
dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
dataset = dataset.batch(batch_size_ph)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
input_tensor, labels_tensor = iterator.get_next()
# Model
logits = model(graph1, input_tensor)
# Optimization
opt_op = get_opt_op(graph1, logits, labels_tensor)
with tf.Session(graph=graph1) as sess:
# Initialize variables
tf.global_variables_initializer().run(session=sess)
for epoch in range(3):
batch = 0
# Initialize dataset (could feed epochs in Dataset.repeat(epochs))
sess.run(
dataset_init_op,
feed_dict={
features_data_ph: features,
labels_data_ph: labels,
batch_size_ph: 32
})
values = []
while True:
try:
if epoch < 2:
# Training
_, value = sess.run([opt_op, logits])
print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
batch += 1
else:
# Final inference
values.append(sess.run(logits))
print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
batch += 1
except tf.errors.OutOfRangeError:
break
# Save model state
print('\nSaving...')
cwd = os.getcwd()
path = os.path.join(cwd, 'simple')
shutil.rmtree(path, ignore_errors=True)
inputs_dict = {
"batch_size_ph": batch_size_ph,
"features_data_ph": features_data_ph,
"labels_data_ph": labels_data_ph
}
outputs_dict = {
"logits": logits
}
tf.saved_model.simple_save(
sess, path, inputs_dict, outputs_dict
)
print('Ok')
# Restoring
graph2 = tf.Graph()
with graph2.as_default():
with tf.Session(graph=graph2) as sess:
# Restore saved values
print('\nRestoring...')
tf.saved_model.loader.load(
sess,
[tag_constants.SERVING],
path
)
print('Ok')
# Get restored placeholders
labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
# Get restored model output
restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
# Get dataset initializing operation
dataset_init_op = graph2.get_operation_by_name('dataset_init')
# Initialize restored dataset
sess.run(
dataset_init_op,
feed_dict={
features_data_ph: features,
labels_data_ph: labels,
batch_size_ph: 32
}
)
# Compute inference for both batches in dataset
restored_values = []
for i in range(2):
restored_values.append(sess.run(restored_logits))
print('Restored values: ', restored_values[i][0])
# Check if original inference and restored inference are equal
valid = all((v == rv).all() for v, rv in zip(values, restored_values))
print('\nInferences match: ', valid)
This will print:
$ python3 save_and_restore.py
Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595 0.12804556 0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045 -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792 -0.00602257 0.07465433 0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984 0.05981954 -0.15913513 -0.3244143 0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358]
Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok
Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ]
Restored values: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358]
Inferences match: True