How do I create padded batches in Tensorflow for tf.train.SequenceExample data using the DataSet API?

后端 未结 4 545
余生分开走
余生分开走 2020-12-31 02:53

For training an LSTM model in Tensorflow, I have structured my data into a tf.train.SequenceExample format and stored it i

4条回答
  •  野趣味
    野趣味 (楼主)
    2020-12-31 03:23

    You need to pass a tuple of shapes. In your case you should pass

    dataset = dataset.padded_batch(4, padded_shapes=([vectorSize],[None]))
    

    or try

    dataset = dataset.padded_batch(4, padded_shapes=([None],[None]))
    

    Check this code for more details. I had to debug this method to figure out why it wasn't working for me.

提交回复
热议问题