Predictor prediction function to classify text using a saved BERT model

六月ゝ 毕业季﹏ 提交于 2020-07-22 06:41:34

问题


I have created a BERT model for classifying a user generated text string as FAQ or not FAQ. I have saved my model using the export_savedmodel() function. I wish to write a function to predict the output for a new set of strings, which takes as input a list of the strings.

I tried using predictor.from_saved_model() method but that method requires passing key value pairs for input id, segment id, label id and input mask. I am a beginner and I do not understand completely what to pass here.

Exporting or saving the model

def serving_input_fn():
    label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, MAX_SEQ_LENGTH], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

export_dir = "..."
estimator._export_to_tpu = False
estimator.export_savedmodel(export_dir, serving_input_fn)

#Predicting
with tf.Session() as sess:   
    predict_fn = predictor.from_saved_model(...')

#Data description
My data is a simple table having a column for input string and another for output label.

# Error.
ValueError: Got unexpected keys in input_dict: {'pred'}
expected: {'label_ids', 'input_mask', 'segment_ids', 'input_ids'}

#Thank you for any help!

来源:https://stackoverflow.com/questions/56867804/predictor-prediction-function-to-classify-text-using-a-saved-bert-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!