问题
I would like to ask whether tf.one_hot() function supports SparseTensor as the "indices" parameter. I want to do a multi-label classification (each example has multiple labels) which requires to calculate a cross_entropy loss.
I try to directly put the SparseTensor in the "indices" parameter but it raises the following error:
TypeError: Failed to convert object of type to Tensor. Contents: SparseTensor(indices=Tensor("read_batch_features/fifo_queue_Dequeue:106", shape=(?, 2), dtype=int64, device=/job:worker), values=Tensor("string_to_index_Lookup:0", shape=(?,), dtype=int64, device=/job:worker), dense_shape=Tensor("read_batch_features/fifo_queue_Dequeue:108", shape=(2,), dtype=int64, device=/job:worker)). Consider casting elements to a supported type.
Any suggestion on the possible cause?
Thanks.
回答1:
one_hot does not support a SparseTensor as the indices parameter. You can though pass the sparse tensor's indices / values tensor as the indices parameter, which might solve your problem.
回答2:
You could build up another SparseTensor of shape (batch_size, num_classes)
from the initial SparseTensor. For example if you keep your classes in a single string feature column (separated by spaces), you could use the following:
import tensorflow as tf
all_classes = ["class1", "class2", "class3"]
classes_column = ["class1 class3", "class1 class2", "class2", "class3"]
table = tf.contrib.lookup.index_table_from_tensor(
mapping=tf.constant(all_classes)
)
classes = tf.constant(classes_column)
classes = tf.string_split(classes)
idx = table.lookup(classes) # SparseTensor of shape (4, 2), because each of the 4 rows has at most 2 classes
num_items = tf.cast(tf.shape(idx)[0], tf.int64) # num items in batch
num_entries = tf.shape(idx.indices)[0] # num nonzero entries
y = tf.SparseTensor(
indices=tf.stack([idx.indices[:, 0], idx.values], axis=1),
values=tf.ones(shape=(num_entries,), dtype=tf.int32),
dense_shape=(num_items, len(all_classes)),
)
y = tf.sparse_tensor_to_dense(y, validate_indices=False)
with tf.Session() as sess:
tf.tables_initializer().run()
print(sess.run(y))
# Outputs:
# [[1 0 1]
# [1 1 0]
# [0 1 0]
# [0 0 1]]
Here the idx
is a SparseTensor. The first column of its indices idx.indices[:, 0]
contains the row numbers of the batch, and its values idx.values
contains the index of the relevant class id. We combine these two to create the new y.indices
.
For a full implementation of multi-label classification, see "Option 2" of https://stackoverflow.com/a/47671503/507062
来源:https://stackoverflow.com/questions/45159133/does-tf-one-hot-supports-sparsetensor-as-indices-parameter