I\'m trying to understand the difference between using tf.Session
and tf.train.MonitoredTrainingSession
, and where I might prefer one over the other. I
I can't give some insights on how these classes were created, but here are a few things which I think are relevants on how you could use them.
The tf.Session
is a low level object in the python TensorFlow API while,
as you said, the tf.train.MonitoredTrainingSession
comes with a lot of handy features, especially useful in most of the common cases.
Before describing some of the benefits of tf.train.MonitoredTrainingSession
, let me answer the question about the graph used by the session. You can specify the tf.Graph
used by the MonitoredTrainingSession
by using a context manager with your_graph.as_default()
:
from __future__ import print_function
import tensorflow as tf
def example():
g1 = tf.Graph()
with g1.as_default():
# Define operations and tensors in `g`.
c1 = tf.constant(42)
assert c1.graph is g1
g2 = tf.Graph()
with g2.as_default():
# Define operations and tensors in `g`.
c2 = tf.constant(3.14)
assert c2.graph is g2
# MonitoredTrainingSession example
with g1.as_default():
with tf.train.MonitoredTrainingSession() as sess:
print(c1.eval(session=sess))
# Next line raises
# ValueError: Cannot use the given session to evaluate tensor:
# the tensor's graph is different from the session's graph.
try:
print(c2.eval(session=sess))
except ValueError as e:
print(e)
# Session example
with tf.Session(graph=g2) as sess:
print(c2.eval(session=sess))
# Next line raises
# ValueError: Cannot use the given session to evaluate tensor:
# the tensor's graph is different from the session's graph.
try:
print(c1.eval(session=sess))
except ValueError as e:
print(e)
if __name__ == '__main__':
example()
So, as you said, the benefits of using MonitoredTrainingSession
are that, this object takes care of
but it has also the benefit of making your code easy to distribute as it also works differently depending if you specified the running process as a master or not.
For example you could run something like:
def run_my_model(train_op, session_args):
with tf.train.MonitoredTrainingSession(**session_args) as sess:
sess.run(train_op)
that you would call in a non-distributed way:
run_my_model(train_op, {})`
or in a distributed way (see the distributed doc for more information on the inputs):
run_my_model(train_op, {"master": server.target,
"is_chief": (FLAGS.task_index == 0)})
On the other hand, the benefit of using the raw tf.Session
object is that, you don't have the extra benefits of tf.train.MonitoredTrainingSession
, which can be useful if you don't plan to use them or if you want to get more control (for example on how the queues are started).
EDIT (as per comment): For the op initialisation, you would have to do something like (cf. official doc:
# Define your graph and your ops
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_p)
sess.run(your_graph_ops,...)
For the QueueRunner, I would refer you to the official doc where you will find more complete examples.
EDIT2:
The main concept to understand to get a sense on how tf.train.MonitoredTrainingSession
works is the _WrappedSession
class:
This wrapper is used as a base class for various session wrappers that provide additional functionality such as monitoring, coordination, and recovery.
The tf.train.MonitoredTrainingSession
works (as of version 1.1) this way:
StopAtStepHook
would just retrieve the global_step
tensor at this stage.Chief
(or Worker
session) wrapped into a _HookedSession
wrapped into a _CoordinatedSession
wrapped into a _RecoverableSession
.Chief
/Worker
sessions are in charge of running initialising ops provided by the Scaffold
.
scaffold: A `Scaffold` used for gathering or building supportive ops. If not specified a default one is created. It's used to finalize the graph.
chief
session also takes care of all the checkpoint parts: e.g. restoring from checkpoints using the Saver
from the Scaffold
._HookedSession
is basically there to decorate the run
method: it calls the _call_hook_before_run
and after_run
methods when relevant. _CoordinatedSession
builds a Coordinator
which starts the queue runners and will be responsible of closing them._RecoverableSession
will insures that there is retry in case of tf.errors.AbortedError
.In conclusion, the tf.train.MonitoredTrainingSession
avoids a lot of boiler plate code while being easily extendable with the hooks mechanism.