tf.datasets input_fn getting error after 1 epoch

南楼画角 提交于 2019-12-13 03:39:54

问题


So I am trying to switch to an input_fn() using tf.datasets as described in this question. While I have been able to get superior steps/sec using tf.datasets with the input_fn() below, I appear to run into an error after 1 epoch when running this experiment on GCMLE. Consider this input_fn():

def input_fn(...):
    files = tf.data.Dataset.list_files(filenames).shuffle(num_shards)

    dataset = files.apply(tf.contrib.data.parallel_interleave(lambda filename: tf.data.TextLineDataset(filename).skip(1), cycle_length=num_shards))
    dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda row:
        parse_csv_dataset(row, hparams = hparams), 
        batch_size = batch_size, 
        num_parallel_batches = multiprocessing.cpu_count())) 
    dataset = dataset.prefetch(1)
    if shuffle:
        dataset = dataset.shuffle(buffer_size = 10000)
    dataset = dataset.repeat(num_epochs)

    iterator = dataset.make_initializable_iterator()
    features = iterator.get_next()
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)

    labels = {key: features.pop(key) for key in LABEL_COLUMNS}

    return features, labels

I receive the following error on GCMLE:

disable=protected-access InvalidArgumentError (see above for traceback): Inputs to operation loss/sparse_softmax_cross_entropy_loss/num_present/Select of type Select must have the same size and shape. Input 0: [74] != input 1: [110] [[Node: loss/sparse_softmax_cross_entropy_loss/num_present/Select = Select[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](loss/sparse_softmax_cross_entropy_loss/num_present/Equal, loss/sparse_softmax_cross_entropy_loss/num_present/zeros_like, loss/sparse_softmax_cross_entropy_loss/num_present/ones_like)]] [[Node: global_step/add/_1509 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_3099_global_step/add", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

This implies that there is a shape mismatch Input 0: [74] != input 1: [110], however my old queue based input_fn() works fine on the same exact data, so I do not believe it is any issue with the underlying data. This is taking place at what I believe to be the end of the epoch (because the num_steps when th GCMLE error ends is right around th num_train_examples/batch_size so I am guessing that the issue might be that the final batch is not equal the batch_size which is 110 (as it shows up in the error) and instead there are only 74 examples. Can anybody confirm that this is the error? Assuming that it is, is there some other flag that I need to set so that the last batch can be something other than the spcified batch size of 110?

For what it's worth, I have replicated this behavior with two different datasets (trains for multiple epochs with the old queue based input_fn, gets hung up at end of first epoch for the tf.datasets input_fn)


回答1:


As Robbie suggests in the other answer, it looks like your old implementation used fixed batch sizes throughout (presumably using an API like tf.train.batch() or one of its wrappers with the default argument of allow_smaller_final_batch=False), and the default behavior of batching in tf.data (via tf.data.Dataset.batch() and tf.contrib.data.map_and_batch()) is to include the smaller final batch.

The bug is most likely in the model_fn. Without seeing that function, it is difficult to guess, but I suspect that there is either an explicit (and incorrect) assertion of a tensor's shape via Tensor.set_shape() (possibly in library code) or a bug in the implementation of tf.losses.sparse_softmax_cross_entropy().

First, I am assuming that the features and labels tensors returned from input_fn() have statically unknown batch size. Can you confirm that by printing the features and labels objects, and ensuring that their reported Tensor.shape properties have None for the 0th dimension?

Next, locate the call to tf.losses.sparse_softmax_cross_entropy() in your model_fn. Print the object that is passed as the weights argument to this function, which should be a tf.Tensor, and locate its static shape. Given the error you are seeing, I suspect it will have a shape like (110,), where 110 is your specified batch size. If that is the case, there is a bug in model_fn that incorrectly asserts that the shape of the weights is a full batch, when it might not be. (If that is not the case, then there's a bug in tf.losses.sparse_softmax_cross_entropy()! Please open a GitHub issue with an example that enables us to reproduce the problem.)

Aside: Why would this explain the bug? The code that calls the failing tf.where() op looks like this (edited for readability):

num_present = tf.where(tf.equal(weights, 0.0),  # This input is shape [74]
                       tf.zeros_like(weights),  # This input is shape [110]
                       tf.ones_like(weights)    # This input is probably [110]
)

This flavor of tf.where() op (named "Select" in the error message for historical reasons) requires that all three inputs have the same size. Superficially, tf.equal(weights, 0.0), tf.ones_like(weights), and tf.zeros_like(weights) all have the same shape, which is the shape of weights. However, if the static shape (the result of Tensor.shape) differs from the dynamic shape, then the behavior is undefined.

What actually happens? In this particular case, let's say the static shape of weights is [110], but the dynamic shape is [74]. The static shape of our three arguments to tf.where() will be [110]. The implementation of tf.equal() doesn't care that there's a mismatch, so its dynamic shape will be [74]. The implementations of tf.zeros_like() and tf.ones_like() use an optimization that ignores that dynamic shape when the static shape is fully defined, and so their dynamic shapes will be [110], causing the error you are seeing.

The proper fix is to locate the code that is asserting a fixed batch size in your model_fn, and remove it. The optimization and evaluation logic in TensorFlow is robust to variable batch sizes, and this will ensure that all of your data is used in the training and evaluation processes.

A less desirable short-term fix would be to drop the small batch at the end of the data. There are a couple of options here:

  • Drop some data randomly at the end of each epoch:

    • With TF 1.8 or later, pass drop_remainder=False to tf.contrib.data.map_and_batch().
    • With TF 1.7 or earlier, use dataset = dataset.filter(lambda features: tf.equal(tf.shape(features[LABEL_COLUMNS[0]])[0], batch_size)) after the map_and_batch.
  • Drop the very last batch of data:

    • Move the dataset.repeat(NUM_EPOCHS) before the map_and_batch() and then apply one of the two fixes mentioned above.



回答2:


It seems that some operation in your graph (from the error message, likely sparse_softmax_cross_entropy_loss), is expecting a fixed batch size. It may be your code (not part of the input_fn) that is enforcing this (e.g. passing batch_size as the shape of some tensor that is used in an op), or it may be one of the TF libraries.

This is not always a problem per se. However, the fact that the documented behavior of tf.data.Dataset.batch is:

NOTE: If the number of elements (N) in this dataset is not an exact multiple of batch_size, the final batch contain smaller tensors with shape N % batch_size in the batch dimension. If your program depends on the batches having the same shape, consider using the tf.contrib.data.batch_and_drop_remainder transformation instead.

As currently written your (non-input_fn) code is in the category of depending on the batch with the same shape.

Your options are to track down where the code is passing through a static batch size or to "drop the remainder". I believe the former is preferable, but more work.

If you choose the latter, note that you are not actually using tf.data.Dataset.batch, but rather tf.contrib.data.map_and_batch which accepts a drop_remainder parameter.



来源:https://stackoverflow.com/questions/50263985/tf-datasets-input-fn-getting-error-after-1-epoch

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!