How to get number of rows, columns /dimensions of tensorflow.data.Dataset?

主宰稳场 提交于 2019-12-12 12:19:59

问题


Like pandas_df.shape is there any way for tensorflow.data.Dataset? Thanks.


回答1:


I'm not familiar with something built-in, but the shapes could be retrieved from Dataset._tensors attribute. Example:

import tensorflow as tf

def dataset_shapes(dataset):
    try:
        return [x.get_shape().as_list() for x in dataset._tensors]
    except TypeError:
        return dataset._tensors.get_shape().as_list()

And usage:

from sklearn.datasets import make_blobs

x_train, y_train = make_blobs(n_samples=10,
                              n_features=2,
                              centers=[[1, 1], [-1, -1]],
                              cluster_std=0.5)
dataset = tf.data.Dataset.from_tensor_slices(x_train)
print(dataset_shapes(dataset)) # [10, 2]

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
print(dataset_shapes(dataset)) # [[10, 2], [10]]



回答2:


To add to Vlad's answer, just in case someone is trying this out for datasets downloaded via tfds, a possible way is to use the dataset information:

info.features['image'].shape # shape of 1 feature in dataset
info.features['label'].num_classes # number of classes
info.splits['train'].num_examples # number of training examples

Eg. tf_flowers :

import tensorflow as tf
import tensorflow_datasets as tfds 

dataset, info = tfds.load("tf_flowers", with_info=True) # download data with info

image_size = info.features['image'].shape # (None, None, 3)
num_classes = info.features['label'].num_classes # 5
data_size = info.splits['train'].num_examples # 3670

Eg. fashion_mnist :

import tensorflow as tf
import tensorflow_datasets as tfds 

dataset, info = tfds.load("fashion_mnist", with_info=True) # download data with info

image_size = info.features['image'].shape # (28, 28, 1)
num_classes = info.features['label'].num_classes # 10
data_splits = {k:v.num_examples for k,v in info.splits.items()} # {'test': 10000, 'train': 60000}

Hope this helps.



来源:https://stackoverflow.com/questions/55618892/how-to-get-number-of-rows-columns-dimensions-of-tensorflow-data-dataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!