Cosine distance of arrays of different shapes in Python?

為{幸葍}努か 提交于 2020-07-10 06:52:34

问题


I asked this question How is the cosine distance calculated for two arrays with different shapes in Python? yesterday and got a precise answer. However, I tried using the same function to get the cosine distance between xx and yy below as a list, row wise Cosine distance of a row in xx to a row in yy. They are arrays of different shapes. What would be the solution. I must be missing something from the earlier example:-

xx = np.array([[ 8.07105800e-01,  2.87828956e-01],
       [ 8.97970426e-01,  3.27145176e-01],
       [ 8.97970426e-01,  3.27145176e-01],
       [ 8.97970426e-01,  3.27145176e-01],
       [ 2.00071147e-01,  2.59049753e-02],
       [ 8.97970426e-01,  3.27145176e-01],
       [ 5.83740391e-01,  1.84404675e-01],
       [ 1.97769348e-03, -5.97115379e-02],
       [ 2.72919656e-03, -6.57503753e-02],
       [-8.47442660e-04, -9.16800956e-02],
       [ 1.23314809e-02, -1.53661279e-01],
       [ 1.40920559e-03, -1.80723646e-01]])
yy = np.array([[8.44094478e-03,  3.35563887e-03,  6.07153217e-18,  2.44468631e-02,
        3.03576608e-18,  5.76097014e-03,  2.60208521e-18, -1.73472348e-18,
        4.04756763e-03,  3.75670839e-02,  2.64618673e-03,  2.60208521e-18,
        3.15227439e-03,  4.68980598e-03,  6.99511952e-03, -8.67361738e-19,
        3.95149202e-03,  9.08766600e-03, -1.90819582e-17,  3.46944695e-18,
        4.64043525e-03, -5.63785130e-18,  5.38781640e-03,  3.95149202e-03,
        6.16573618e-03, -2.60208521e-18,  2.87670232e-03, -2.16840434e-18,
        5.46844915e-03,  1.40806489e-02,  6.71015253e-03,  2.70077046e-03,
        3.92620082e-03, -1.30104261e-18,  7.35027324e-03, -5.20417043e-18,
        1.96262984e-02, -3.46944695e-18,  4.17690832e-03,  3.42000191e-03,
       -5.20417043e-18,  6.31808776e-03,  6.63353766e-03,  2.64618673e-03,
        2.64618673e-03, -3.03576608e-18,  1.73472348e-18,  3.95149202e-03,
       -5.63785130e-18,  3.08286809e-03,  1.30104261e-18,  6.32236500e-03,
        4.50714698e-03,  3.23573661e-02,  4.15821958e-03,  2.83429102e-03,
        3.08286809e-03,  3.18612020e-02,  4.10100115e-03, -1.30104261e-18,
        8.67361738e-19,  3.33684303e-02,  3.95149202e-03,  5.76097014e-03,
        5.74691221e-03, -5.20417043e-18,  4.29728958e-03, -2.60208521e-18,
       -6.07153217e-18,  2.70077046e-03,  4.71137641e-03, -8.67361738e-19,
        4.17690832e-03,  2.91338264e-03,  5.38781640e-03,  8.64510290e-03,
        3.86036578e-03,  9.89490959e-03,  5.76097014e-03,  1.37527752e-02,
        6.42464160e-03,  7.32098404e-03,  4.71137641e-03,  3.25127413e-03,
        4.60430401e-03,  4.71137641e-03,  3.42000191e-03, -4.33680869e-18,
        6.07153217e-18, -3.03576608e-18,  6.30454878e-03,  4.34769443e-03]])

I'll be glad looking at the solution.


回答1:


def cosine_similarity_vec(x, y):
    numerator_vec = (x[:, np.newaxis]*y).sum(axis=2).reshape(-1)
    denominator_vec = np.sqrt(
        np.einsum('ij, ij->i', x, x)[:, np.newaxis] * np.einsum('ij, ij->i', y, y)).reshape(-1)
    return numerator_vec / denominator_vec

Data

a = np.array(([7.26741212e-01, -9.80825232e-17], ))
b = np.array(([-3.82390578e-01, -1.48157964e-17],
              [-3.82390578e-01,  7.87310307e-01],
              [7.26741212e-01, -9.80825232e-17],
              [7.26741212e-01, -9.80825232e-17],
              [-3.82390578e-01, -2.06286905e-01],
              [7.26741212e-01, -9.80825232e-17],
              [-2.16887107e-01,  6.84509305e-17],
              [-3.82390578e-01, -5.81023402e-01],
              [-2.16887107e-01,  6.84509305e-17],
              [-2.16887107e-01,  6.84509305e-17]))

Output:

cosine_similarity_vect(a, b)                                                                                  
 
array([[-1.        , -0.43688798,  1.        ,  1.        , -0.88010163,
         1.        , -1.        , -0.5497553 , -1.        , -1.        ]])

Data:

xx = np.array([[8.07105800e-01,  2.87828956e-01],
               [8.97970426e-01,  3.27145176e-01],
               [8.97970426e-01,  3.27145176e-01],
               [8.97970426e-01,  3.27145176e-01],
               [2.00071147e-01,  2.59049753e-02],
               [8.97970426e-01,  3.27145176e-01],
               [5.83740391e-01,  1.84404675e-01],
               [1.97769348e-03, -5.97115379e-02],
               [2.72919656e-03, -6.57503753e-02],
               [-8.47442660e-04, -9.16800956e-02],
               [1.23314809e-02, -1.53661279e-01],
               [1.40920559e-03, -1.80723646e-01]])

yy = np.array([[8.44094478e-03,  3.35563887e-03],
            [6.07153217e-18,  2.44468631e-02],
            [3.03576608e-18,  5.76097014e-03],
            [2.60208521e-18, -1.73472348e-18],
            [4.04756763e-03,  3.75670839e-02],
            [2.64618673e-03,  2.60208521e-18],
            [3.15227439e-03,  4.68980598e-03],
            [6.99511952e-03, -8.67361738e-19],
            [3.95149202e-03,  9.08766600e-03],
            [-1.90819582e-17,  3.46944695e-18],
            [4.64043525e-03, -5.63785130e-18],
            [5.38781640e-03,  3.95149202e-03],
            [6.16573618e-03, -2.60208521e-18],
            [2.87670232e-03, -2.16840434e-18],
            [5.46844915e-03,  1.40806489e-02],
            [6.71015253e-03,  2.70077046e-03],
            [3.92620082e-03, -1.30104261e-18],
            [7.35027324e-03, -5.20417043e-18],
            [1.96262984e-02, -3.46944695e-18],
            [4.17690832e-03,  3.42000191e-03],
            [-5.20417043e-18,  6.31808776e-03],
            [6.63353766e-03,  2.64618673e-03],
            [2.64618673e-03, -3.03576608e-18],
            [1.73472348e-18,  3.95149202e-03],
            [-5.63785130e-18,  3.08286809e-03],
            [1.30104261e-18,  6.32236500e-03],
            [4.50714698e-03,  3.23573661e-02],
            [4.15821958e-03,  2.83429102e-03],
            [3.08286809e-03,  3.18612020e-02],
            [4.10100115e-03, -1.30104261e-18],
            [8.67361738e-19,  3.33684303e-02],
            [3.95149202e-03,  5.76097014e-03],
            [5.74691221e-03, -5.20417043e-18],
            [4.29728958e-03, -2.60208521e-18],
            [-6.07153217e-18,  2.70077046e-03],
            [4.71137641e-03, -8.67361738e-19],
            [4.17690832e-03,  2.91338264e-03],
            [5.38781640e-03,  8.64510290e-03],
            [3.86036578e-03,  9.89490959e-03],
            [5.76097014e-03,  1.37527752e-02],
            [6.42464160e-03,  7.32098404e-03],
            [4.71137641e-03,  3.25127413e-03],
            [4.60430401e-03,  4.71137641e-03],
            [3.42000191e-03, -4.33680869e-18],
            [6.07153217e-18, -3.03576608e-18],
            [6.30454878e-03,  4.34769443e-03]])

Output

cosine_similarity_vec(xx, yy)                                                                                 
 
array([ 0.99935826,  0.33589844,  0.33589844,  0.59738375,  0.43486405,
        0.94189821,  0.80421365,  0.94189821,  0.6836243 , -0.86661797,
        0.94189821,  0.95817482,  0.94189821,  0.94189821,  0.65410303,
        0.99919641,  0.94189821,  0.94189821,  0.94189821,  0.9415704 ,
        0.33589844,  0.99931529,  0.94189821,  0.33589844,  0.33589844,
        0.33589844,  0.46263151,  0.96748169,  0.42505076,  0.94189821,
        0.33589844,  0.80977235,  0.94189821,  0.94189821,  0.33589844,
        0.94189821,  0.96470238,  0.78325138,  0.6552651 ,  0.6737319 ,
        0.87374036,  0.96600713,  0.89855416,  0.94189821,  0.69224102,
        0.96609076,  0.99957909,  0.34230717,  0.34230717,  0.5919067 ,
        0.44098844,  0.9395881 ,  0.80824383,  0.9395881 ,  0.68858031,
       -0.86319869,  0.9395881 ,  0.96010216,  0.9395881 ,  0.9395881 ,
        0.65924074,  0.99944627,  0.9395881 ,  0.9395881 ,  0.9395881 ,
        0.94384306,  0.34230717,  0.99954416,  0.9395881 ,  0.34230717,
        0.34230717,  0.34230717,  0.46866025,  0.96918236,  0.43120721,
        0.9395881 ,  0.34230717,  0.81375066,  0.9395881 ,  0.9395881 ,
        0.34230717,  0.9395881 ,  0.96647397,  0.78746847,  0.66039593,
        0.67875041,  0.87703355,  0.96774581,  0.90152299,  0.9395881 ,
        0.68730873,  0.96782731,  0.99957909,  0.34230717,  0.34230717,
        0.5919067 ,  0.44098844,  0.9395881 ,  0.80824383,  0.9395881 ,
        0.68858031, -0.86319869,  0.9395881 ,  0.96010216,  0.9395881 ,
        0.9395881 ,  0.65924074,  0.99944627,  0.9395881 ,  0.9395881 ,
        0.9395881 ,  0.94384306,  0.34230717,  0.99954416,  0.9395881 ,
        0.34230717,  0.34230717,  0.34230717,  0.46866025,  0.96918236,
        0.43120721,  0.9395881 ,  0.34230717,  0.81375066,  0.9395881 ,
        0.9395881 ,  0.34230717,  0.9395881 ,  0.96647397,  0.78746847,
        0.66039593,  0.67875041,  0.87703355,  0.96774581,  0.90152299,
        0.9395881 ,  0.68730873,  0.96782731,  0.99957909,  0.34230717,
        0.34230717,  0.5919067 ,  0.44098844,  0.9395881 ,  0.80824383,
        0.9395881 ,  0.68858031, -0.86319869,  0.9395881 ,  0.96010216,
        0.9395881 ,  0.9395881 ,  0.65924074,  0.99944627,  0.9395881 ,
        0.9395881 ,  0.9395881 ,  0.94384306,  0.34230717,  0.99954416,
        0.9395881 ,  0.34230717,  0.34230717,  0.34230717,  0.46866025,
        0.96918236,  0.43120721,  0.9395881 ,  0.34230717,  0.81375066,
        0.9395881 ,  0.9395881 ,  0.34230717,  0.9395881 ,  0.96647397,
        0.78746847,  0.66039593,  0.67875041,  0.87703355,  0.96774581,
        0.90152299,  0.9395881 ,  0.68730873,  0.96782731,  0.96900536,
        0.12840693,  0.12840693,  0.75393487,  0.23390368,  0.99172156,
        0.65980162,  0.99172156,  0.51320988, -0.95275487,  0.99172156,
        0.87563955,  0.99172156,  0.99172156,  0.47872305,  0.967943  ,
        0.99172156,  0.99172156,  0.99172156,  0.84867036,  0.12840693,
        0.96871314,  0.99172156,  0.12840693,  0.12840693,  0.12840693,
        0.26399777,  0.89178762,  0.22332226,  0.99172156,  0.12840693,
        0.66684545,  0.99172156,  0.99172156,  0.12840693,  0.99172156,
        0.8868647 ,  0.63351058,  0.48007221,  0.50160308,  0.7506488 ,
        0.88916392,  0.78498216,  0.99172156,  0.82959741,  0.889312  ,
        0.99957909,  0.34230717,  0.34230717,  0.5919067 ,  0.44098844,
        0.9395881 ,  0.80824383,  0.9395881 ,  0.68858031, -0.86319869,
        0.9395881 ,  0.96010216,  0.9395881 ,  0.9395881 ,  0.65924074,
        0.99944627,  0.9395881 ,  0.9395881 ,  0.9395881 ,  0.94384306,
        0.34230717,  0.99954416,  0.9395881 ,  0.34230717,  0.34230717,
        0.34230717,  0.46866025,  0.96918236,  0.43120721,  0.9395881 ,
        0.34230717,  0.81375066,  0.9395881 ,  0.9395881 ,  0.34230717,
        0.9395881 ,  0.96647397,  0.78746847,  0.66039593,  0.67875041,
        0.87703355,  0.96774581,  0.90152299,  0.9395881 ,  0.68730873,
        0.96782731,  0.99737987,  0.30122881,  0.30122881,  0.62631144,
        0.40164229,  0.95355189,  0.78194086,  0.95355189,  0.6564772 ,
       -0.88428556,  0.95355189,  0.94706825,  0.95355189,  0.95355189,
        0.62600397,  0.99706228,  0.95355189,  0.95355189,  0.95355189,
        0.92862332,  0.30122881,  0.9972938 ,  0.95355189,  0.30122881,
        0.30122881,  0.30122881,  0.42990116,  0.95758456,  0.39166466,
        0.95355189,  0.30122881,  0.78777363,  0.95355189,  0.95355189,
        0.30122881,  0.95355189,  0.95442674,  0.75999189,  0.62720208,
        0.64625711,  0.85536868,  0.95590714,  0.88190405,  0.95355189,
        0.71816912,  0.95600215, -0.33845812, -0.99945196, -0.99945196,
        0.58193926, -0.99015491,  0.03310264, -0.81102105,  0.03310264,
       -0.90335536, -0.21135609,  0.03310264, -0.56438795,  0.03310264,
        0.03310264, -0.91967435, -0.34246805,  0.03310264,  0.03310264,
        0.03310264, -0.60755902, -0.99945196, -0.33956838,  0.03310264,
       -0.99945196, -0.99945196, -0.99945196, -0.98532802, -0.53555873,
       -0.99161784,  0.03310264, -0.99945196, -0.80547856,  0.03310264,
        0.03310264, -0.99945196,  0.03310264, -0.54461978, -0.83070233,
       -0.91906958, -0.90905032, -0.72937408, -0.54041925, -0.69165838,
        0.03310264,  0.47657641, -0.54014685, -0.33056473, -0.99913964,
       -0.99913964,  0.58873036, -0.98894776,  0.04147274, -0.80609259,
        0.04147274, -0.89973133, -0.21953532,  0.04147274, -0.5574538 ,
        0.04147274,  0.04147274, -0.91635305, -0.33458668,  0.04147274,
        0.04147274,  0.04147274, -0.60088498, -0.99913964, -0.33167829,
        0.04147274, -0.99913964, -0.99913964, -0.99913964, -0.98386394,
       -0.52846655, -0.99050085,  0.04147274, -0.99913964, -0.80048656,
        0.04147274,  0.04147274, -0.99913964,  0.04147274, -0.53757599,
       -0.82601022, -0.91573646, -0.90552833, -0.72361844, -0.53335289,
       -0.68558487,  0.04147274,  0.48392318, -0.53307903, -0.37799506,
       -0.99995728, -0.99995728,  0.54698579, -0.99519351, -0.00924308,
       -0.83506298, -0.00924308, -0.92070432, -0.16978381, -0.00924308,
       -0.59883331, -0.00924308, -0.00924308, -0.93547553, -0.38193992,
       -0.00924308, -0.00924308, -0.00924308, -0.6406432 , -0.99995728,
       -0.37908738, -0.00924308, -0.99995728, -0.99995728, -0.99995728,
       -0.99167056, -0.57083389, -0.99619911, -0.00924308, -0.99995728,
       -0.82984757, -0.00924308, -0.00924308, -0.99995728, -0.00924308,
       -0.57964062, -0.85352841, -0.93493114, -0.92587735, -0.75768496,
       -0.57555872, -0.72161662, -0.00924308,  0.43892723, -0.57529396,
       -0.29390247, -0.99679535, -0.99679535,  0.61948152, -0.98249049,
        0.07999389, -0.78265797,  0.07999389, -0.88222102, -0.25701575,
        0.07999389, -0.52500491,  0.07999389,  0.07999389, -0.90022224,
       -0.29797615,  0.07999389,  0.07999389,  0.07999389, -0.56959497,
       -0.99679535, -0.29503028,  0.07999389, -0.99679535, -0.99679535,
       -0.99679535, -0.97622767, -0.49531592, -0.98445751,  0.07999389,
       -0.99679535, -0.77676437,  0.07999389,  0.07999389, -0.99679535,
        0.07999389, -0.50464001, -0.80364634, -0.89955177, -0.88848283,
       -0.69644804, -0.50031674, -0.65698456,  0.07999389,  0.51732914,
       -0.50003643, -0.36216461, -0.9999696 , -0.9999696 ,  0.56117111,
       -0.99338034,  0.00779733, -0.82556722,  0.00779733, -0.91392067,
       -0.18655156,  0.00779733, -0.58509964,  0.00779733,  0.00779733,
       -0.92931799, -0.3661365 ,  0.00779733,  0.00779733,  0.00779733,
       -0.62746637, -0.9999696 , -0.36326438,  0.00779733, -0.9999696 ,
       -0.9999696 , -0.9999696 , -0.98933185, -0.55676022, -0.99457022,
        0.00779733, -0.9999696 , -0.82021904,  0.00779733,  0.00779733,
       -0.9999696 ,  0.00779733, -0.5656712 , -0.84452596, -0.92874918,
       -0.91930487, -0.74645443, -0.56154068, -0.70971532,  0.00779733,
        0.45417415, -0.56127279])


来源:https://stackoverflow.com/questions/62755267/cosine-distance-of-arrays-of-different-shapes-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!