Tensorflow 1.14+: Make intentionally unbalanced mini batch with Dataset API

白昼怎懂夜的黑 提交于 2019-12-24 18:31:07

问题


This question is somewhat of an extension of Produce balanced mini batch with Dataset API and references the interleave function from the tf.data.Dataset documentation.

Context:

Suppose you have the following:

  • dataset with n=4 classes
  • a list of filenames where each file corresponds to a record
  • the label for each file

Then we can construct the labeled dataset as follows:

path_ds = tf.data.Dataset.from_tensor_slices(files)
indx_ds = tf.data.Dataset.from_tensor_slices(labels)
ds = tf.data.Dataset.zip((path_ds, indx_ds))

If I wanted to make balanced mini batchs for the n classes (where n>2 unlike the linked SO question), then:

# assuming class index starts at 0
class_ds = tf.data.Dataset.range(0, n).map(lambda e: tf.cast(e, tf.int32))

ids = class_ds.interleave(
    lambda index : filter_for_class(index, ds),
    cycle_length=n, block_length=1
)

would result in one example from each class after one another, where:

def filter_for_class(class_index, dataset):
    return dataset.filter(lambda path, label: tf.math.equal(label, class_index))

more generally, if b=1 in the above example (block_length) then:

.interleave( 
   ...,
   cycle_length=n*b, 
   block_length=b
)

would ensure that as long as our mini-batch is divisible by n*b, we see an even number of classes (so long as there is enough data per class).

So my question is, how could I, using build in tf.data.Dataset operations produce imbalanced mini-batches.

e.g. suppose if my mini-batch has m elements, for my n classes, I want each class in the following ratios:

class_ratios = {
    0: 0.6,
    1: 0.1,
    2: 0.2,
    3: 0.1
}

# if m = 100 then 60 examples from class 0, 10 from class 1, 20 from class 2 and 10 from class 3

Some restrictions being that unlike the previously linked question, is that each file is exactly one record and the record label can be extracted from the file / path name.

Note: the above approach will result in imbalanced batches towards the end if data from one class runs out

来源:https://stackoverflow.com/questions/57331048/tensorflow-1-14-make-intentionally-unbalanced-mini-batch-with-dataset-api

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!