pytorch、tensorflow之生成one-hot向量

浪子不回头ぞ 提交于 2020-02-16 19:23:19
  • 对于标签分类:最后生成的可以是类别标签索引,也可以是one-hot向量(独热编码)

    • 我们举一个五分类的例子:
      • 可以用[3]:表示第三种分类
      • 也可以用one-hot向量[0,0,1,0,0]:表示第三种分类
  • 那么我们接下来用pytorchtensorflow这两个深度学习框架来生成one-hot向量

Pytorch 生成one-hot向量

import torch
from torch.nn import functional
label = torch.tensor([2])  # 2显示的是索引
num_class = 5
label2one_hot = functional.one_hot(label, num_classes=num_class)
print("LongTensor:", label2one_hot)       #LongTensor类型
print("ndarray:", label2one_hot.numpy())  # ndarray 类型
print("list:", label2one_hot.numpy().tolist())  # list 类型

在这里插入图片描述

Tensorflow生成one-hot向量

  • tf.one_hot
one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
	- indices:输入的数值,可以是list,矩阵
	- depth:one-hot的深度,分成几类就有多长
	- on_value:如[1,0,0,0]  1的那位用什么数字表示,默认是1
    - off_value:如[1,0,0,0]  0的那位用什么数字表示,默认是0
   返回 one-hot tensor
import tensorflow as tf
list = [1]   # 可以是list,可以是数组,可以是ndarray类型, 可以是tensor类型
num_class = 5
onehot_1 = tf.one_hot(list, num_class)
onehot_2 = tf.one_hot(list, num_class, on_value=2)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("onehot_1:", sess.run(onehot_1))
    print("onehot_2:", sess.run(onehot_2))
    

在这里插入图片描述

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!