问题
The model_fn
for custom estimator which I have built is as shown below,
def _model_fn(features, labels, mode):
"""
Mask RCNN Model function
"""
self.keras_model = self.build_graph(mode, config)
outputs = self.keras_model(features) # ERROR STATEMENT
# outputs = self.keras_model(list(features.values())) # Same ERROR with this statement
# Predictions
if mode == tf.estimator.ModeKeys.PREDICT:
... # Defining Prediction Spec
# Training
if mode == tf.estimator.ModeKeys.TRAIN:
# Defining Loss and Training Spec
...
# Evaluation
...
The _model_fn()
receives arguments features
and labels
from tf.data
in form:
features = {
'a' : (batch_size, h, w, 3) # dtype: float
'b' : (batch_size, n) # # dtype: float
}
# And
labels = []
The self.keras_model
is built using tensorflow.keras.models.Model
API with Input placeholders (defined using layer tensorflow.keras.layers.Input()
) of name 'a'
and 'b'
for respective shapes.
After running the estimator using train_and_evaluate()
the _model_fn
is running fine. The graph is initialized, but when the training starts I'm facing the following issue:
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'a' with dtype float and shape [?,128,128,3] [[{{node a}}]]
I have worked with custom estimators before, this the first time using tensorflow.keras.models.Model
API inside the _model_fn
to compute the graph.
回答1:
This problem occurs only with this particular model (Mask-RCNN). To overcome this problem slight modifications can be made in method self.build_graph(mode, config)
as follows:
def build_graph(mode, config):
# For Input placeholder definition
a = KL.Input(tensor=features['a'])
# Earlier
# a = KL.Input(shape=[batch_size, h, w, 3], name='a')
b = KL.Input(tensor=features['b'])
# Earlier
# b = KL.Input(shape=[batch_size, n], name='b')
...
...
These modifications wraps the feature tensor directly into tensorflow.keras.layers.Input()
. Which can be later used to define input arguments while defining Model using tensorflow.keras.models.Model
.
来源:https://stackoverflow.com/questions/59046447/invalid-argument-error-while-using-keras-model-api-inside-an-estimator-model-fn