How to get mini-batches in pytorch in a clean and efficient way?

前端 未结 6 1440
北恋
北恋 2021-01-30 08:48

I was trying to do a simple thing which was train a linear model with Stochastic Gradient Descent (SGD) using torch:

import numpy as np

import torch
from torch.         


        
6条回答
  •  借酒劲吻你
    2021-01-30 09:17

    Create a class that is a subclass of torch.utils.data.Dataset and pass it to a torch.utils.data.Dataloader. Below is an example for my project.

    class CandidateDataset(Dataset):
        def __init__(self, x, y):
            self.len = x.shape[0]
            if torch.cuda.is_available():
                device = 'cuda'
            else:
                device = 'cpu'
            self.x_data = torch.as_tensor(x, device=device, dtype=torch.float)
            self.y_data = torch.as_tensor(y, device=device, dtype=torch.long)
    
        def __getitem__(self, index):
            return self.x_data[index], self.y_data[index]
    
        def __len__(self):
            return self.len
    
    def fit(self, candidate_count):
            feature_matrix = np.empty(shape=(candidate_count, 600))
            target_matrix = np.empty(shape=(candidate_count, 1))
            fill_matrices(feature_matrix, target_matrix)
            candidate_ds = CandidateDataset(feature_matrix, target_matrix)
            train_loader = DataLoader(dataset = candidate_ds, batch_size = self.BATCH_SIZE, shuffle = True)
            for epoch in range(self.N_EPOCHS):
                print('starting epoch ' + str(epoch))
                for batch_idx, (inputs, labels) in enumerate(train_loader):
                    print('starting batch ' + str(batch_idx) + ' epoch ' + str(epoch))
                    inputs, labels = Variable(inputs), Variable(labels)
                    self.optimizer.zero_grad()
                    inputs = inputs.view(1, inputs.size()[0], 600)
                    # init hidden with number of rows in input
                    y_pred = self.model(inputs, self.model.initHidden(inputs.size()[1]))
                    labels.squeeze_()
                    # labels should be tensor with batch_size rows. Column the index of the class (0 or 1)
                    loss = self.loss_f(y_pred, labels)
                    loss.backward()
                    self.optimizer.step()
                    print('done batch ' + str(batch_idx) + ' epoch ' + str(epoch))
    

提交回复
热议问题