Extract target from Tensorflow PrefetchDataset

回眸只為那壹抹淺笑 提交于 2021-02-07 12:59:41

问题


I am still learning tensorflow and keras, and I suspect this question has a very easy answer I'm just missing due to lack of familiarity.

I have a PrefetchDataset object:

> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>

...made up of features and a target. I can iterate over it using a for loop:

> for example in tf_test:
>     print(example[0].numpy())
>     print(example[1].numpy())
>     exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
   [-0.22 -0.54 -0.14 ... 0.33 -0.55]
   [-0.60 -0.02 -1.41 ... 0.21 -0.63]
   ...
   [-0.03 -0.91 -0.12 ... 0.77 -0.23]
   [-0.76 -1.48 -0.15 ... 0.38 -0.35]
   [-0.55 -0.08 -0.69 ... 0.44 -0.36]]
  [0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
   ...
   0 1 1 0]

However, this is very slow. What I'd like to do is access the tensor corresponding to the class labels and turn that into a numpy array, or a list, or any sort of iterable that can be fed into scikit-learn's classification report and/or confusion matrix:

> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
   [0.14]
   [0.00]
   ...
   [0.32]
   [0.03]
   [0.00]]
> y_pred_list = [int(x[0]) for x in y_pred]             # assumes value >= 0.5 is positive prediction
> y_true = []                                           # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)

...OR access the data such that it could be used in tensorflow's confusion matrix:

> labels = []                                           # what I need help with
> predictions = y_pred_list                             # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)

In both cases, the general ability to grab the target data from the original object in a manner that isn't computationally expensive would be very helpful (and might help with my underlying intuitions re: tensorflow and keras).

Any advice would be greatly appreciated.


回答1:


You can convert it to a list with list(ds) and then recompile it as a normal Dataset with tf.data.Dataset.from_tensor_slices(list(ds)). From there your nightmare begins again but at least it's a nightmare that other people have had before.

Note that for more complex datasets (e.g. nested dictionaries) you will need more preprocessing after calling list(ds), but this should work for the example you asked about.

This is far from a satisfying answer but unfortunately the class is entirely undocumented and none of the standard Dataset tricks work.




回答2:


You can turn use map to select either the input or label from every (input, label) pair, and turn this into a list:

import tensorflow as tf
import numpy as np

inputs = np.random.rand(100, 99)
targets = np.random.rand(100)

ds = tf.data.Dataset.from_tensor_slices((inputs, targets))

X_train = list(map(lambda x: x[0], ds))
y_train = list(map(lambda x: x[1], ds))


来源:https://stackoverflow.com/questions/62436302/extract-target-from-tensorflow-prefetchdataset

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!