fastText 另外两种安装方式
conda install 方式:速度慢
https://anaconda.org/conda-forge/fasttext
windows 版本下可以通过whl安装(fasttext‑0.9.1‑cp36‑cp36m‑win32.whl) ,windows 下可以使用这个安装
https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext
fastText 训练
import fastText
import fastText
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix,precision_recall_fscore_support
# 训练
'''
dtrain.txt 和dtest.txt 数据格式 如下:
__label__2 中新网 日电 日前 上海 国际
__label__0 两人 被捕 警方 指控 非法
__label__3 中旬 航渡 过程 美军 第一
__label__1 强强 联手 背后 品牌 用户 双赢
'''
model = fastText.train_supervised(
'../data/dtrain.txt',
lr=0.1,
dim=200,
epoch=50,
neg=5,
wordNgrams=2,
label="__label__"
)
# 预测
result = model.test('../data/dtest.txt')
print('y_pred = ',y_pred)
# 保存model
model_path = '../model/fastText_model.pkl'
model.save_model(model_path)
# 计算分类的metrics
#绘制precision、recall、f1-score、support报告表
def eval_model(y_true, y_pred, labels):
# 计算每个分类的Precision, Recall, f1, support
p, r, f1, s = precision_recall_fscore_support(y_true, y_pred)
# 计算总体的平均Precision, Recall, f1, support
tot_p = np.average(p, weights=s)
tot_r = np.average(r, weights=s)
tot_f1 = np.average(f1, weights=s)
tot_s = np.sum(s)
res1 = pd.DataFrame({
u'Label': labels,
u'Precision': p,
u'Recall': r,
u'F1': f1,
u'Support': s
})
res2 = pd.DataFrame({
u'Label': ['总体'],
u'Precision': [tot_p],
u'Recall': [tot_r],
u'F1': [tot_f1],
u'Support': [tot_s]
})
res2.index = [999]
res = pd.concat([res1, res2])
return res[['Label', 'Precision', 'Recall', 'F1', 'Support']]
cate_dic = {'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
dict_cate = dict(('__label__{}'.format(v),k) for k,v in cate_dic.items())
y_true= []
y_pred = []
with open('../data/dtest.txt','r',encoding='utf-8') as f:
for line in f.readlines():
line = line.strip()
splits = line.split(" ")
label = splits[0]
words = [" ".join(splits[1:])]
label = dict_cate[label]
y_true.append(label)
y_pred_results = clf.predict(words)[0][0][0]
y_pred.append(dict_cate[y_pred_results])
print("y_true = ",y_true[:5])
print("y_pred = ",y_pred[:5])
print('y_true length = ',len(y_true))
print('y_pred length = ',len(y_pred))
print('keys = ',list(cate_dic.keys()))
y_true = ['sports', 'car', 'car', 'technology', 'entertainment']
y_pred = ['sports', 'car', 'car', 'technology', 'entertainment']
y_true length = 87581
y_pred length = 87581
keys = ['entertainment', 'technology', 'sports', 'military', 'car']
eval_model(y_true,y_pred,list(cate_dic.keys()))
Label | Precision | Recall | F1 | Support | |
---|---|---|---|---|---|
0 | entertainment | 0.934803 | 0.827857 | 0.878086 | 8400 |
1 | technology | 0.906027 | 0.923472 | 0.914666 | 26696 |
2 | sports | 0.881885 | 0.911727 | 0.896558 | 11555 |
3 | military | 0.943886 | 0.931749 | 0.937778 | 22476 |
4 | car | 0.857226 | 0.873252 | 0.865165 | 18454 |
999 | 总体 | 0.905035 | 0.904294 | 0.904270 | 87581 |
模拟在线预测
# 加载模型
model_path = '../model/fastText_model.pkl'
clf = fastText.load_model(model_path)
cate_dic = {'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
print(cate_dic)
dict_cate = dict(('__label__{}'.format(v),k) for k,v in cate_dic.items())
print(dict_cate)
{'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
{'__label__0': 'entertainment', '__label__1': 'technology', '__label__2': 'sports', '__label__3': 'military', '__label__4': 'car'}
- 预测案例1-汽车类
摘自今日头条: https://www.toutiao.com/a6714271125473346055/
import jieba
text = "奥迪A3、宝马1系和奔驰A级一直纠缠不休的三个冤家"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
# predict
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words = ['奥迪', 'A3', '、', '宝马', '1', '系', '和', '奔驰', 'A', '级', '一直', '纠缠', '不休', '的', '三个', '冤家']
y_pred results = car
- 预测案例2-军事类
摘自今日头条新闻: https://www.toutiao.com/a6714188329937535496/
import jieba
text = "谁说文物只能躺在博物馆,想买一架梦想中的战斗机开着兜风吗?"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
# predict
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words = ['谁', '说', '文物', '只能', '躺', '在', '博物馆', ',', '想', '买', '一架', '梦想', '中', '的', '战斗机', '开着', '兜风', '吗', '?']
y_pred results = military
- 预测案例3-娱乐类
我们从 今日头条: https://www.toutiao.com/a6689675139333751299/ 拷贝标题来进行预测
import jieba
text = "陈晓旭:从完美林黛玉到身家过亿后剃度出家,她戏里戏外都是传奇"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words = ['陈晓旭', ':', '从', '完美', '林黛玉', '到', '身家', '过', '亿后', '剃度', '出家', ',', '她', '戏里', '戏外', '都', '是', '传奇']
y_pred results = entertainment
- 预测案例4-体育类
摘自今日头条:https://www.toutiao.com/a6714266792253981192/
import jieba
text = "男女有别!国乒主力参加马来西亚T2联赛 男队站着吃自助女队吃桌餐"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words = ['男女有别', '!', '国乒', '主力', '参加', '马来西亚', 'T2', '联赛', ' ', '男队', '站', '着', '吃', '自助', '女队', '吃', '桌餐']
y_pred results = sports
- 预测案例5-科技类
import jieba
text = "摩托罗拉One Macro将是最新一款Android One智能手机"
words = [word for word in jieba.lcut(text)]
print('words = ',words)
data = " ".join(words)
results = clf.predict([data])
y_pred = results[0][0][0]
print("y_pred results = ",dict_cate[y_pred])
words = ['摩托罗拉', 'One', ' ', 'Macro', '将', '是', '最新', '一款', 'Android', ' ', 'One', '智能手机']
y_pred results = technology
Flask Web 服务在线预测
http://127.0.0.1:5000/v1/p?q=xxxxx
其中: q 是要预测的样本
# -*- coding: UTF-8 -*-
import jieba
import fastText
from flask import Flask
from flask import request
app = Flask(__name__)
model_path = '../model/fastText_model.pkl'
clf = fastText.load_model(model_path)
cate_dic = {'entertainment': 0, 'technology': 1, 'sports': 2, 'military': 3, 'car': 4}
dict_cate = dict(('__label__{}'.format(v), k) for k, v in cate_dic.items())
print(dict_cate)
@app.route('/')
def hello_world():
return 'Hello World!'
@app.route('/v1/p', methods=['POST', 'GET'])
def predict():
if request.method == 'POST':
q = request.form['q']
else:
q = request.args.get('q', '')
print('q = ', q)
print('input data:', q)
words = [word for word in jieba.lcut(q)]
print('words = ', words)
data = " ".join(words)
results = clf.predict([data])
y_pred = results[0][0][0]
return dict_cate[y_pred]
if __name__ == '__main__':
app.run()
来源:https://blog.csdn.net/shenfuli/article/details/98882655