Pytorch RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select

折月煮酒 提交于 2021-02-11 14:35:05


I am training a model that takes tokenized strings which are then passed through an embedding layer and an LSTM thereafter. However, there seems to be an error in the input, as it does not pass through the embedding layer.

class DrugModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, drug_embed_dim,
            lstm_layer, lstm_dropout, bi_lstm, linear_dropout, char_vocab_size,
            char_embed_dim, char_dropout, dist_fn, learning_rate,
            binary, is_mlp, weight_decay, is_graph, g_layer,
            g_hidden_dim, g_out_dim, g_dropout):

        super(DrugModel, self).__init__()

        # Save model configs
        self.drug_embed_dim = drug_embed_dim
        self.lstm_layer = lstm_layer
        self.char_dropout = char_dropout
        self.dist_fn = dist_fn
        self.binary = binary
        self.is_mlp = is_mlp
        self.is_graph = is_graph
        self.g_layer = g_layer
        self.g_dropout = g_dropout
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # For one-hot encoded SMILES
        if not is_mlp:
            self.char_embed = nn.Embedding(char_vocab_size, char_embed_dim,
            self.lstm = nn.LSTM(char_embed_dim, drug_embed_dim, lstm_layer,
                                batch_first=True, dropout=lstm_dropout)
        # Distance function
        self.dist_fc = nn.Linear(drug_embed_dim, 1)

        if binary:
            # Binary Cross Entropy
            self.criterion = lambda x, y: y*torch.log(x) + (1-y)*torch.log(1-x)

    def init_lstm_h(self, batch_size):
        return (Variable(torch.zeros(
                self.lstm_layer*1, batch_size, self.drug_embed_dim)).cuda(),
                self.lstm_layer*1, batch_size, self.drug_embed_dim)).cuda())

    # Set Siamese network as basic LSTM
    def siamese_sequence(self, inputs, length):
        # Character embedding
        inputs = inputs.long()
        inputs = inputs.cuda()

        self.char_embed = self.char_embed(
        c_embed = self.char_embed(inputs)
        # c_embed = F.dropout(c_embed, self.char_dropout)
        maxlen = inputs.size(1)

        if not
            # Sort c_embed
            _, sort_idx = torch.sort(length, dim=0, descending=True)
            _, unsort_idx = torch.sort(sort_idx, dim=0)
            maxlen = torch.max(length)

            # Pack padded sequence
            c_embed = c_embed.index_select(0, Variable(sort_idx).cuda())
            sorted_len = length.index_select(0, sort_idx).tolist()
            c_packed = pack_padded_sequence(c_embed, sorted_len, batch_first=True)

            c_packed = c_embed

        # Run LSTM
        init_lstm_h = self.init_lstm_h(inputs.size(0))
        lstm_out, states = self.lstm(c_packed, init_lstm_h)

        hidden = torch.transpose(states[0], 0, 1).contiguous().view(
                                 -1, 1 * self.drug_embed_dim)
        if not
            # Unsort hidden states
            outputs = hidden.index_select(0, Variable(unsort_idx).cuda())
            outputs = hidden

        return outputs

    def forward(self, key1, key2, targets, key1_len, key2_len, status, predict = False):
        if not self.is_mlp:
            output1 = self.siamese_sequence(key1, key1_len)
            output2 = self.siamese_sequence(key2, key2_len)

After instantiating the class I get the following error when passing the input through the embedding layer:

<ipython-input-128-432fcc7a1e39> in forward(self, key1, key2, targets, key1_len, key2_len, status, predict)
    129     def forward(self, key1, key2, targets, key1_len, key2_len, status, predict = False):
    130         if not self.is_mlp:
--> 131             output1 = self.siamese_sequence(key1, key1_len)
    132             output2 = self.siamese_sequence(key2, key2_len)
    133             set_trace()

<ipython-input-128-432fcc7a1e39> in siamese_sequence(self, inputs, length)
     74         inputs = inputs.cuda()
---> 76         self.char_embed = self.char_embed(
     77         set_trace()
     78         c_embed = self.char_embed(inputs)

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/ in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/ in forward(self, input)
    112         return F.embedding(
    113             input, self.weight, self.padding_idx, self.max_norm,
--> 114             self.norm_type, self.scale_grad_by_freq, self.sparse)
    116     def extra_repr(self):

~/miniconda3/lib/python3.7/site-packages/torch/nn/ in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1482         # remove once script supports set_grad_enabled
   1483         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1484     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select

despite the fact that the input (e.g. key1) has already been passed to cuda and has been transformed into long format:

tensor([[25, 33, 30,  ...,  0,  0,  0],
        [25,  7,  7,  ...,  0,  0,  0],
        [25,  7, 30,  ...,  0,  0,  0],
        [25,  7, 33,  ...,  0,  0,  0],
        [25, 33, 41,  ...,  0,  0,  0],
        [25, 33, 41,  ...,  0,  0,  0]], device='cuda:0')


setting model.device to cuda does not change your inner module devices, so self.lstm, self.char_embed, and self.dist_fc are all still on cpu. correct way of doing it is by using DrugModel().to(device)

in general, it's better not to feed a device to your model and write it in a device-agnostic way. to make your init_lstm_h function device-agnostic you can use something like this

