In python list, we can use list.index(somevalue)
. How can pytorch do this?
For example:
a=[1,2,3]
print(a.index(2))
For floating point tensors, I use this to get the index of the element in the tensor.
print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.
print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())