LDA模型的封装

醉酒当歌 提交于 2020-01-24 01:06:51

最近一直在训练LDA模型,将LDA模型封装在一个脚本中,可以直接在终端传入参数进行LDA的训练和预测.
需要在同目录下准备一个stopwords(停用词典)

#conding=utf-8
import codecs
import os
import re
from os import mkdir
from os.path import exists, isdir, abspath, join
import gensim
import jieba.posseg as pseg
import yaml
from gensim import corpora

#1.数据处理工具
def data_util(data,cut_flag=False,stopwords_file='./stopwords'):
    '''
    传入一条语料,对其进行处理,获取训练语料
    :param data:    单行的语料
    :param cut_flag:    是否分词,为True时,使用jieba进行分词处理
    :param stopwords_file:   停用词文件
    :return:    list of words
    '''
    stopwords_file = abspath(stopwords_file)
    if cut_flag:
        function_words = ["d", "p", "c", "u", "o", "e", "m", "q", "r", "t", "z"]
        stopwords = []
        with open(stopwords_file, 'r', encoding='utf8') as f:
            for i in f:
                stopwords.append(i.strip())
        j = re.sub(r'[^\u4e00-\u9fa5]+', ' ', data)
        words = pseg.cut(j)
        line = []
        for word, flag in words:
            if flag not in function_words and word not in stopwords and word.strip():
                line.append(word)
    else:
        line = data.strip().split()
    return line

#2.读取数据函数
def get_train_data(file_name,cut_flag=False):
    '''
    传入一个目录or文件名,返回训练数据
    '''
    if not exists(file_name):
        raise ValueError('文件不存在: {}'.format(file_name))
    if isdir(file_name):
        file_list = os.listdir(file_name)
        if not file_list:
            raise ValueError('文件夹为空: {}'.format(file_name))
        #数据存在文件夹中
        print('#####训练数据读取中#####')
        ret = []
        n = 0
        for file in file_list:
            with open(file_name+'/'+file, 'r', encoding='utf8') as f:
                for line in f:
                    data = data_util(line, cut_flag)
                    if data:
                        ret.append(data)
                        n+=1
                    if n % 10000 == 0:
                        print(f'已加载{n}条数据')
                if n > 2000000:
                    break
        print(f'加载完成,共{n}条数据')
        return ret
    else:
        #存在一个单独的文件中
        print('#####训练数据读取中#####')
        ret = []
        with open(file_name,'r',encoding='utf8') as f:
            for n,line in enumerate(f):
                data = data_util(line,cut_flag)
                if data:
                    ret.append(data)

                if n%10000 == 0:
                    print(f'已加载{n}条数据')
                if n >2000000:
                    break
            print(f'加载完成,共{n}条数据')
            return ret


#3.参数类
class LDAArgs(dict):
    #构建一个参数类(字典的子类)
    def __init__(self, params=None, *args, **kwargs):
        super(LDAArgs, self).__init__(*args, **kwargs)
        self.update(params)
        self.__dict__ = self

    def save(self, args_file):
        #将对象自身转换为字典存在args_file中
        to_dump_dict = dict(self.__dict__)
        #将参数字典里的数据路径转换为绝对路径储存
        to_dump_dict['documents'] = abspath(to_dump_dict['documents'])
        #将参数存储在传入的路径中
        with codecs.getwriter("utf-8")(open(args_file, "wb")) as f:
            yaml.dump(to_dump_dict, f, default_flow_style=False)
    @staticmethod
    #加载yaml文件中的参数,返回一个LDAArgs对象
    def load(args_file):
        with codecs.getreader("utf-8")(open(args_file, "rb")) as f:
            params = yaml.load(f)
        return LDAArgs(params=params)



#4.训练模型
def train(args):
    '''
    训练模型并保存在指定的路径
    :param args: 训练的参数字典
    :return: None
    '''
    corpus = get_train_data(args.documents,args.cut_flag)
    #读取数据,构建字典
    dictionary = corpora.Dictionary(corpus)
    class CorpusWrapper:
        def __init__(self, dictionary):
            self._dictionary = dictionary
        def __iter__(self):
            for tokens in corpus:
                yield self._dictionary.doc2bow(tokens)
    mm_corpus = CorpusWrapper(dictionary)

    #开始训练模型
    print('##### start train #####')
    lda_model = gensim.models.LdaModel(mm_corpus,
                                    id2word=dictionary,
                                    num_topics=args.num_topics,
                                    iterations=args.iterations,
                                    chunksize=args.chunksize,
                                    eta=args.eta,
                                    eval_every=args.eval_every)
    print('##### saving model #####')
    lda_model.save(args.model_name)
    print('##### 输出主题词 #####')
    for i in range(args.num_topics):
        print(f'topic:{i}',[lda_model.id2word[i[0]] for i in lda_model.get_topic_terms(i, args.topn)])
    # 将使用的参数保存
    if not exists('./config/'):
        mkdir('./config/')
    args.save( f'./config/config_topic{args.num_topics}.yml')

#5.预测函数
class TopicInferer:
    def __init__(self,args):
        '''
        model_dir:  模型的存储路径
        verbose:    是否显示详细的信息
        '''
        self.args =args
        # self._model_dir = model_dir
        if exists(args.model_name):
            # 加载lda模型
            self._ldamodel = gensim.models.LdaModel.load(args.model_name)
        else:
            raise NameError('模型不存在')

    #返回所有topic的前words_per_topic个主题词,存在一个字典中
    def _init_words_per_topics(self):
        topic_word_dict = {}

        for i in range(self.args.num_topics):
            topic_words = [self._ldamodel.id2word[i[0]] for i in self._ldamodel.get_topic_terms(i, self.args.topn)]
            topic_word_dict[i] = topic_words
        return topic_word_dict

    def predict(self):
        data = self.args.test_data
        if exists(data):
            #对文件数据预测
            #获取预测数据
            f1 = open(self.args.output,'a',encoding='utf8')
            with open(data,'r',encoding='utf8') as f:
                topic_word_dict = self._init_words_per_topics()
                for i in f:
                    # 处理数据
                    data = i.strip()
                    data_list = data_util(data, self.args.cut_flag)
                    corpus = self._ldamodel.id2word.doc2bow(data_list)
                    topic_ids = self._ldamodel.get_document_topics(corpus)
                    if len(topic_ids) > 0:
                        # 排序,只取第一个预测结果
                        t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0]
                        f1.write(data+'\t'+ ' '.join(topic_word_dict[t_id]) + '\n')
                    else:
                        f1.write('Null'+'\n')

            pass
        else:
            if not isinstance(data,str):
                raise ValueError('请输入字符串格式的数据')
            topic_word_dict = self._init_words_per_topics()
            #处理数据
            data_list = data_util(data,self.args.cut_flag)
            corpus = self._ldamodel.id2word.doc2bow(data_list)
            topic_ids = self._ldamodel.get_document_topics(corpus)
            if len(topic_ids) > 0:
                # 排序,只取第一个预测结果
                t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0]
                print(t_id, topic_word_dict[t_id])
            else:
                print('Null')


#6.主函数
def main():
    import argparse
    # 参数的获取,从命令行获取参数
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_topics', type=int, default=200, help='话题数 默认200')
    parser.add_argument('--cut_flag', type=bool, default=True, help='数据是否需要分词处理')
    parser.add_argument('--model_name', type=str, required=True, help='模型的路径')
    parser.add_argument('--data', type=str, help='训练数据的路径')
    parser.add_argument('--mode', type=str, default="infer", choices=("train", "infer"), help='预测模式或者训练模式')
    parser.add_argument('--topn', type=int, default=10, help='每个topic中输出的主题词的个数')
    parser.add_argument('--test_data', type=str, help='需要预测的数据,测试数据路径 or 单条数据')
    parser.add_argument('--output', type=str,default='./result.txt', help='预测的结果的存储路径')

    args = parser.parse_args()
    if args.mode == 'train':
        if not args.data:
            raise ValueError('请传入训练数据')
        print(args.data)
        _params = {
            "model_name":args.model_name,
            "num_topics": args.num_topics,
            "documents": args.data,
            'cut_flag' : args.cut_flag,
            "eval_every": 10,
            "chunksize": 2000,
            "iterations": 10,
            "words_per_topic": 100,
            'eta': None,
            'topn': args.topn
        }
        train(LDAArgs(_params))
    elif args.mode == 'infer':
        if not args.test_data:
            raise ValueError('请传入测试数据')
        _params = {
            "model_name": args.model_name,
            "num_topics": args.num_topics,
            'cut_flag': args.cut_flag,
            # 'eta': None,
            'topn': args.topn,
            'test_data':args.test_data,
            'output':args.output
        }
        TopicInferer(LDAArgs(_params)).predict()
if __name__ == '__main__':
    main()
    #训练示例
    #python LDA.py --model_name=./model/LDA_topic10.model --mode=train --data=./train --num_topics=10
    #预测示例
    #python LDA.py --model_name=./model/LDA_topic10.model --num_topics=10 --test_data=./test.txt
    # 或者
    #python LDA.py --model_name=./model/LDA_topic2.model --num_topics=2 --test_data=做客孩子临走时带走几只玩具,我的孩子抗拒并一直哭,要怎么开导

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!