How to check if two Torch tensors or matrices are equal?

前端 未结 5 1720
一向
一向 2021-02-04 23:50

I need a Torch command that checks if two tensors have the same content, and returns TRUE if they have the same content.

For example:

local tens_a = torc         


        
5条回答
  •  北荒
    北荒 (楼主)
    2021-02-05 00:30

    To compare tensors you can do element wise:

    torch.eq is element wise:

    torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
    tensor([[True, False], [False, True]])
    

    Or torch.equal for the whole tensor exactly:

    torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
    # False
    torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
    # True
    

    But then you may be lost because at some point there are small differences you would like to ignore. For instance floats 1.0 and 1.0000000001 are pretty close and you may consider these are equal. For that kind of comparison you have torch.allclose.

    torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
    # True
    

    At some point may be important to check element wise how many elements are equal, comparing to the full number of elements. If you have two tensors dt1 and dt2 you get number of elements of dt1 as dt1.nelement()

    And with this formula you get the percentage:

    print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())
    

提交回复
热议问题