Pytorch数据类型
IntTensor
FloatTensor
需要注意,没有string对应的数据类型,可以用one-hot或者embedding支持。
放置在GPU上时,数据类型存在区别,如:
torch.FloatTensor (CPU) 与 torch.cuda.FloatTensor (GPU)
isinstance检验合法性
data = data.cuda() 可以搬运到GPU
torch.tensor(1.0) 表示标量
常量是0维向量。
.shape是成员,.size() 是方法。
dim = 1
torch.tensor([1.1]) 直接指定数据
torch.FloatTensor(2) 指定维度,dim=1,size=2
常用于bias的表示。
dim = 2 的样例:a = torch.randn(2,3)
各个维度的区分:
如果是两行两列的tensor,则 dim() = 2; size() / shape = [2,2] ; a.numel() = 4 dim即size()的长度。
创建Tensor
torch.from_numpy(a) 从numpy直接导入。
torch.tensor([…]) 从List导入,tensor只接受现有数据,大写的Tensor等可以接受数据或者维度
torch.empty() 未初始化空间
torch.Tensor() 将生成设置的默认类型
随机初始化
torch.rand() [0,1]均匀分布初始化
torch.randint(1,10,…) 指定最小值与最大值
torch.randn() 服从N(0,1)的正态分布
torch.normal() 指定均值与方差,比较复杂
torch.full( [size], val ) 全部一致,类型默认
torch.arange(0,10,2) 左闭右开,2为阶梯
torch.linespace(0,10,steps = 4) 指定数量值
torch.ones()
torch.zeros()
torch.eye()
torch.randperm(2) 得到索引值的shuffle
索引与切片
a[:2] 取前两个元素
a[ :2, :1, : , :] 前2个元素,第1个通道,所有元素
a[0:28:2] 隔行采样
a.index_select(0,[0,2]) 对0维度的第0、2个元素进行采样
a[:,1,…] …表示任意长
.masked_select()
torch.take(src,torch.tensor([…])) 先打平,再取元素
维度变换
.view() && .reshape() 参数为新的维度信息,需要保证维度正确
.unsqueeze() 参数为一个索引值,在该处添加一个维度,如0,-1,3,该值为插入后新维度的位置
.squeeze(idx) 维度挤压,如果缺省参数,则删减所有维度,变成一维
.expand() 广播,不主动复制数据,将【1】变成【4】,参数是同dim的,并且只能改变原shape为1的地方,可以用-1代替不想变的部分
.repeat() 主动拷贝内存,参数为需要拷贝的次数,不需要原来为1
.t() 转置,只能适用于2维矩阵
.transpose() 参数为需交换的维度,使用之后需要重新使数据连续,.contiguous()
第一种用法错误,会产生维度错乱
a.permute(0,2,3,1) 改变维度信息,参数为新维度的原索引,同样需要重新使数据连续
Broadcast
实现不拷贝数据的维度扩展,会在前面插入维度,完成后从后面的维度开始拓展。
使用需要满足小维度匹配,可全部指定,或者为1个元素,不能有歧义。
拼接与拆分
torch.cat([a,b],dim=0) 按照指定维度合并,确保其余维度相同
torch.stack([a,b],dim=2) 会在dim处插入一个新的维度,要求原维度完全一致
a,b = c.split(1,dim=0) 在c的0维度处,按1的长度拆分,这是按长度来拆分
a,b = c.split([2,1],dim=0) 在c的0维度处,按指定的长度拆分,两种做法都得保证结果的个数的正确性
a,b = c.chunk(2,dim=0) 在0维度拆分成2个,按数量来拆分
数学运算
+ - /
torch.add(), sub() mul() div()
* 是对应元素相乘,element_wise
@ 或 torch.matmul() 表示矩阵相乘
torch.mm() 矩阵相乘,但只适用于二维向量
对于大于二维的矩阵,只会对最后二维进行乘法,高维度需满足broadcast 原理上等同于同时多个进行计算
a.pow(2) 或 a xx (乘号)2
a.sqrt()
a.rsqrt() 先平方根,再取倒数
torch.exp(a) 元素级取指数
torch.log(a) 其中的log可以换成log2等
a.floor(), a.ceil()
a.trunc() 裁剪出整数部分
a.frac() 裁剪出小数部分
a.round() 四舍五入
a.clamp(10) 将<10的变为10
a.clamp(0,10) 调整为该区间
统计属性
a.norm(2) 求2范数,可以用dim= 指定维度,则在该维度上进行范数求解,其余维度不变
a.min(), a.max(), a.mean()
a.prod() 元素累乘
a.sum()
a.argmax(), a.argmin() 最大/小元素的索引,返回打平为1维,也可以指定维度
上图返回每个第一维上的最大值索引
dim与keepdim
keepdim消除统计信息所产生的维度消失的问题
a.topk(3,dim=1) 返回top值与索引,比max返回更多的数据,可以使用参数largest = False 来求最小topk
a.kthvalue(8,dim=1) 第8小的,只能设置为小
** >, torch.gt(), !=** 等都是元素级
torch.eq(a,b) 返回的都是0,1
torch.equal(a,b) 返回的T/F
高级操作
torch.where(condition,x,y) -> Tensor 要求shape相同,为1取x,否则取y
来源:CSDN
作者:bit_codertoo
链接:https://blog.csdn.net/bit_codertoo/article/details/103588196