Implement early stopping in tf.estimator.DNNRegressor using the available training hooks

前端 未结 1 1546
野趣味
野趣味 2021-02-20 12:43

I am new to tensorflow and want to implement early stopping in tf.estimator.DNNRegressor with available training hooksTraining Hooks for the MNIST dataset. The ear

相关标签:
1条回答
  • 2021-02-20 13:12

    Here is a EarlyStoppingHook sample implementation:

    import numpy as np
    import tensorflow as tf
    import logging
    from tensorflow.python.training import session_run_hook
    
    
    class EarlyStoppingHook(session_run_hook.SessionRunHook):
        """Hook that requests stop at a specified step."""
    
        def __init__(self, monitor='val_loss', min_delta=0, patience=0,
                     mode='auto'):
            """
            """
            self.monitor = monitor
            self.patience = patience
            self.min_delta = min_delta
            self.wait = 0
            if mode not in ['auto', 'min', 'max']:
                logging.warning('EarlyStopping mode %s is unknown, '
                                'fallback to auto mode.', mode, RuntimeWarning)
                mode = 'auto'
    
            if mode == 'min':
                self.monitor_op = np.less
            elif mode == 'max':
                self.monitor_op = np.greater
            else:
                if 'acc' in self.monitor:
                    self.monitor_op = np.greater
                else:
                    self.monitor_op = np.less
    
            if self.monitor_op == np.greater:
                self.min_delta *= 1
            else:
                self.min_delta *= -1
    
            self.best = np.Inf if self.monitor_op == np.less else -np.Inf
    
        def begin(self):
            # Convert names to tensors if given
            graph = tf.get_default_graph()
            self.monitor = graph.as_graph_element(self.monitor)
            if isinstance(self.monitor, tf.Operation):
                self.monitor = self.monitor.outputs[0]
    
        def before_run(self, run_context):  # pylint: disable=unused-argument
            return session_run_hook.SessionRunArgs(self.monitor)
    
        def after_run(self, run_context, run_values):
            current = run_values.results
    
            if self.monitor_op(current - self.min_delta, self.best):
                self.best = current
                self.wait = 0
            else:
                self.wait += 1
                if self.wait >= self.patience:
                    run_context.request_stop()
    

    This implementation is based on Keras implementation.

    To use it with CNN MNIST example create hook and pass it to train.

    early_stopping_hook = EarlyStoppingHook(monitor='sparse_softmax_cross_entropy_loss/value', patience=10)
    
    mnist_classifier.train(
      input_fn=train_input_fn,
      steps=20000,
      hooks=[logging_hook, early_stopping_hook])
    

    Here sparse_softmax_cross_entropy_loss/value is the name of the loss op in that example.

    EDIT 1:

    It looks like there is no "official" way of finding loss node when using estimators (or I can't find it).

    For the DNNRegressor this node has name dnn/head/weighted_loss/Sum.

    Here is how to find it in the graph:

    1. Start tensorboard in model directory. In my case I didn't set any directory so estimator used temporary directory and printed this line:
      WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpInj8SC
      Start tensorboard:

      tensorboard --logdir /tmp/tmpInj8SC
      
    2. Open it in browser and navigate to GRAPHS tab.

    3. Find loss in the graph. Expand blocks in the sequence: dnnheadweighted_loss and click on the Sum node (note that there is summary node named loss connected to it).

    4. Name shown in the info "window" to the right is the name of the selected node, that need to be passed to monitor argument pf EarlyStoppingHook.

    Loss node of the DNNClassifier has the same name by default. Both DNNClassifier and DNNRegressor have optional argument loss_reduction that influences loss node name and behavior (defaults to losses.Reduction.SUM).

    EDIT 2:

    There is a way of finding loss without looking at the graph.
    You can use GraphKeys.LOSSES collection to get the loss. But this way will work only after training started. So you can use it only in a hook.

    For example you can remove monitor argument from the EarlyStoppingHook class and change its begin function to always use the first loss in the collection:

    self.monitor = tf.get_default_graph().get_collection(tf.GraphKeys.LOSSES)[0]
    

    You also probably need to check that there is a loss in the collection.

    0 讨论(0)
提交回复
热议问题