文本分类模型
一、fastText
https://fasttext.cc/docs/en/unsupervised-tutorial.html
fastText模型架构:
其中x1,x2,…,xN−1,xN表示一个文本中的n-gram向量,每个特征是词向量的平均值。这和前文中提到的cbow相似,cbow用上下文去预测中心词,而此处用全部的n-gram去预测指定类别
代码如下,只能在linux环境运行:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# -*- coding:utf-8 -*-
import pandas as pd
import random
import fasttext
import jieba
from sklearn.model_selection import train_test_split
import os
"""
函数说明:加载数据
"""
def loadData():
#利用pandas把数据读进来
df_military = pd.read_csv("./data/junshi.csv",encoding ="utf-8")
df_military=df_military.dropna()
df_sports = pd.read_csv("./data/sports.csv",encoding ="utf-8")
df_sports=df_sports.dropna()
military=df_military.values.tolist()[:20000]
sports=df_sports.values.tolist()[:20000]
return military,sports
"""
函数说明:停用词
"""
def getStopWords(datapath):
stopwords=pd.read_csv(datapath,index_col=False,quoting=3,sep="\t",names=['stopword'], encoding='utf-8')
stopwords=stopwords["stopword"].values
return stopwords
"""
函数说明:数据准备
"""
def preprocess_text(content_line,sentences,category,stopwords):
for line in content_line:
try:
segs=jieba.lcut(str(line)) #利用结巴分词进行中文分词
segs=filter(lambda x:len(x)>1,segs) #去掉长度小于1的词
segs=filter(lambda x:x not in stopwords,segs) #去掉停用词
sentences.append("__label__"+str(category)+" , "+" ".join(segs)) #把当前的文本和对应的类别拼接起来,组合成fasttext的文本格式
except Exception as e:
print (line)
continue
"""
函数说明:把处理好的写入到文件中,备用
"""
def writeData(sentences,fileName):
print("writing data to fasttext format...")
out=open(fileName,'w',encoding ="utf-8")
for sentence in sentences:
out.write(sentence+"\n")
print("done!")
"""
函数说明:数据处理
"""
def preprocessData(stopwords,saveDataFile):
military,sports=loadData()
# 去停用词,生成数据集
sentences=[]
# preprocess_text(technology,sentences,cate_dic["technology"],stopwords)
# preprocess_text(car,sentences,cate_dic["car"],stopwords)
preprocess_text(military,sentences,cate_dic["military"],stopwords)
preprocess_text(sports,sentences,cate_dic["sports"],stopwords)
random.shuffle(sentences) #做乱序处理,使得同类别的样本不至于扎堆
writeData(sentences,saveDataFile)
if __name__=="__main__":
# 类别标签
cate_dic = {'military': 1, 'sports': 2}
# 数据处理
stopwordsFile = r"./data/stopwords.txt"
stopwords = getStopWords(stopwordsFile)
saveDataFile = r'train_data.txt'
preprocessData(stopwords,saveDataFile)
# 训练模型
# fasttext.supervised():有监督的学习
filename = 'classifier.model.bin'
# 判断是否已经存在模型,如果存在则加载,不存在则进行训练
if (os.path.exists(filename)):
classifier = fasttext.load_model(filename, label_prefix='__label__')
else:
# 训练模型
classifier = fasttext.supervised(saveDataFile, 'classifier.model', label_prefix='__label__')
# classifier=fasttext.supervised(saveDataFile,'classifier.model',label_prefix='__label__')
result = classifier.test(saveDataFile)
print("P@1:",result.precision) # 准确率
print("R@2:",result.recall) # 召回率
print("Number of examples:",result.nexamples) # 预测错的例子
# 实际预测
label_to_cate={1:'military',2:'sports'}
texts=['中新网 日电 2018 预赛 亚洲区 强赛 中国队 韩国队 较量 比赛 上半场 分钟 主场 作战 中国队 率先 打破 场上 僵局 利用 角球 机会 大宝 前点 攻门 得手 中国队 领先']
labels = classifier.predict(texts)
print(labels)
print(label_to_cate[int(labels[0][0])])
# 可以得到类别+概率
labels=classifier.predict_proba(texts)
print(labels)
# 可以得到前k个类别
labels=classifier.predict(texts,k=3)
print(labels)
# 可以得到前k个类别+概率
labels=classifier.predict_proba(texts,k=3)
print(labels)
二、BERT文本分类
https://github.com/xmxoxo/BERT-train2deploy
export BERT_BASE_DIR=/home/chenz/BERT/chinese_L-12_H-768_A-12
export GLUE_DIR=/home/chenz/BERT-train2deploy/dat
export TRAINED_CLASSIFIER=/home/chenz/BERT-train2deploy/output
# 训练
python3 run_mobile.py \
--task_name=setiment \
--do_train=true \
--do_eval=true \
--data_dir=$GLUE_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=32 \
--train_batch_size=16 \
--learning_rate=2e-5 \
--num_train_epochs=5.0 \
--output_dir=$TRAINED_CLASSIFIER
# 测试
python3 run_mobile.py \
--task_name=setiment \
--do_predict=true \
--data_dir=$GLUE_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=32 \
--output_dir=$TRAINED_CLASSIFIER
model_dir 就是训练好的.ckpt文件所在的目录
max_seq_len 要与原来一致;
num_labels 是分类标签的个数,本例中是3个
# 模型文件转化
python3 freeze_graph.py \
-bert_model_dir $BERT_BASE_DIR \
-model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
-max_seq_len 32 \
-num_labels 18
# 启动bert
bert-base-serving-start \
-model_dir $TRAINED_CLASSIFIER \
-bert_model_dir $BERT_BASE_DIR \
-model_pb_dir $TRAINED_CLASSIFIER \
-mode CLASS \
-max_seq_len 32 \
-http_port 8091 \
-port 65331 \
-port_out 65332 \
-device_map 1
# 调用
from bert_base.client import BertClient
classifier = BertClient("192.168.1.158", 65331, 65332, show_server_config=False, check_version=False, check_length=False, mode='CLASS')
classifier.encode(["农村人的歌,我家在农村,非常喜欢的一首歌,这次我把一整首都唱完了,你也听听看"])
当然BERT分类效果更好。
来源:CSDN
作者:卓玛cug
链接:https://blog.csdn.net/qq_29153321/article/details/104022840