问题
This question is somewhat of an extension of Produce balanced mini batch with Dataset API and references the interleave function from the tf.data.Dataset
documentation.
Context:
Suppose you have the following:
- dataset with
n=4
classes - a list of filenames where each file corresponds to a record
- the label for each file
Then we can construct the labeled dataset as follows:
path_ds = tf.data.Dataset.from_tensor_slices(files)
indx_ds = tf.data.Dataset.from_tensor_slices(labels)
ds = tf.data.Dataset.zip((path_ds, indx_ds))
If I wanted to make balanced mini batchs for the n
classes (where n>2
unlike the linked SO question), then:
# assuming class index starts at 0
class_ds = tf.data.Dataset.range(0, n).map(lambda e: tf.cast(e, tf.int32))
ids = class_ds.interleave(
lambda index : filter_for_class(index, ds),
cycle_length=n, block_length=1
)
would result in one example from each class after one another, where:
def filter_for_class(class_index, dataset):
return dataset.filter(lambda path, label: tf.math.equal(label, class_index))
more generally, if b=1
in the above example (block_length
) then:
.interleave(
...,
cycle_length=n*b,
block_length=b
)
would ensure that as long as our mini-batch is divisible by n*b
, we see an even number of classes (so long as there is enough data per class).
So my question is, how could I, using build in tf.data.Dataset
operations produce imbalanced mini-batches.
e.g. suppose if my mini-batch has m
elements, for my n
classes, I want each class in the following ratios:
class_ratios = {
0: 0.6,
1: 0.1,
2: 0.2,
3: 0.1
}
# if m = 100 then 60 examples from class 0, 10 from class 1, 20 from class 2 and 10 from class 3
Some restrictions being that unlike the previously linked question, is that each file is exactly one record and the record label can be extracted from the file / path name.
Note: the above approach will result in imbalanced batches towards the end if data from one class runs out
来源:https://stackoverflow.com/questions/57331048/tensorflow-1-14-make-intentionally-unbalanced-mini-batch-with-dataset-api