问题
I'm using the module bert-for-tf2
in order to wrap BERT model as Keras layer in Tensorflow 2.0 I've followed your guide for implementing BERT model as Keras layer.
I'm trying to extract embeddings from a sentence; in my case, the sentence is "Hello"
I have a question about the output of the model prediction; I've written this model:
model_word_embedding = tf.keras.Sequential([
tf.keras.layers.Input(shape=(4,), dtype='int32', name='input_ids'),
bert_layer
])
model_word_embedding .build(input_shape=(None, 4))
Then I want to extract the embeddings for the sentence written above:
sentences = ["Hello"]
predict = model_word_embedding .predict(sentences)
the object predict contains 4 arrays of 768 elements each:
print(predict)
print(len(predict))
print(len(predict[0][0]))
...
[[[-0.02768866 -0.7341324 1.9084396 ... -0.65953904 0.26496622
1.1610721 ]
[-0.19322394 -1.3134469 0.10383344 ... 1.1250225 -0.2988368
-0.2323082 ]
[-1.4576151 -1.4579685 0.78580517 ... -0.8898649 -1.1016986
0.6008501 ]
[ 1.41647 -0.92478925 -1.3651332 ... -0.9197768 -1.5469263
0.03305872]]]
4
768
I know that each array of that 4 represents my original sentence, but I want to obtain one array as the embeddings of my original sentence. So, my question is: How can I obtain the embeddings for a sentence?
In BERT source code I read this:
For classification tasks, the first vector (corresponding to [CLS]) is used as the "sentence vector." Note that this only makes sense because the entire model is fine-tuned.
So I have to take the first array from the prediction output since it represents my sentence vector?
Thank you for your support
回答1:
We should use [CLS] from the last hidden states as the sentence embeddings from BERT. According to the BERT paper [CLS] represent the encoded sentence of dimension 768. Following figure represents the use of [CLS] in more details. considering you have 2000 sentences.
#input_ids consist of all sentences padded to max_len.
last_hidden_states = model(input_ids)
features = last_hidden_states[0][:,0,:].numpy() # considering o only the [CLS] for each sentences
features.shape
# (2000, 768) dimension
来源:https://stackoverflow.com/questions/59330597/bert-sentence-embeddings-how-to-obtain-sentence-embeddings-vector