How Pytorch Tensor get the index of specific value

前端 未结 5 1440
谎友^
谎友^ 2020-12-14 15:34

In python list, we can use list.index(somevalue). How can pytorch do this?
For example:

    a=[1,2,3]
    print(a.index(2))
相关标签:
5条回答
  • 2020-12-14 15:41

    Can be done by converting to numpy as follows

    import torch
    x = torch.range(1,4)
    print(x)
    ===> tensor([ 1.,  2.,  3.,  4.]) 
    nx = x.numpy()
    np.where(nx == 3)[0][0]
    ===> 2
    
    0 讨论(0)
  • 2020-12-14 15:42

    I think there is no direct translation from list.index() to a pytorch function. However, you can achieve similar results using tensor==number and then the nonzero() function. For example:

    t = torch.Tensor([1, 2, 3])
    print ((t == 2).nonzero())
    

    This piece of code returns

    1

    [torch.LongTensor of size 1x1]

    0 讨论(0)
  • 2020-12-14 15:49

    For floating point tensors, I use this to get the index of the element in the tensor.

    print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
    

    Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.

    print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())
    
    0 讨论(0)
  • 2020-12-14 15:58

    for finding index of an element in 1d tensor/array Example

    mat=torch.tensor([1,8,5,3])
    

    to find index of 5

    five=5
    
    numb_of_col=4
    for o in range(numb_of_col):
       if mat[o]==five:
         print(torch.tensor([o]))
    

    To find element index of a 2d/3d tensor covert it into 1d #ie example.view(number of elements)

    Example

    mat=torch.tensor([[1,2],[4,3])
    #to find index of 2
    
    five = 2
    mat=mat.view(4)
    numb_of_col = 4
    for o in range(numb_of_col):
       if mat[o] == five:
         print(torch.tensor([o]))    
    
    0 讨论(0)
  • 2020-12-14 16:02
        import torch
        x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
        print(x_data.data[0])
        >>tensor([1.])
    
    0 讨论(0)
提交回复
热议问题