I want to use one of the pre-built keras\' models (vgg, inception, resnet, etc) included in tf.keras.application
for feature extraction to save me some time tra
I am not aware of any available method allowing you to create custom model_fn
from pretrained keras model. An easier way is to use tf.keras.estimator.model_to_estimator()
model = tf.keras.applications.ResNet50(
input_shape=(224, 224, 3),
include_top=False,
pooling='avg',
weights='imagenet')
logits = tf.keras.layers.Dense(10, 'softmax')(model.layers[-1].output)
model = tf.keras.models.Model(model.inputs, logits)
model.compile('adam', 'categorical_crossentropy', ['accuracy'])
# Convert Keras Model to tf.Estimator
estimator = tf.keras.estimator.model_to_estimator(keras_model=model)
estimator.train(input_fn=....)
However, if you would like to create custom model_fn to add more ops (e.g. Summary ops), you can write as following:
import tensorflow as tf
_INIT_WEIGHT = True
def model_fn(features, labels, mode, params):
global _INIT_WEIGHT
# This is important, it allows keras model to update weights
tf.keras.backend.set_learning_phase(mode == tf.estimator.ModeKeys.TRAIN)
model = tf.keras.applications.MobileNet(
input_tensor=features,
include_top=False,
pooling='avg',
weights='imagenet' if _INIT_WEIGHT else None)
# Only init weights on first run
if _INIT_WEIGHT:
_INIT_WEIGHT = False
feature_map = model(features)
logits = tf.keras.layers.Dense(units=params['num_classes'])(feature_map)
# loss
loss = tf.losses.softmax_cross_entropy(labels=labels, logits=logits)
...
You can have only tensors in model_fn
. Maybe you can try something like this. This can be considered as a hack. The better part is that this code apart from just providing model_fn
, it also stores weights of the loaded model as a checkpoint in . This helps you to get the weights when you call estimator.train(...)
or estimator.evaluate(...)
from the checkpoint.
def model_fn(features, labels, mode):
# Import the pretrained model
base_model = tf.keras.applications.InceptionV3(
weights='imagenet',
include_top=False,
input_shape=(200,200,3)
)
# some check
if not hasattr(m, 'optimizer'):
raise ValueError(
'Given keras model has not been compiled yet. '
'Please compile first '
'before creating the estimator.')
# get estimator object from model
keras_estimator_obj = tf.keras.estimator.model_to_estimator(
keras_model=base_model,
model_dir=<model_dir>,
config=<run_config>,
)
# pull model_fn that we need (hack)
return keras_estimator_obj._model_fn