lightgbm用于排序

柔情痞子 提交于 2020-05-07 11:47:42

一. 

  LTR(learning to rank)经常用于搜索排序中,开源工具中比较有名的是微软的ranklib,但是这个好像是单机版的,也有好长时间没有更新了。所以打算想利用lightgbm进行排序,但网上关于lightgbm用于排序的代码很少,关于回归和分类的倒是一堆。这里我将贴上python版的lightgbm用于排序的代码,里面将包括训练、获取叶结点、ndcg评估、预测以及特征重要度等处理代码,有需要的朋友可以参考一下或进行修改。

  其实在使用时,本人也对比了ranklib中的lambdamart和lightgbm,令人映像最深刻的是lightgbm的训练速度非常快,快的起飞。可能lambdamart训练需要几个小时,而lightgbm只需要几分钟,但是后面的ndcg测试都差不多,不像论文中所说的lightgbm精度高一点。lightgbm的训练速度快,我想可能最大的原因要可能是:a.节点分裂用到了直方图,而不是预排序方法;b.基于梯度的单边采样,即行采样;c.互斥特征绑定,即列采样;d.其于leaf-wise决策树生长策略;e.类别特征的支持等

二.代码

第一部分代码块是主代码,后面三个代码块是用到的加载数据和ndcg。运行主代码使用命令如训练模型使用:python lgb.py -train等

完成代码和数据格式放在https://github.com/jiangnanboy/learning_to_rank上面,大家可以参考一下!!!!!

  1 import os
  2 import lightgbm as lgb
  3 from sklearn import datasets as ds
  4 import pandas as pd
  5 
  6 import numpy as np
  7 from datetime import datetime
  8 import sys
  9 from sklearn.preprocessing import OneHotEncoder
 10 
 11 def split_data_from_keyword(data_read, data_group, data_feats):
 12     '''
 13     利用pandas
 14     转为lightgbm需要的格式进行保存
 15     :param data_read:
 16     :param data_save:
 17     :return:
 18     '''
 19     with open(data_group, 'w', encoding='utf-8') as group_path:
 20         with open(data_feats, 'w', encoding='utf-8') as feats_path:
 21             dataframe = pd.read_csv(data_read,
 22                                     sep=' ',
 23                                     header=None,
 24                                     encoding="utf-8",
 25                                     engine='python')
 26             current_keyword = ''
 27             current_data = []
 28             group_size = 0
 29             for _, row in dataframe.iterrows():
 30                 feats_line = [str(row[0])]
 31                 for i in range(2, len(dataframe.columns) - 1):
 32                     feats_line.append(str(row[i]))
 33                 if current_keyword == '':
 34                     current_keyword = row[1]
 35                 if row[1] == current_keyword:
 36                     current_data.append(feats_line)
 37                     group_size += 1
 38                 else:
 39                     for line in current_data:
 40                         feats_path.write(' '.join(line))
 41                         feats_path.write('\n')
 42                     group_path.write(str(group_size) + '\n')
 43 
 44                     group_size = 1
 45                     current_data = []
 46                     current_keyword = row[1]
 47                     current_data.append(feats_line)
 48 
 49             for line in current_data:
 50                 feats_path.write(' '.join(line))
 51                 feats_path.write('\n')
 52             group_path.write(str(group_size) + '\n')
 53 
 54 def save_data(group_data, output_feature, output_group):
 55     '''
 56     group与features分别进行保存
 57     :param group_data:
 58     :param output_feature:
 59     :param output_group:
 60     :return:
 61     '''
 62     if len(group_data) == 0:
 63         return
 64     output_group.write(str(len(group_data)) + '\n')
 65     for data in group_data:
 66         # 只包含非零特征
 67         # feats = [p for p in data[2:] if float(p.split(":")[1]) != 0.0]
 68         feats = [p for p in data[2:]]
 69         output_feature.write(data[0] + ' ' + ' '.join(feats) + '\n') # data[0] => level ; data[2:] => feats
 70 
 71 def process_data_format(test_path, test_feats, test_group):
 72     '''
 73      转为lightgbm需要的格式进行保存
 74      '''
 75     with open(test_path, 'r', encoding='utf-8') as fi:
 76         with open(test_feats, 'w', encoding='utf-8') as output_feature:
 77             with open(test_group, 'w', encoding='utf-8') as output_group:
 78                 group_data = []
 79                 group = ''
 80                 for line in fi:
 81                     if not line:
 82                         break
 83                     if '#' in line:
 84                         line = line[:line.index('#')]
 85                     splits = line.strip().split()
 86                     if splits[1] != group: # qid => splits[1]
 87                         save_data(group_data, output_feature, output_group)
 88                         group_data = []
 89                     group = splits[1]
 90                     group_data.append(splits)
 91                 save_data(group_data, output_feature, output_group)
 92 
 93 def load_data(feats, group):
 94     '''
 95     加载数据
 96     分别加载feature,label,query
 97     '''
 98     x_train, y_train = ds.load_svmlight_file(feats)
 99     q_train = np.loadtxt(group)
100     return x_train, y_train, q_train
101 
102 def load_data_from_raw(raw_data):
103     with open(raw_data, 'r', encoding='utf-8') as testfile:
104         test_X, test_y, test_qids, comments = letor.read_dataset(testfile)
105     return test_X, test_y, test_qids, comments
106 
107 def train(x_train, y_train, q_train, model_save_path):
108     '''
109     模型的训练和保存
110     '''
111     train_data = lgb.Dataset(x_train, label=y_train, group=q_train)
112     params = {
113         'task': 'train',  # 执行的任务类型
114         'boosting_type': 'gbrt',  # 基学习器
115         'objective': 'lambdarank',  # 排序任务(目标函数)
116         'metric': 'ndcg',  # 度量的指标(评估函数)
117         'max_position': 10,  # @NDCG 位置优化
118         'metric_freq': 1,  # 每隔多少次输出一次度量结果
119         'train_metric': True,  # 训练时就输出度量结果
120         'ndcg_at': [10],
121         'max_bin': 255,  # 一个整数,表示最大的桶的数量。默认值为 255。lightgbm 会根据它来自动压缩内存。如max_bin=255 时,则lightgbm 将使用uint8 来表示特征的每一个值。
122         'num_iterations': 500,  # 迭代次数
123         'learning_rate': 0.01,  # 学习率
124         'num_leaves': 31,  # 叶子数
125         # 'max_depth':6,
126         'tree_learner': 'serial',  # 用于并行学习,‘serial’: 单台机器的tree learner
127         'min_data_in_leaf': 30,  # 一个叶子节点上包含的最少样本数量
128         'verbose': 2  # 显示训练时的信息
129     }
130     gbm = lgb.train(params, train_data, valid_sets=[train_data])
131     gbm.save_model(model_save_path)
132 
133 def predict(x_test, comments, model_input_path):
134     '''
135     预测得分并排序
136     '''
137     gbm = lgb.Booster(model_file=model_input_path)  # 加载model
138 
139     ypred = gbm.predict(x_test)
140 
141     predicted_sorted_indexes = np.argsort(ypred)[::-1]  # 返回从大到小的索引
142 
143     t_results = comments[predicted_sorted_indexes]  # 返回对应的comments,从大到小的排序
144 
145     return t_results
146 
147 def test_data_ndcg(model_path, test_path):
148     '''
149     评估测试数据的ndcg
150     '''
151     with open(test_path, 'r', encoding='utf-8') as testfile:
152         test_X, test_y, test_qids, comments = letor.read_dataset(testfile)
153 
154     gbm = lgb.Booster(model_file=model_path)
155     test_predict = gbm.predict(test_X)
156 
157     average_ndcg, _ = ndcg.validate(test_qids, test_y, test_predict, 60)
158     # 所有qid的平均ndcg
159     print("all qid average ndcg: ", average_ndcg)
160     print("job done!")
161 
162 def plot_print_feature_importance(model_path):
163     '''
164     打印特征的重要度
165     '''
166     #模型中的特征是Column_数字,这里打印重要度时可以映射到真实的特征名
167     feats_dict = {
168         'Column_0': '特征0名称',
169         'Column_1': '特征1名称',
170         'Column_2': '特征2名称',
171         'Column_3': '特征3名称',
172         'Column_4': '特征4名称',
173         'Column_5': '特征5名称',
174         'Column_6': '特征6名称',
175         'Column_7': '特征7名称',
176         'Column_8': '特征8名称',
177         'Column_9': '特征9名称',
178         'Column_10': '特征10名称',
179     }
180     if not os.path.exists(model_path):
181         print("file no exists! {}".format(model_path))
182         sys.exit(0)
183 
184     gbm = lgb.Booster(model_file=model_path)
185 
186     # 打印和保存特征重要度
187     importances = gbm.feature_importance(importance_type='split')
188     feature_names = gbm.feature_name()
189 
190     sum = 0.
191     for value in importances:
192         sum += value
193 
194     for feature_name, importance in zip(feature_names, importances):
195         if importance != 0:
196             feat_id = int(feature_name.split('_')[1]) + 1
197             print('{} : {} : {} : {}'.format(feat_id, feats_dict[feature_name], importance, importance / sum))
198 
199 def get_leaf_index(data, model_path):
200     '''
201     得到叶结点并进行one-hot编码
202     '''
203     gbm = lgb.Booster(model_file=model_path)
204     ypred = gbm.predict(data, pred_leaf=True)
205 
206     one_hot_encoder = OneHotEncoder()
207     x_one_hot = one_hot_encoder.fit_transform(ypred)
208     print(x_one_hot.toarray()[0])
209 
210 if __name__ == '__main__':
211     model_path = "保存模型的路径"
212 
213     if len(sys.argv) != 2:
214         print("Usage: python main.py [-process | -train | -predict | -ndcg | -feature | -leaf]")
215         sys.exit(0)
216 
217     if sys.argv[1] == '-process':
218         # 训练样本的格式与ranklib中的训练样本是一样的,但是这里需要处理成lightgbm中排序所需的格式
219         # lightgbm中是将样本特征和group分开保存为txt的,什么意思呢,看下面解释
220         '''
221         feats:
222         1 1:0.2 2:0.4 ...
223         2 1:0.2 2:0.4 ...
224         1 1:0.2 2:0.4 ...
225         3 1:0.2 2:0.4 ...
226         group:
227         2
228         4
229         这里group中2表示前2个是一个qid,4表示后两个是一个qid
230         '''
231         raw_data_path = '训练样本集路径'
232         data_feats = '特征保存路径'
233         data_group = 'group保存路径'
234         process_data_format(raw_data_path, data_feats, data_group)
235 
236     elif sys.argv[1] == '-train':
237         # train
238         train_start = datetime.now()
239         data_feats = '特征保存路径'
240         data_group = 'group保存路径'
241         x_train, y_train, q_train = load_data(data_feats, data_group)
242         train(x_train, y_train, q_train, model_path)
243         train_end = datetime.now()
244         consume_time = (train_end - train_start).seconds
245         print("consume time : {}".format(consume_time))
246 
247     elif sys.argv[1] == '-predict':
248         train_start = datetime.now()
249         raw_data_path = '需要预测的数据路径'#格式如ranklib中的数据格式
250         test_X, test_y, test_qids, comments = load_data_from_raw(raw_data_path)
251         t_results = predict(test_X, comments, model_path)
252         train_end = datetime.now()
253         consume_time = (train_end - train_start).seconds
254         print("consume time : {}".format(consume_time))
255 
256     elif sys.argv[1] == '-ndcg':
257         # ndcg
258         test_path = '测试的数据路径'#评估测试数据的平均ndcg
259         test_data_ndcg(model_path, test_path)
260 
261     elif sys.argv[1] == '-feature':
262         plot_print_feature_importance(model_path)
263 
264     elif sys.argv[1] == '-leaf':
265         #利用模型得到样本叶结点的one-hot表示
266         raw_data = '测试数据路径'#
267         with open(raw_data, 'r', encoding='utf-8') as testfile:
268             test_X, test_y, test_qids, comments = letor.read_dataset(testfile)
269         get_leaf_index(test_X, model_path)

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!