How to get the last global_step from an tf.estimator.Estimator

删除回忆录丶 提交于 2019-12-05 21:50:28
crafet

recently, I found estimator has the api get_variable_value

global_step = estimator.get_variable_value("global_step")

Simply create a hook before the training loop:

class GlobalStepHook(tf.train.SessionRunHook):
    def __init__(self):
        self._global_step_tensor = None
        self.value = None

    def begin(self):
        self._global_step_tensor = tf.train.get_global_step()

    def after_run(self, run_context, run_values):
        self.value = run_context.session.run(self._global_step_tensor)

    def __str__(self):
        return str(self.value)

global_step = GlobalStepHook()
for epoch in range(n_epochs):
    estimator.train(input_fn=input_fn, hooks=[global_step])
    # Now the global_step hook contains the latest value of global_step
    my_custom_eval_method(global_step.value)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!