关于tf.one_hot

不羁岁月 提交于 2020-01-29 04:31:06
import tensorflow as tf

labels = [[0,1,1], [1,0,1]]

res = tf.one_hot(labels,2 )
with tf.Session() as sess:
    print(sess.run(res)print(res)

输出为

[[[1. 0.]
  [0. 1.]
  [0. 1.]]

 [[0. 1.]
  [1. 0.]
  [0. 1.]]]
Tensor("one_hot:0", shape=(2, 3, 2), dtype=float32)

里面的2是指长度,一般有几类就写成几就行了

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!