最近一直在训练LDA模型,将LDA模型封装在一个脚本中,可以直接在终端传入参数进行LDA的训练和预测.
需要在同目录下准备一个stopwords(停用词典)
#conding=utf-8
import codecs
import os
import re
from os import mkdir
from os.path import exists, isdir, abspath, join
import gensim
import jieba.posseg as pseg
import yaml
from gensim import corpora
#1.数据处理工具
def data_util(data,cut_flag=False,stopwords_file='./stopwords'):
'''
传入一条语料,对其进行处理,获取训练语料
:param data: 单行的语料
:param cut_flag: 是否分词,为True时,使用jieba进行分词处理
:param stopwords_file: 停用词文件
:return: list of words
'''
stopwords_file = abspath(stopwords_file)
if cut_flag:
function_words = ["d", "p", "c", "u", "o", "e", "m", "q", "r", "t", "z"]
stopwords = []
with open(stopwords_file, 'r', encoding='utf8') as f:
for i in f:
stopwords.append(i.strip())
j = re.sub(r'[^\u4e00-\u9fa5]+', ' ', data)
words = pseg.cut(j)
line = []
for word, flag in words:
if flag not in function_words and word not in stopwords and word.strip():
line.append(word)
else:
line = data.strip().split()
return line
#2.读取数据函数
def get_train_data(file_name,cut_flag=False):
'''
传入一个目录or文件名,返回训练数据
'''
if not exists(file_name):
raise ValueError('文件不存在: {}'.format(file_name))
if isdir(file_name):
file_list = os.listdir(file_name)
if not file_list:
raise ValueError('文件夹为空: {}'.format(file_name))
#数据存在文件夹中
print('#####训练数据读取中#####')
ret = []
n = 0
for file in file_list:
with open(file_name+'/'+file, 'r', encoding='utf8') as f:
for line in f:
data = data_util(line, cut_flag)
if data:
ret.append(data)
n+=1
if n % 10000 == 0:
print(f'已加载{n}条数据')
if n > 2000000:
break
print(f'加载完成,共{n}条数据')
return ret
else:
#存在一个单独的文件中
print('#####训练数据读取中#####')
ret = []
with open(file_name,'r',encoding='utf8') as f:
for n,line in enumerate(f):
data = data_util(line,cut_flag)
if data:
ret.append(data)
if n%10000 == 0:
print(f'已加载{n}条数据')
if n >2000000:
break
print(f'加载完成,共{n}条数据')
return ret
#3.参数类
class LDAArgs(dict):
#构建一个参数类(字典的子类)
def __init__(self, params=None, *args, **kwargs):
super(LDAArgs, self).__init__(*args, **kwargs)
self.update(params)
self.__dict__ = self
def save(self, args_file):
#将对象自身转换为字典存在args_file中
to_dump_dict = dict(self.__dict__)
#将参数字典里的数据路径转换为绝对路径储存
to_dump_dict['documents'] = abspath(to_dump_dict['documents'])
#将参数存储在传入的路径中
with codecs.getwriter("utf-8")(open(args_file, "wb")) as f:
yaml.dump(to_dump_dict, f, default_flow_style=False)
@staticmethod
#加载yaml文件中的参数,返回一个LDAArgs对象
def load(args_file):
with codecs.getreader("utf-8")(open(args_file, "rb")) as f:
params = yaml.load(f)
return LDAArgs(params=params)
#4.训练模型
def train(args):
'''
训练模型并保存在指定的路径
:param args: 训练的参数字典
:return: None
'''
corpus = get_train_data(args.documents,args.cut_flag)
#读取数据,构建字典
dictionary = corpora.Dictionary(corpus)
class CorpusWrapper:
def __init__(self, dictionary):
self._dictionary = dictionary
def __iter__(self):
for tokens in corpus:
yield self._dictionary.doc2bow(tokens)
mm_corpus = CorpusWrapper(dictionary)
#开始训练模型
print('##### start train #####')
lda_model = gensim.models.LdaModel(mm_corpus,
id2word=dictionary,
num_topics=args.num_topics,
iterations=args.iterations,
chunksize=args.chunksize,
eta=args.eta,
eval_every=args.eval_every)
print('##### saving model #####')
lda_model.save(args.model_name)
print('##### 输出主题词 #####')
for i in range(args.num_topics):
print(f'topic:{i}',[lda_model.id2word[i[0]] for i in lda_model.get_topic_terms(i, args.topn)])
# 将使用的参数保存
if not exists('./config/'):
mkdir('./config/')
args.save( f'./config/config_topic{args.num_topics}.yml')
#5.预测函数
class TopicInferer:
def __init__(self,args):
'''
model_dir: 模型的存储路径
verbose: 是否显示详细的信息
'''
self.args =args
# self._model_dir = model_dir
if exists(args.model_name):
# 加载lda模型
self._ldamodel = gensim.models.LdaModel.load(args.model_name)
else:
raise NameError('模型不存在')
#返回所有topic的前words_per_topic个主题词,存在一个字典中
def _init_words_per_topics(self):
topic_word_dict = {}
for i in range(self.args.num_topics):
topic_words = [self._ldamodel.id2word[i[0]] for i in self._ldamodel.get_topic_terms(i, self.args.topn)]
topic_word_dict[i] = topic_words
return topic_word_dict
def predict(self):
data = self.args.test_data
if exists(data):
#对文件数据预测
#获取预测数据
f1 = open(self.args.output,'a',encoding='utf8')
with open(data,'r',encoding='utf8') as f:
topic_word_dict = self._init_words_per_topics()
for i in f:
# 处理数据
data = i.strip()
data_list = data_util(data, self.args.cut_flag)
corpus = self._ldamodel.id2word.doc2bow(data_list)
topic_ids = self._ldamodel.get_document_topics(corpus)
if len(topic_ids) > 0:
# 排序,只取第一个预测结果
t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0]
f1.write(data+'\t'+ ' '.join(topic_word_dict[t_id]) + '\n')
else:
f1.write('Null'+'\n')
pass
else:
if not isinstance(data,str):
raise ValueError('请输入字符串格式的数据')
topic_word_dict = self._init_words_per_topics()
#处理数据
data_list = data_util(data,self.args.cut_flag)
corpus = self._ldamodel.id2word.doc2bow(data_list)
topic_ids = self._ldamodel.get_document_topics(corpus)
if len(topic_ids) > 0:
# 排序,只取第一个预测结果
t_id = sorted(topic_ids, key=lambda x: x[1], reverse=True)[0][0]
print(t_id, topic_word_dict[t_id])
else:
print('Null')
#6.主函数
def main():
import argparse
# 参数的获取,从命令行获取参数
parser = argparse.ArgumentParser()
parser.add_argument('--num_topics', type=int, default=200, help='话题数 默认200')
parser.add_argument('--cut_flag', type=bool, default=True, help='数据是否需要分词处理')
parser.add_argument('--model_name', type=str, required=True, help='模型的路径')
parser.add_argument('--data', type=str, help='训练数据的路径')
parser.add_argument('--mode', type=str, default="infer", choices=("train", "infer"), help='预测模式或者训练模式')
parser.add_argument('--topn', type=int, default=10, help='每个topic中输出的主题词的个数')
parser.add_argument('--test_data', type=str, help='需要预测的数据,测试数据路径 or 单条数据')
parser.add_argument('--output', type=str,default='./result.txt', help='预测的结果的存储路径')
args = parser.parse_args()
if args.mode == 'train':
if not args.data:
raise ValueError('请传入训练数据')
print(args.data)
_params = {
"model_name":args.model_name,
"num_topics": args.num_topics,
"documents": args.data,
'cut_flag' : args.cut_flag,
"eval_every": 10,
"chunksize": 2000,
"iterations": 10,
"words_per_topic": 100,
'eta': None,
'topn': args.topn
}
train(LDAArgs(_params))
elif args.mode == 'infer':
if not args.test_data:
raise ValueError('请传入测试数据')
_params = {
"model_name": args.model_name,
"num_topics": args.num_topics,
'cut_flag': args.cut_flag,
# 'eta': None,
'topn': args.topn,
'test_data':args.test_data,
'output':args.output
}
TopicInferer(LDAArgs(_params)).predict()
if __name__ == '__main__':
main()
#训练示例
#python LDA.py --model_name=./model/LDA_topic10.model --mode=train --data=./train --num_topics=10
#预测示例
#python LDA.py --model_name=./model/LDA_topic10.model --num_topics=10 --test_data=./test.txt
# 或者
#python LDA.py --model_name=./model/LDA_topic2.model --num_topics=2 --test_data=做客孩子临走时带走几只玩具,我的孩子抗拒并一直哭,要怎么开导
来源:CSDN
作者:hylalalala
链接:https://blog.csdn.net/hylalalala/article/details/103914152