How to get mini-batches in pytorch in a clean and efficient way?

前端 未结 6 1438
北恋
北恋 2021-01-30 08:48

I was trying to do a simple thing which was train a linear model with Stochastic Gradient Descent (SGD) using torch:

import numpy as np

import torch
from torch.         


        
6条回答
  •  后悔当初
    2021-01-30 09:23

    Not sure what you were trying to do. W.r.t. batching you wouldn't have to convert to numpy. You could just use index_select() , e.g.:

    for epoch in range(500):
        k=0
        loss = 0
        while k < X_mdl.size(0):
    
            random_batch = [0]*5
            for i in range(k,k+M):
                random_batch[i] = np.random.choice(N-1)
            random_batch = torch.LongTensor(random_batch)
            batch_xs = X_mdl.index_select(0, random_batch)
            batch_ys = y.index_select(0, random_batch)
    
            # Forward pass: compute predicted y using operations on Variables
            y_pred = batch_xs.mul(W)
            # etc..
    

    The rest of the code would have to be changed as well though.


    My guess, you would like to create a get_batch function that concatenates your X tensors and Y tensors. Something like:

    def make_batch(list_of_tensors):
        X, y = list_of_tensors[0]
        # may need to unsqueeze X and y to get right dimensions
        for i, (sample, label) in enumerate(list_of_tensors[1:]):
            X = torch.cat((X, sample), dim=0)
            y = torch.cat((y, label), dim=0)
        return X, y
    

    Then during training you select, e.g. max_batch_size = 32, examples through slicing.

    for epoch:
      X, y = make_batch(list_of_tensors)
      X = Variable(X, requires_grad=False)
      y = Variable(y, requires_grad=False)
    
      k = 0   
       while k < X.size(0):
         inputs = X[k:k+max_batch_size,:]
         labels = y[k:k+max_batch_size,:]
         # some computation
         k+= max_batch_size
    

提交回复
热议问题