Using a Keras model inside a TF estimator

后端 未结 2 553
醉梦人生
醉梦人生 2020-12-30 16:41

I want to use one of the pre-built keras\' models (vgg, inception, resnet, etc) included in tf.keras.application for feature extraction to save me some time tra

相关标签:
2条回答
  • 2020-12-30 16:46

    I am not aware of any available method allowing you to create custom model_fn from pretrained keras model. An easier way is to use tf.keras.estimator.model_to_estimator()

    model = tf.keras.applications.ResNet50(
        input_shape=(224, 224, 3),
        include_top=False,
        pooling='avg',
        weights='imagenet')
    logits =  tf.keras.layers.Dense(10, 'softmax')(model.layers[-1].output)
    model = tf.keras.models.Model(model.inputs, logits)
    model.compile('adam', 'categorical_crossentropy', ['accuracy'])
    
    # Convert Keras Model to tf.Estimator
    estimator = tf.keras.estimator.model_to_estimator(keras_model=model)
    estimator.train(input_fn=....)
    

    However, if you would like to create custom model_fn to add more ops (e.g. Summary ops), you can write as following:

    import tensorflow as tf
    
    _INIT_WEIGHT = True
    
    def model_fn(features, labels, mode, params):
      global _INIT_WEIGHT
    
      # This is important, it allows keras model to update weights
      tf.keras.backend.set_learning_phase(mode == tf.estimator.ModeKeys.TRAIN)
    
      model = tf.keras.applications.MobileNet(
          input_tensor=features,
          include_top=False,
          pooling='avg',
          weights='imagenet' if _INIT_WEIGHT else None)
    
      # Only init weights on first run
      if _INIT_WEIGHT:
        _INIT_WEIGHT = False
    
      feature_map = model(features)
      logits = tf.keras.layers.Dense(units=params['num_classes'])(feature_map)
    
      # loss
      loss = tf.losses.softmax_cross_entropy(labels=labels, logits=logits)
      ...
    
    0 讨论(0)
  • 2020-12-30 16:47

    You can have only tensors in model_fn. Maybe you can try something like this. This can be considered as a hack. The better part is that this code apart from just providing model_fn, it also stores weights of the loaded model as a checkpoint in . This helps you to get the weights when you call estimator.train(...) or estimator.evaluate(...) from the checkpoint.

    def model_fn(features, labels, mode):  
    
        # Import the pretrained model
        base_model = tf.keras.applications.InceptionV3(
            weights='imagenet', 
            include_top=False,
            input_shape=(200,200,3)
        )    
    
        # some check
        if not hasattr(m, 'optimizer'):
            raise ValueError(
                'Given keras model has not been compiled yet. '
                'Please compile first '
                'before creating the estimator.')
    
        # get estimator object from model
        keras_estimator_obj = tf.keras.estimator.model_to_estimator(
            keras_model=base_model,
            model_dir=<model_dir>,
            config=<run_config>,
        ) 
    
        # pull model_fn that we need (hack)
        return keras_estimator_obj._model_fn
    
    0 讨论(0)
提交回复
热议问题