Tensorflow: next_batch function of np array

前端 未结 1 515
暗喜
暗喜 2021-01-26 03:32

I have train data as

xTrain = numpy.asarray([100, 1, 5, 6 ...])
yTrain = numpy.asarray([200, 2, 10, 12 ...])

How to define next_batch(size) me

相关标签:
1条回答
  • 2021-01-26 04:00

    You can use this as your next batch function:

    def batch_data(source, target, batch_size):
    
       # Shuffle data
       shuffle_indices = np.random.permutation(np.arange(len(target)))
       source = source[shuffle_indices]
       target = target[shuffle_indices]
    
       for batch_i in range(0, len(source)//batch_size):
          start_i = batch_i * batch_size
          source_batch = source[start_i:start_i + batch_size]
          target_batch = target[start_i:start_i + batch_size]
    
          yield np.array(source_batch), np.array(target_batch)
    
    0 讨论(0)
提交回复
热议问题