Convert PyTorch tensor to python list

前端 未结 1 826
野趣味
野趣味 2021-02-06 22:23

How do I convert a PyTorch Tensor into a python list?

My current use case is to convert a tensor of size [1, 2048, 1, 1] into a list of 2048 el

1条回答
  •  情书的邮戳
    2021-02-06 22:36

    I found Tensor.tolist() which gives the following usage example:

    >>> import torch
    >>> a = torch.randn(2, 2)
    >>> a.tolist()
    [[0.012766935862600803, 0.5415473580360413],
     [-0.08909505605697632, 0.7729271650314331]]
    >>> a[0,0].tolist()
    0.012766935862600803
    

    So, to answer the question, use a.squeeze().tolist() to remove all dimensions of size 1.

    Also consider .flatten() if a list of lists is not desired.


    Before I came across .tolist(), I was using:

    list = [element.item() for element in tensor.flatten()]
    

    This flattens the tensor into a single dimension then calls .item() to convert each element into a Python number.

    0 讨论(0)
提交回复
热议问题