validation during training of Estimator

前端 未结 2 742
一整个雨季
一整个雨季 2020-12-15 09:25

With the TensorFlow r1.3 monitors are deprecated:

\"2016-12-05\", \"Monitors are deprecated. Please use tf.train.SessionRunHook.\") and Estimator.train(input_f

相关标签:
2条回答
  • 2020-12-15 09:38

    EDIT: As pointed out in the comments, this feels like the right thing to do, but will reinitialize the weights every time it's evaluated, which makes it pretty much useless...


    I ended up being able to monitor my validation error (which is what I understand you are trying to do) using the train_and_evaluate function. The EvalSpec object you have to use has parameters start_delay_secs and throttle_secs to define the frequency at which the error (or whatever you have defined in your estimator's EVAL mode) will be computed.

    My code looks somewhat like

    classifier = tf.estimator.Estimator(
        model_fn=model_fn,
        model_dir=model_dir,
        params=params)
    
    train_spec = tf.estimator.TrainSpec(
        input_fn = input_fn,
    )
    
    eval_spec = tf.estimator.EvalSpec(
        input_fn = valid_input_fn,
        throttle_secs=120,
        start_delay_secs=120,
    )
    
    tf.estimator.train_and_evaluate(
        classifier,
        train_spec,
        eval_spec
    )
    
    0 讨论(0)
  • 2020-12-15 09:56

    I have been using SummarySaverHook instead of Monitors. They are not "as powerful" just yet, and the training material has not been updated with a description on how exactly replicate the Monitor functionality.

    Here is how i use it:

    summary_hook = tf.train.SummarySaverHook(
        save_steps=SAVE_EVERY_N_STEPS,
        output_dir='./tmp/rnnStats',
        scaffold=tf.train.Scaffold(),
        summary_op=tf.summary.merge_all())
    
    print("Classifier.train")
    classifier.train(input_fn=train_input_fn, steps=1000, hooks=[summary_hook])
    
    0 讨论(0)
提交回复
热议问题