文本分类模型的几种方法介绍及比较

若如初见. 提交于 2020-01-18 00:08:33

文本分类模型

一、fastText
https://fasttext.cc/docs/en/unsupervised-tutorial.html
fastText模型架构:
其中x1,x2,…,xN−1,xN表示一个文本中的n-gram向量,每个特征是词向量的平均值。这和前文中提到的cbow相似,cbow用上下文去预测中心词,而此处用全部的n-gram去预测指定类别
在这里插入图片描述
代码如下,只能在linux环境运行:

#!/usr/bin/python
# -*- coding: UTF-8 -*-

# -*- coding:utf-8 -*-
import pandas as pd
import random
import fasttext
import jieba
from sklearn.model_selection import train_test_split
import os


"""
函数说明:加载数据
"""
def loadData():

    #利用pandas把数据读进来
    df_military = pd.read_csv("./data/junshi.csv",encoding ="utf-8")
    df_military=df_military.dropna()

    df_sports = pd.read_csv("./data/sports.csv",encoding ="utf-8")
    df_sports=df_sports.dropna()

    military=df_military.values.tolist()[:20000]
    sports=df_sports.values.tolist()[:20000]

    return military,sports

"""
函数说明:停用词
"""
def getStopWords(datapath):
    stopwords=pd.read_csv(datapath,index_col=False,quoting=3,sep="\t",names=['stopword'], encoding='utf-8')
    stopwords=stopwords["stopword"].values
    return stopwords

"""
函数说明:数据准备
"""
def preprocess_text(content_line,sentences,category,stopwords):
    for line in content_line:
        try:
            segs=jieba.lcut(str(line))    #利用结巴分词进行中文分词
            segs=filter(lambda x:len(x)>1,segs)    #去掉长度小于1的词
            segs=filter(lambda x:x not in stopwords,segs)    #去掉停用词
            sentences.append("__label__"+str(category)+" , "+" ".join(segs))    #把当前的文本和对应的类别拼接起来,组合成fasttext的文本格式
        except Exception as e:
            print (line)
            continue

"""
函数说明:把处理好的写入到文件中,备用
"""
def writeData(sentences,fileName):
    print("writing data to fasttext format...")
    out=open(fileName,'w',encoding ="utf-8")
    for sentence in sentences:
        out.write(sentence+"\n")
    print("done!")

"""
函数说明:数据处理
"""
def preprocessData(stopwords,saveDataFile):
    military,sports=loadData()

    # 去停用词,生成数据集
    sentences=[]
    # preprocess_text(technology,sentences,cate_dic["technology"],stopwords)
    # preprocess_text(car,sentences,cate_dic["car"],stopwords)
    preprocess_text(military,sentences,cate_dic["military"],stopwords)
    preprocess_text(sports,sentences,cate_dic["sports"],stopwords)

    random.shuffle(sentences)    #做乱序处理,使得同类别的样本不至于扎堆

    writeData(sentences,saveDataFile)

if __name__=="__main__":
    # 类别标签
    cate_dic = {'military': 1, 'sports': 2}

    # 数据处理
    stopwordsFile = r"./data/stopwords.txt"
    stopwords = getStopWords(stopwordsFile)
    saveDataFile = r'train_data.txt'
    preprocessData(stopwords,saveDataFile)

    # 训练模型
    # fasttext.supervised():有监督的学习
    filename = 'classifier.model.bin'
    # 判断是否已经存在模型,如果存在则加载,不存在则进行训练
    if (os.path.exists(filename)):
        classifier = fasttext.load_model(filename, label_prefix='__label__')
    else:
        # 训练模型
        classifier = fasttext.supervised(saveDataFile, 'classifier.model', label_prefix='__label__')
    # classifier=fasttext.supervised(saveDataFile,'classifier.model',label_prefix='__label__')
    result = classifier.test(saveDataFile)
    print("P@1:",result.precision)    # 准确率
    print("R@2:",result.recall)    # 召回率
    print("Number of examples:",result.nexamples)    # 预测错的例子

    # 实际预测
    label_to_cate={1:'military',2:'sports'}

    texts=['中新网 日电 2018 预赛 亚洲区 强赛 中国队 韩国队 较量 比赛 上半场 分钟 主场 作战 中国队 率先 打破 场上 僵局 利用 角球 机会 大宝 前点 攻门 得手 中国队 领先']
    labels = classifier.predict(texts)
    print(labels)
    print(label_to_cate[int(labels[0][0])])

    # 可以得到类别+概率
    labels=classifier.predict_proba(texts)
    print(labels)

    # 可以得到前k个类别
    labels=classifier.predict(texts,k=3)
    print(labels)

    # 可以得到前k个类别+概率
    labels=classifier.predict_proba(texts,k=3)
    print(labels)

二、BERT文本分类
https://github.com/xmxoxo/BERT-train2deploy

export BERT_BASE_DIR=/home/chenz/BERT/chinese_L-12_H-768_A-12
export GLUE_DIR=/home/chenz/BERT-train2deploy/dat
export TRAINED_CLASSIFIER=/home/chenz/BERT-train2deploy/output

# 训练
python3 run_mobile.py \
  --task_name=setiment \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=32 \
  --train_batch_size=16 \
  --learning_rate=2e-5 \
  --num_train_epochs=5.0 \
  --output_dir=$TRAINED_CLASSIFIER
  
# 测试
python3 run_mobile.py \
  --task_name=setiment \
  --do_predict=true \
  --data_dir=$GLUE_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=32 \
  --output_dir=$TRAINED_CLASSIFIER


model_dir 就是训练好的.ckpt文件所在的目录
max_seq_len 要与原来一致;
num_labels 是分类标签的个数,本例中是3个
# 模型文件转化
python3 freeze_graph.py \
    -bert_model_dir $BERT_BASE_DIR \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -max_seq_len 32 \
    -num_labels 18


# 启动bert
bert-base-serving-start \
     -model_dir $TRAINED_CLASSIFIER \
     -bert_model_dir $BERT_BASE_DIR \
     -model_pb_dir $TRAINED_CLASSIFIER \
     -mode CLASS \
     -max_seq_len 32 \
     -http_port 8091 \
     -port 65331 \
    -port_out 65332 \
     -device_map 1


# 调用
from bert_base.client import BertClient
classifier = BertClient("192.168.1.158", 65331, 65332, show_server_config=False, check_version=False, check_length=False, mode='CLASS')
classifier.encode(["农村人的歌,我家在农村,非常喜欢的一首歌,这次我把一整首都唱完了,你也听听看"])

当然BERT分类效果更好。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!