Is there some way to save best model only with tensorflow.estimator.train_and_evaluate()?

前端 未结 3 1718
耶瑟儿~
耶瑟儿~ 2021-02-09 10:46

I try retrain TF Object Detection API model from checkpoint with already .config file for training pipeline with tf.estimator.train_and_evaluate() method like in models/research

相关标签:
3条回答
  • 2021-02-09 11:07

    If you are training using the models repo of tensorflow/models. models/research/object_detection/model_lib.py file create_train_and_eval_specs function can be modified to include the best exporter:

    final_exporter = tf.estimator.FinalExporter(
        name=final_exporter_name, serving_input_receiver_fn=predict_input_fn)
    
    best_exporter = tf.estimator.BestExporter(
        name="best_exporter",
        serving_input_receiver_fn=predict_input_fn,
        event_file_pattern='eval_eval/*.tfevents.*',
        exports_to_keep=5)
    exporters = [final_exporter, best_exporter]
    
    train_spec = tf.estimator.TrainSpec(
        input_fn=train_input_fn, max_steps=train_steps)
    
    eval_specs = [
        tf.estimator.EvalSpec(
            name=eval_spec_name,
            input_fn=eval_input_fn,
            steps=eval_steps,
            exporters=exporters)
    ]
    
    0 讨论(0)
  • 2021-02-09 11:13

    You can try using BestExporter. As far as I know, it's the only option for what you're trying to do.

    exporter = tf.estimator.BestExporter(
          compare_fn=_loss_smaller,
          exports_to_keep=5)
    
    eval_spec = tf.estimator.EvalSpec(
        input_fn,
        steps,
        exporters)
    

    https://www.tensorflow.org/api_docs/python/tf/estimator/BestExporter

    0 讨论(0)
  • 2021-02-09 11:14

    I have been using https://github.com/bluecamel/best_checkpoint_copier which works well for me.

    Example:

    best_copier = BestCheckpointCopier(
       name='best', # directory within model directory to copy checkpoints to
       checkpoints_to_keep=10, # number of checkpoints to keep
       score_metric='metrics/total_loss', # metric to use to determine "best"
       compare_fn=lambda x,y: x.score < y.score, # comparison function used to determine "best" checkpoint (x is the current checkpoint; y is the previously copied checkpoint with the highest/worst score)
       sort_key_fn=lambda x: x.score,
       sort_reverse=False) # sort order when discarding excess checkpoints
    

    pass it to your eval_spec:

    eval_spec = tf.estimator.EvalSpec(
       ...
       exporters=best_copier,
       ...)
    
    0 讨论(0)
提交回复
热议问题