-
对于标签分类:最后生成的可以是类别标签索引,也可以是one-hot向量(独热编码)
- 我们举一个五分类的例子:
- 可以用[3]:表示第三种分类
- 也可以用one-hot向量[0,0,1,0,0]:表示第三种分类
- 我们举一个五分类的例子:
-
那么我们接下来用pytorch和tensorflow这两个深度学习框架来生成one-hot向量
Pytorch 生成one-hot向量
import torch
from torch.nn import functional
label = torch.tensor([2]) # 2显示的是索引
num_class = 5
label2one_hot = functional.one_hot(label, num_classes=num_class)
print("LongTensor:", label2one_hot) #LongTensor类型
print("ndarray:", label2one_hot.numpy()) # ndarray 类型
print("list:", label2one_hot.numpy().tolist()) # list 类型
Tensorflow生成one-hot向量
- tf.one_hot
one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
- indices:输入的数值,可以是list,矩阵
- depth:one-hot的深度,分成几类就有多长
- on_value:如[1,0,0,0] 1的那位用什么数字表示,默认是1
- off_value:如[1,0,0,0] 0的那位用什么数字表示,默认是0
返回 one-hot tensor
import tensorflow as tf
list = [1] # 可以是list,可以是数组,可以是ndarray类型, 可以是tensor类型
num_class = 5
onehot_1 = tf.one_hot(list, num_class)
onehot_2 = tf.one_hot(list, num_class, on_value=2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("onehot_1:", sess.run(onehot_1))
print("onehot_2:", sess.run(onehot_2))
来源:CSDN
作者:troublemaker、
链接:https://blog.csdn.net/weixin_44912159/article/details/104344266