TensorFlow之SparseTensor对象

三世轮回 提交于 2020-01-29 03:57:48

在TensorFlow中,SparseTensor对象表示稀疏矩阵。SparseTensor对象通过3个稠密矩阵indices, values及dense_shape来表示稀疏矩阵,这三个稠密矩阵的含义介绍如下:

1. indices:数据类型为int64的二维Tensor对象,它的Shape为[N, ndims]。indices保存的是非零值的索引,即稀疏矩阵中除了indices保存的位置之外,其他位置均为0。

2. values:一维Tensor对象,其Shape为[N]。它对应的是稀疏矩阵中indices索引位置中的值。

3. dense_shape:数据类型为int64的一维Tensor对象,其维度为[ndims],用于指定当前稀疏矩阵对应的Shape。

举个例子,稀疏矩阵对象SparseTensor(indices=[[0, 0],[1, 1]], value=[1, 1], dense_shape=[3, 3])对应的矩阵如下:

 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]

函数tf.sparse_tensor_to_dense用于将稀疏矩阵SparseTensor对象转换为稠密矩阵,函数tf.sparse_tensor_to_dense的原型如下:

tf.sparse_tensor_to_dense(
     sp_input,
     default_value=0,
     validate_indices=True,
     name=None
)

各个参数的类型及其含义介绍如下:

sp_input: SparseTensor对象,用于作为转换稠密矩阵的输入。

default_value: 标量类型。稀疏矩阵sp_input中的indices没有指定位置的元素值,默认为0。

validate_indices: bool类型,用于设置是否对索引值按照字典顺序排序。

name: string类型。返回的Tensor对象的名称的前缀。

 

示例代码:

import tensorflow as tf

# 定义Tensor对象
indices_tf = tf.constant([[0, 0], [1, 1]], dtype=tf.int64)
values_tf = tf.constant([1, 2], dtype=tf.float32)
dense_shape_tf = tf.constant([3, 3], dtype=tf.int64)

sparse_tf = tf.SparseTensor(indices=indices_tf,
                            values=values_tf,
                            dense_shape=dense_shape_tf)
dense_tf = tf.sparse_tensor_to_dense(sparse_tf, default_value=0)
with tf.Session() as sess:
    sparse, dense = sess.run([sparse_tf, dense_tf])
    print('sparse:\n', sparse)
    print('dense:\n', dense)

输出:

sparse:
 SparseTensorValue(indices=array([[0, 0],
       [1, 1]], dtype=int64), values=array([1., 2.], dtype=float32), dense_shape=array([3, 3], dtype=int64))

dense:
 [[1. 0. 0.]
 [0. 2. 0.]
 [0. 0. 0.]]

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!