In python list, we can use list.index(somevalue)
. How can pytorch do this?
For example:
a=[1,2,3]
print(a.index(2))
Can be done by converting to numpy as follows
import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1., 2., 3., 4.])
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2
I think there is no direct translation from list.index()
to a pytorch function. However, you can achieve similar results using tensor==number
and then the nonzero()
function. For example:
t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero())
This piece of code returns
1
[torch.LongTensor of size 1x1]
For floating point tensors, I use this to get the index of the element in the tensor.
print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.
print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())
for finding index of an element in 1d tensor/array Example
mat=torch.tensor([1,8,5,3])
to find index of 5
five=5
numb_of_col=4
for o in range(numb_of_col):
if mat[o]==five:
print(torch.tensor([o]))
To find element index of a 2d/3d tensor covert it into 1d #ie example.view(number of elements)
Example
mat=torch.tensor([[1,2],[4,3])
#to find index of 2
five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
if mat[o] == five:
print(torch.tensor([o]))
import torch
x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
print(x_data.data[0])
>>tensor([1.])