import tensorflow as tf
labels = [[0,1,1], [1,0,1]]
res = tf.one_hot(labels,2 )
with tf.Session() as sess:
print(sess.run(res))
print(res)
输出为
[[[1. 0.]
[0. 1.]
[0. 1.]]
[[0. 1.]
[1. 0.]
[0. 1.]]]
Tensor("one_hot:0", shape=(2, 3, 2), dtype=float32)
里面的2是指长度,一般有几类就写成几就行了
来源:CSDN
作者:weixin_43444314
链接:https://blog.csdn.net/weixin_43444314/article/details/103734172