问题
I asked this question How is the cosine distance calculated for two arrays with different shapes in Python? yesterday and got a precise answer. However, I tried using the same function to get the cosine distance between xx and yy below as a list, row wise Cosine distance of a row in xx to a row in yy. They are arrays of different shapes. What would be the solution. I must be missing something from the earlier example:-
xx = np.array([[ 8.07105800e-01, 2.87828956e-01],
[ 8.97970426e-01, 3.27145176e-01],
[ 8.97970426e-01, 3.27145176e-01],
[ 8.97970426e-01, 3.27145176e-01],
[ 2.00071147e-01, 2.59049753e-02],
[ 8.97970426e-01, 3.27145176e-01],
[ 5.83740391e-01, 1.84404675e-01],
[ 1.97769348e-03, -5.97115379e-02],
[ 2.72919656e-03, -6.57503753e-02],
[-8.47442660e-04, -9.16800956e-02],
[ 1.23314809e-02, -1.53661279e-01],
[ 1.40920559e-03, -1.80723646e-01]])
yy = np.array([[8.44094478e-03, 3.35563887e-03, 6.07153217e-18, 2.44468631e-02,
3.03576608e-18, 5.76097014e-03, 2.60208521e-18, -1.73472348e-18,
4.04756763e-03, 3.75670839e-02, 2.64618673e-03, 2.60208521e-18,
3.15227439e-03, 4.68980598e-03, 6.99511952e-03, -8.67361738e-19,
3.95149202e-03, 9.08766600e-03, -1.90819582e-17, 3.46944695e-18,
4.64043525e-03, -5.63785130e-18, 5.38781640e-03, 3.95149202e-03,
6.16573618e-03, -2.60208521e-18, 2.87670232e-03, -2.16840434e-18,
5.46844915e-03, 1.40806489e-02, 6.71015253e-03, 2.70077046e-03,
3.92620082e-03, -1.30104261e-18, 7.35027324e-03, -5.20417043e-18,
1.96262984e-02, -3.46944695e-18, 4.17690832e-03, 3.42000191e-03,
-5.20417043e-18, 6.31808776e-03, 6.63353766e-03, 2.64618673e-03,
2.64618673e-03, -3.03576608e-18, 1.73472348e-18, 3.95149202e-03,
-5.63785130e-18, 3.08286809e-03, 1.30104261e-18, 6.32236500e-03,
4.50714698e-03, 3.23573661e-02, 4.15821958e-03, 2.83429102e-03,
3.08286809e-03, 3.18612020e-02, 4.10100115e-03, -1.30104261e-18,
8.67361738e-19, 3.33684303e-02, 3.95149202e-03, 5.76097014e-03,
5.74691221e-03, -5.20417043e-18, 4.29728958e-03, -2.60208521e-18,
-6.07153217e-18, 2.70077046e-03, 4.71137641e-03, -8.67361738e-19,
4.17690832e-03, 2.91338264e-03, 5.38781640e-03, 8.64510290e-03,
3.86036578e-03, 9.89490959e-03, 5.76097014e-03, 1.37527752e-02,
6.42464160e-03, 7.32098404e-03, 4.71137641e-03, 3.25127413e-03,
4.60430401e-03, 4.71137641e-03, 3.42000191e-03, -4.33680869e-18,
6.07153217e-18, -3.03576608e-18, 6.30454878e-03, 4.34769443e-03]])
I'll be glad looking at the solution.
回答1:
def cosine_similarity_vec(x, y):
numerator_vec = (x[:, np.newaxis]*y).sum(axis=2).reshape(-1)
denominator_vec = np.sqrt(
np.einsum('ij, ij->i', x, x)[:, np.newaxis] * np.einsum('ij, ij->i', y, y)).reshape(-1)
return numerator_vec / denominator_vec
Data
a = np.array(([7.26741212e-01, -9.80825232e-17], ))
b = np.array(([-3.82390578e-01, -1.48157964e-17],
[-3.82390578e-01, 7.87310307e-01],
[7.26741212e-01, -9.80825232e-17],
[7.26741212e-01, -9.80825232e-17],
[-3.82390578e-01, -2.06286905e-01],
[7.26741212e-01, -9.80825232e-17],
[-2.16887107e-01, 6.84509305e-17],
[-3.82390578e-01, -5.81023402e-01],
[-2.16887107e-01, 6.84509305e-17],
[-2.16887107e-01, 6.84509305e-17]))
Output:
cosine_similarity_vect(a, b)
array([[-1. , -0.43688798, 1. , 1. , -0.88010163,
1. , -1. , -0.5497553 , -1. , -1. ]])
Data:
xx = np.array([[8.07105800e-01, 2.87828956e-01],
[8.97970426e-01, 3.27145176e-01],
[8.97970426e-01, 3.27145176e-01],
[8.97970426e-01, 3.27145176e-01],
[2.00071147e-01, 2.59049753e-02],
[8.97970426e-01, 3.27145176e-01],
[5.83740391e-01, 1.84404675e-01],
[1.97769348e-03, -5.97115379e-02],
[2.72919656e-03, -6.57503753e-02],
[-8.47442660e-04, -9.16800956e-02],
[1.23314809e-02, -1.53661279e-01],
[1.40920559e-03, -1.80723646e-01]])
yy = np.array([[8.44094478e-03, 3.35563887e-03],
[6.07153217e-18, 2.44468631e-02],
[3.03576608e-18, 5.76097014e-03],
[2.60208521e-18, -1.73472348e-18],
[4.04756763e-03, 3.75670839e-02],
[2.64618673e-03, 2.60208521e-18],
[3.15227439e-03, 4.68980598e-03],
[6.99511952e-03, -8.67361738e-19],
[3.95149202e-03, 9.08766600e-03],
[-1.90819582e-17, 3.46944695e-18],
[4.64043525e-03, -5.63785130e-18],
[5.38781640e-03, 3.95149202e-03],
[6.16573618e-03, -2.60208521e-18],
[2.87670232e-03, -2.16840434e-18],
[5.46844915e-03, 1.40806489e-02],
[6.71015253e-03, 2.70077046e-03],
[3.92620082e-03, -1.30104261e-18],
[7.35027324e-03, -5.20417043e-18],
[1.96262984e-02, -3.46944695e-18],
[4.17690832e-03, 3.42000191e-03],
[-5.20417043e-18, 6.31808776e-03],
[6.63353766e-03, 2.64618673e-03],
[2.64618673e-03, -3.03576608e-18],
[1.73472348e-18, 3.95149202e-03],
[-5.63785130e-18, 3.08286809e-03],
[1.30104261e-18, 6.32236500e-03],
[4.50714698e-03, 3.23573661e-02],
[4.15821958e-03, 2.83429102e-03],
[3.08286809e-03, 3.18612020e-02],
[4.10100115e-03, -1.30104261e-18],
[8.67361738e-19, 3.33684303e-02],
[3.95149202e-03, 5.76097014e-03],
[5.74691221e-03, -5.20417043e-18],
[4.29728958e-03, -2.60208521e-18],
[-6.07153217e-18, 2.70077046e-03],
[4.71137641e-03, -8.67361738e-19],
[4.17690832e-03, 2.91338264e-03],
[5.38781640e-03, 8.64510290e-03],
[3.86036578e-03, 9.89490959e-03],
[5.76097014e-03, 1.37527752e-02],
[6.42464160e-03, 7.32098404e-03],
[4.71137641e-03, 3.25127413e-03],
[4.60430401e-03, 4.71137641e-03],
[3.42000191e-03, -4.33680869e-18],
[6.07153217e-18, -3.03576608e-18],
[6.30454878e-03, 4.34769443e-03]])
Output
cosine_similarity_vec(xx, yy)
array([ 0.99935826, 0.33589844, 0.33589844, 0.59738375, 0.43486405,
0.94189821, 0.80421365, 0.94189821, 0.6836243 , -0.86661797,
0.94189821, 0.95817482, 0.94189821, 0.94189821, 0.65410303,
0.99919641, 0.94189821, 0.94189821, 0.94189821, 0.9415704 ,
0.33589844, 0.99931529, 0.94189821, 0.33589844, 0.33589844,
0.33589844, 0.46263151, 0.96748169, 0.42505076, 0.94189821,
0.33589844, 0.80977235, 0.94189821, 0.94189821, 0.33589844,
0.94189821, 0.96470238, 0.78325138, 0.6552651 , 0.6737319 ,
0.87374036, 0.96600713, 0.89855416, 0.94189821, 0.69224102,
0.96609076, 0.99957909, 0.34230717, 0.34230717, 0.5919067 ,
0.44098844, 0.9395881 , 0.80824383, 0.9395881 , 0.68858031,
-0.86319869, 0.9395881 , 0.96010216, 0.9395881 , 0.9395881 ,
0.65924074, 0.99944627, 0.9395881 , 0.9395881 , 0.9395881 ,
0.94384306, 0.34230717, 0.99954416, 0.9395881 , 0.34230717,
0.34230717, 0.34230717, 0.46866025, 0.96918236, 0.43120721,
0.9395881 , 0.34230717, 0.81375066, 0.9395881 , 0.9395881 ,
0.34230717, 0.9395881 , 0.96647397, 0.78746847, 0.66039593,
0.67875041, 0.87703355, 0.96774581, 0.90152299, 0.9395881 ,
0.68730873, 0.96782731, 0.99957909, 0.34230717, 0.34230717,
0.5919067 , 0.44098844, 0.9395881 , 0.80824383, 0.9395881 ,
0.68858031, -0.86319869, 0.9395881 , 0.96010216, 0.9395881 ,
0.9395881 , 0.65924074, 0.99944627, 0.9395881 , 0.9395881 ,
0.9395881 , 0.94384306, 0.34230717, 0.99954416, 0.9395881 ,
0.34230717, 0.34230717, 0.34230717, 0.46866025, 0.96918236,
0.43120721, 0.9395881 , 0.34230717, 0.81375066, 0.9395881 ,
0.9395881 , 0.34230717, 0.9395881 , 0.96647397, 0.78746847,
0.66039593, 0.67875041, 0.87703355, 0.96774581, 0.90152299,
0.9395881 , 0.68730873, 0.96782731, 0.99957909, 0.34230717,
0.34230717, 0.5919067 , 0.44098844, 0.9395881 , 0.80824383,
0.9395881 , 0.68858031, -0.86319869, 0.9395881 , 0.96010216,
0.9395881 , 0.9395881 , 0.65924074, 0.99944627, 0.9395881 ,
0.9395881 , 0.9395881 , 0.94384306, 0.34230717, 0.99954416,
0.9395881 , 0.34230717, 0.34230717, 0.34230717, 0.46866025,
0.96918236, 0.43120721, 0.9395881 , 0.34230717, 0.81375066,
0.9395881 , 0.9395881 , 0.34230717, 0.9395881 , 0.96647397,
0.78746847, 0.66039593, 0.67875041, 0.87703355, 0.96774581,
0.90152299, 0.9395881 , 0.68730873, 0.96782731, 0.96900536,
0.12840693, 0.12840693, 0.75393487, 0.23390368, 0.99172156,
0.65980162, 0.99172156, 0.51320988, -0.95275487, 0.99172156,
0.87563955, 0.99172156, 0.99172156, 0.47872305, 0.967943 ,
0.99172156, 0.99172156, 0.99172156, 0.84867036, 0.12840693,
0.96871314, 0.99172156, 0.12840693, 0.12840693, 0.12840693,
0.26399777, 0.89178762, 0.22332226, 0.99172156, 0.12840693,
0.66684545, 0.99172156, 0.99172156, 0.12840693, 0.99172156,
0.8868647 , 0.63351058, 0.48007221, 0.50160308, 0.7506488 ,
0.88916392, 0.78498216, 0.99172156, 0.82959741, 0.889312 ,
0.99957909, 0.34230717, 0.34230717, 0.5919067 , 0.44098844,
0.9395881 , 0.80824383, 0.9395881 , 0.68858031, -0.86319869,
0.9395881 , 0.96010216, 0.9395881 , 0.9395881 , 0.65924074,
0.99944627, 0.9395881 , 0.9395881 , 0.9395881 , 0.94384306,
0.34230717, 0.99954416, 0.9395881 , 0.34230717, 0.34230717,
0.34230717, 0.46866025, 0.96918236, 0.43120721, 0.9395881 ,
0.34230717, 0.81375066, 0.9395881 , 0.9395881 , 0.34230717,
0.9395881 , 0.96647397, 0.78746847, 0.66039593, 0.67875041,
0.87703355, 0.96774581, 0.90152299, 0.9395881 , 0.68730873,
0.96782731, 0.99737987, 0.30122881, 0.30122881, 0.62631144,
0.40164229, 0.95355189, 0.78194086, 0.95355189, 0.6564772 ,
-0.88428556, 0.95355189, 0.94706825, 0.95355189, 0.95355189,
0.62600397, 0.99706228, 0.95355189, 0.95355189, 0.95355189,
0.92862332, 0.30122881, 0.9972938 , 0.95355189, 0.30122881,
0.30122881, 0.30122881, 0.42990116, 0.95758456, 0.39166466,
0.95355189, 0.30122881, 0.78777363, 0.95355189, 0.95355189,
0.30122881, 0.95355189, 0.95442674, 0.75999189, 0.62720208,
0.64625711, 0.85536868, 0.95590714, 0.88190405, 0.95355189,
0.71816912, 0.95600215, -0.33845812, -0.99945196, -0.99945196,
0.58193926, -0.99015491, 0.03310264, -0.81102105, 0.03310264,
-0.90335536, -0.21135609, 0.03310264, -0.56438795, 0.03310264,
0.03310264, -0.91967435, -0.34246805, 0.03310264, 0.03310264,
0.03310264, -0.60755902, -0.99945196, -0.33956838, 0.03310264,
-0.99945196, -0.99945196, -0.99945196, -0.98532802, -0.53555873,
-0.99161784, 0.03310264, -0.99945196, -0.80547856, 0.03310264,
0.03310264, -0.99945196, 0.03310264, -0.54461978, -0.83070233,
-0.91906958, -0.90905032, -0.72937408, -0.54041925, -0.69165838,
0.03310264, 0.47657641, -0.54014685, -0.33056473, -0.99913964,
-0.99913964, 0.58873036, -0.98894776, 0.04147274, -0.80609259,
0.04147274, -0.89973133, -0.21953532, 0.04147274, -0.5574538 ,
0.04147274, 0.04147274, -0.91635305, -0.33458668, 0.04147274,
0.04147274, 0.04147274, -0.60088498, -0.99913964, -0.33167829,
0.04147274, -0.99913964, -0.99913964, -0.99913964, -0.98386394,
-0.52846655, -0.99050085, 0.04147274, -0.99913964, -0.80048656,
0.04147274, 0.04147274, -0.99913964, 0.04147274, -0.53757599,
-0.82601022, -0.91573646, -0.90552833, -0.72361844, -0.53335289,
-0.68558487, 0.04147274, 0.48392318, -0.53307903, -0.37799506,
-0.99995728, -0.99995728, 0.54698579, -0.99519351, -0.00924308,
-0.83506298, -0.00924308, -0.92070432, -0.16978381, -0.00924308,
-0.59883331, -0.00924308, -0.00924308, -0.93547553, -0.38193992,
-0.00924308, -0.00924308, -0.00924308, -0.6406432 , -0.99995728,
-0.37908738, -0.00924308, -0.99995728, -0.99995728, -0.99995728,
-0.99167056, -0.57083389, -0.99619911, -0.00924308, -0.99995728,
-0.82984757, -0.00924308, -0.00924308, -0.99995728, -0.00924308,
-0.57964062, -0.85352841, -0.93493114, -0.92587735, -0.75768496,
-0.57555872, -0.72161662, -0.00924308, 0.43892723, -0.57529396,
-0.29390247, -0.99679535, -0.99679535, 0.61948152, -0.98249049,
0.07999389, -0.78265797, 0.07999389, -0.88222102, -0.25701575,
0.07999389, -0.52500491, 0.07999389, 0.07999389, -0.90022224,
-0.29797615, 0.07999389, 0.07999389, 0.07999389, -0.56959497,
-0.99679535, -0.29503028, 0.07999389, -0.99679535, -0.99679535,
-0.99679535, -0.97622767, -0.49531592, -0.98445751, 0.07999389,
-0.99679535, -0.77676437, 0.07999389, 0.07999389, -0.99679535,
0.07999389, -0.50464001, -0.80364634, -0.89955177, -0.88848283,
-0.69644804, -0.50031674, -0.65698456, 0.07999389, 0.51732914,
-0.50003643, -0.36216461, -0.9999696 , -0.9999696 , 0.56117111,
-0.99338034, 0.00779733, -0.82556722, 0.00779733, -0.91392067,
-0.18655156, 0.00779733, -0.58509964, 0.00779733, 0.00779733,
-0.92931799, -0.3661365 , 0.00779733, 0.00779733, 0.00779733,
-0.62746637, -0.9999696 , -0.36326438, 0.00779733, -0.9999696 ,
-0.9999696 , -0.9999696 , -0.98933185, -0.55676022, -0.99457022,
0.00779733, -0.9999696 , -0.82021904, 0.00779733, 0.00779733,
-0.9999696 , 0.00779733, -0.5656712 , -0.84452596, -0.92874918,
-0.91930487, -0.74645443, -0.56154068, -0.70971532, 0.00779733,
0.45417415, -0.56127279])
来源:https://stackoverflow.com/questions/62755267/cosine-distance-of-arrays-of-different-shapes-in-python