How to compute the cosine_similarity in pytorch for all rows in a matrix with respect to all rows in another matrix

后端 未结 2 1082
谎友^
谎友^ 2021-02-14 13:49

In pytorch, given that I have 2 matrixes how would I compute cosine similarity of all rows in each with all rows in the other.

For example

Given the input =

2条回答
  •  梦如初夏
    2021-02-14 14:20

    Adding eps for numerical stability base on benjaminplanche's answer:

    def sim_matrix(a, b, eps=1e-8):
        """
        added eps for numerical stability
        """
        a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
        a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
        b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
        sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
        return sim_mt
    

提交回复
热议问题