问题
I am using chainer framework(Deep learning), suppose I have to stop iteration once two iteration's target function value's gap is little: f - old_f < eps
. but chainer.training.Trainer's stop_trigger is (args.epoch, 'epoch') tuple. how to trigger early stop?
回答1:
I implemented EarlyStoppingTrigger
example according to @Seiya Tokui's answer, based on your situation.
from chainer import reporter
from chainer.training import util
class EarlyStoppingTrigger(object):
"""Early stopping trigger
It observes the value specified by `key`, and invoke a trigger only when
observing value satisfies the `stop_condition`.
The trigger may be used to `stop_trigger` option of Trainer module for
early stopping the training.
Args:
max_epoch (int or float): Max epoch for the training, even if the value
is not reached to the condition specified by `stop_condition`,
finish the training if it reaches to `max_epoch` epoch.
key (str): Key of value to be observe for `stop_condition`.
stop_condition (callable): To check the previous value and current value
to decide early stop timing. Default value is `None`, in that case
internal `_stop_condition` method is used.
eps (float): It is used by the internal `_stop_condition`.
trigger: Trigger that decides the comparison interval between previous
best value and current value. This must be a tuple in the form of
``<int>, 'epoch'`` or ``<int>, 'iteration'`` which is passed to
:class:`~chainer.training.triggers.IntervalTrigger`.
"""
def __init__(self, max_epoch, key, stop_condition=None, eps=0.01,
trigger=(1, 'epoch')):
self.max_epoch = max_epoch
self.eps = eps
self._key = key
self._current_value = None
self._interval_trigger = util.get_trigger(trigger)
self._init_summary()
self.stop_condition = stop_condition or self._stop_condition
def __call__(self, trainer):
"""Decides whether the extension should be called on this iteration.
Args:
trainer (~chainer.training.Trainer): Trainer object that this
trigger is associated with. The ``observation`` of this trainer
is used to determine if the trigger should fire.
Returns:
bool: ``True`` if the corresponding extension should be invoked in
this iteration.
"""
epoch_detail = trainer.updater.epoch_detail
if self.max_epoch <= epoch_detail:
print('Reached to max_epoch.')
return True
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._current_value is None:
self._current_value = value
return False
else:
if self.stop_condition(self._current_value, value):
# print('Previous value {}, Current value {}'
# .format(self._current_value, value))
print('Invoke EarlyStoppingTrigger...')
self._current_value = value
return True
else:
self._current_value = value
return False
def _init_summary(self):
self._summary = reporter.DictSummary()
def _stop_condition(self, current_value, new_value):
return current_value - new_value < self.eps
Usage: You can pass it to the stop_trigger
option of trainer
,
early_stop = EarlyStoppingTrigger(args.epoch, key='validation/main/loss', eps=0.01)
trainer = training.Trainer(updater, stop_trigger=early_stop, out=args.out)
See the this gist for whole working example code.
[Note] I noticed that we also need to fix ProgressBar
extension to pass training_length
explicitly, if we use customized stop_trigger
.
回答2:
You can pass a callable object to the stop_trigger
option. The callable object is called at each iteration by passing the Trainer
object. It should return a boolean value. When the returned value is True
, the training is stopped. In order to implement early stopping, you can write your own trigger function and pass it to the stop_trigger
option of Trainer
.
Other APIs that accept a trigger object also accepts a callable; see the document of get_trigger for details.
Note: a tuple value for stop_trigger
is a short hand notation of using chainer.training.triggers.IntervalTrigger
as the callable.
来源:https://stackoverflow.com/questions/45891924/in-chainer-how-to-early-stop-iteration-using-chainer-training-trainer