How do I convert a PyTorch Tensor
into a python list?
My current use case is to convert a tensor of size [1, 2048, 1, 1]
into a list of 2048 el
I found Tensor.tolist() which gives the following usage example:
>>> import torch
>>> a = torch.randn(2, 2)
>>> a.tolist()
[[0.012766935862600803, 0.5415473580360413],
[-0.08909505605697632, 0.7729271650314331]]
>>> a[0,0].tolist()
0.012766935862600803
So, to answer the question, use a.squeeze().tolist()
to remove all dimensions of size 1
.
Also consider .flatten() if a list of lists is not desired.
Before I came across .tolist()
, I was using:
list = [element.item() for element in tensor.flatten()]
This flattens the tensor into a single dimension then calls .item() to convert each element into a Python number.