使用lsh快速检索语义-词向量结合

主宰稳场 提交于 2020-04-29 13:46:24
"""
    test
"""
import os
import gensim
import pickle
import time
import numpy as np

DIR_PATH = os.path.dirname(os.path.abspath(__file__))
HASHTABLES = os.path.join(DIR_PATH, 'resource', 'hashtables.pkl')
WORD2VEC = os.path.join(DIR_PATH, 'resource', 'sgns.weibo.word')
RESOURCES = os.path.join(DIR_PATH, 'resource', 'resources.pkl')



class MyClass(object):

    def __init__(self, Table_num=5, Hashcode_fun=5):
        self.hashtables = HASHTABLES
        self.word2vec = WORD2VEC
        self.resources = RESOURCES
        self.table_num = Table_num
        self.Hashcode_fun = Hashcode_fun

    def load_traindata(self):
        model = gensim.models.KeyedVectors.load_word2vec_format(self.word2vec, unicode_errors='ignore')
        data = []
        features = []

        for word, vector in zip(model.vocab, model.vectors):
            features.append(vector)
            data.append(word)
            print(word)
        self.features = np.array(features)
        self.data = data
        with open(self.resources, 'wb') as fw:
            pickle.dump((self.features, self.data), fw)
        print('词向量序列化完毕,当前词向量数量:{}'.format(len(self.data)))


    def create_hashtables(self):
        with open(self.resources, 'rb') as fr:
            features, _ = pickle.load(fr)
        print('特征加载完毕,当前词向量数量:{}'.format(len(features)))


        users_size, items_size = features.shape
        hashtables = [[[] for _ in range(int('1' * self.Hashcode_fun) + 1)] for _ in range(self.table_num)]

        random_matrixes = [np.empty((self.Hashcode_fun, items_size)) for _ in range(self.table_num)]
        for i in range(self.table_num):
            random_matrixes[i] = np.random.uniform(-1, 1, (self.Hashcode_fun, items_size))
        for i, user_vec in enumerate(features):

            for j in range(self.table_num):
                v = random_matrixes[j]
                index = ''
                for k in range(self.Hashcode_fun):
                    index += '1' if np.dot(user_vec, v[k]) > 0 else '0'
                t_index = int(index, 2)
                hashtables[j][t_index].append(i)

        with open(self.hashtables, 'wb') as fw:
            pickle.dump((hashtables,random_matrixes), fw)
        print('hash表存储完毕')


    def cal_similarity(self):
        with open(self.resources, 'rb') as fr:
            _, data = pickle.load(fr)

        with open(self.hashtables, 'rb') as fr:
            hashtables, random_matrixes = pickle.load(fr)

        model = gensim.models.KeyedVectors.load_word2vec_format(self.word2vec, unicode_errors='ignore')
        search_data = '中国'  # word2vec 找出的相似词:[('Portugal#', 0.8183228373527527), ('University#', 0.8141831755638123), ('Montfort', 0.8129391074180603),

        search_feature_vec = np.array(model.get_vector(search_data))
        sim = model.most_similar(search_data)
        print('word2vec 找出的相似词:{}'.format(sim))
        print('{}-莱雅,相似度:{}'.format(search_data, model.similarity(search_data, '莱雅')))
        print('{}-触网,相似度:{}'.format(search_data, model.similarity(search_data, '触网')))


        # '莱雅', '真材实料', '触网', '@Sophia', '汕尾',
        similar_users = set()
        t1 = time.time()
        for i, hashtable in enumerate(hashtables):
            index = ''
            for j in range(self.Hashcode_fun):
                index += '1' if np.dot(search_feature_vec, random_matrixes[i][j]) > 0 else '0'
            target_index = int(index, 2)
            similar_users |= set(hashtable[target_index])
        t2 = time.time()
        print('查找相似性用户耗时:{:.4f}'.format(t2 - t1))

        t3 = time.time()
        res = {}
        for i in similar_users:
            res[data[i]] = cosine_similarity2(search_feature_vec, model.get_vector(data[i]))
        a = sorted(res.items(), key=lambda x: x[1], reverse=True)
        t4 = time.time()
        print('计算余弦相似度及排序耗时:{:.4f}ms'.format(t4-t3))
        print(a[:20])


def cosine_similarity(x, y):
    res = np.array([[x[i] * y[i], x[i] * x[i], y[i] * y[i]] for i in range(len(x))])
    cos = sum(res[:, 0]) / (np.sqrt(sum(res[:, 1])) * np.sqrt(sum(res[:, 2])))

    return cos

def cosine_similarity2(x,y):
    num = x.dot(y.T)
    denom = np.linalg.norm(x) * np.linalg.norm(y)
    return num / denom


if __name__ == '__main__':
    ir = MyClass()
    # ir.load_traindata()
    # ir.create_hashtables()
    ir.cal_similarity()

能够快速捕获一组相似性数据出来

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!