How do Monitored Training Sessions work?

前端 未结 1 1190
太阳男子
太阳男子 2021-02-02 07:58

I\'m trying to understand the difference between using tf.Session and tf.train.MonitoredTrainingSession, and where I might prefer one over the other. I

1条回答
  •  日久生厌
    2021-02-02 08:19

    I can't give some insights on how these classes were created, but here are a few things which I think are relevants on how you could use them.

    The tf.Session is a low level object in the python TensorFlow API while, as you said, the tf.train.MonitoredTrainingSession comes with a lot of handy features, especially useful in most of the common cases.

    Before describing some of the benefits of tf.train.MonitoredTrainingSession, let me answer the question about the graph used by the session. You can specify the tf.Graph used by the MonitoredTrainingSession by using a context manager with your_graph.as_default():

    from __future__ import print_function
    import tensorflow as tf
    
    def example():
        g1 = tf.Graph()
        with g1.as_default():
            # Define operations and tensors in `g`.
            c1 = tf.constant(42)
            assert c1.graph is g1
    
        g2 = tf.Graph()
        with g2.as_default():
            # Define operations and tensors in `g`.
            c2 = tf.constant(3.14)
            assert c2.graph is g2
    
        # MonitoredTrainingSession example
        with g1.as_default():
            with tf.train.MonitoredTrainingSession() as sess:
                print(c1.eval(session=sess))
                # Next line raises
                # ValueError: Cannot use the given session to evaluate tensor:
                # the tensor's graph is different from the session's graph.
                try:
                    print(c2.eval(session=sess))
                except ValueError as e:
                    print(e)
    
        # Session example
        with tf.Session(graph=g2) as sess:
            print(c2.eval(session=sess))
            # Next line raises
            # ValueError: Cannot use the given session to evaluate tensor:
            # the tensor's graph is different from the session's graph.
            try:
                print(c1.eval(session=sess))
            except ValueError as e:
                print(e)
    
    if __name__ == '__main__':
        example()
    

    So, as you said, the benefits of using MonitoredTrainingSession are that, this object takes care of

    • initialising variables,
    • starting queue runner as well as
    • setting up the file writers,

    but it has also the benefit of making your code easy to distribute as it also works differently depending if you specified the running process as a master or not.

    For example you could run something like:

    def run_my_model(train_op, session_args):
        with tf.train.MonitoredTrainingSession(**session_args) as sess:
            sess.run(train_op)
    

    that you would call in a non-distributed way:

    run_my_model(train_op, {})`
    

    or in a distributed way (see the distributed doc for more information on the inputs):

    run_my_model(train_op, {"master": server.target,
                            "is_chief": (FLAGS.task_index == 0)})
    

    On the other hand, the benefit of using the raw tf.Session object is that, you don't have the extra benefits of tf.train.MonitoredTrainingSession, which can be useful if you don't plan to use them or if you want to get more control (for example on how the queues are started).

    EDIT (as per comment): For the op initialisation, you would have to do something like (cf. official doc:

    # Define your graph and your ops
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_p)
        sess.run(your_graph_ops,...)
    

    For the QueueRunner, I would refer you to the official doc where you will find more complete examples.

    EDIT2:

    The main concept to understand to get a sense on how tf.train.MonitoredTrainingSession works is the _WrappedSession class:

    This wrapper is used as a base class for various session wrappers that provide additional functionality such as monitoring, coordination, and recovery.

    The tf.train.MonitoredTrainingSession works (as of version 1.1) this way:

    • It first checks if it is a chief or a worker (cf. the distributed doc for lexical question).
    • It begins the hooks which have been provided (for example, StopAtStepHook would just retrieve the global_step tensor at this stage.
    • It creates a session which is a Chief (or Worker session) wrapped into a _HookedSession wrapped into a _CoordinatedSession wrapped into a _RecoverableSession.
      The Chief/Worker sessions are in charge of running initialising ops provided by the Scaffold.
        scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified a default one is created. It's used to finalize the graph.
      
    • The chief session also takes care of all the checkpoint parts: e.g. restoring from checkpoints using the Saver from the Scaffold.
    • The _HookedSession is basically there to decorate the run method: it calls the _call_hook_before_run and after_run methods when relevant.
    • At creation the _CoordinatedSession builds a Coordinator which starts the queue runners and will be responsible of closing them.
    • The _RecoverableSession will insures that there is retry in case of tf.errors.AbortedError.

    In conclusion, the tf.train.MonitoredTrainingSession avoids a lot of boiler plate code while being easily extendable with the hooks mechanism.

    0 讨论(0)
提交回复
热议问题