在程序化广告投放中,一个优秀的CTR预测算法会给广告主、Adx以及用户都将带来好处。Google公司2013在《ResearchGate》上发表了一篇“Ad click prediction: a view from the trenches”论文,这篇论文是基于FTRL的在线CTR预测算法,下面将讲解该算法的主要思想以及Java实现。
什么是Online Learning
传统的批量算法的每次迭代是对全体训练数据集进行计算(例如计算全局梯度),优点是精度和收敛还可以,缺点是无法有效处理大数据集(此时全局梯度计算代价太大),且没法应用于数据流做在线学习。而在线学习算法的特点是:每来一个训练样本,就用该样本产生的loss和梯度对模型迭代一次,一个一个数据地进行训练,因此可以处理大数据量训练和在线训练。准确地说,Online Learning并不是一种模型,而是一种模型的训练方法,Online Learning能够根据线上反馈数据,实时快速地进行模型调整,使得模型及时反映线上的变化,提高线上预测的准确率。Online Learning的流程包括:将模型的预测结果展现给用户,然后收集用户的反馈数据,再用来训练模型,形成闭环的系统。如下图所示:
这篇论文提出的基于FTRL的在线CTR预测算法,就是一种Online Learning算法。即,针对每一个训练样本,首先通过一种方式进行预测,然后再利用一种损失函数进行误差评估,最后再通过所评估的误差值对参数进行更新迭代。直到所有样本全部遍历完,则结束。那么,如何选择模型预测方法、评估指标以及模型更新公式就是该算法的重点所在。下面将介绍论文中这三部分内容:
预测方法:在每一轮
t 中,针对特征样本xt∈Rd ,以及迭代后(第一此则是给定初值)的模型参数wt ,我们可以预测该样本的标记值:pt=σ(wt,xt) ,其中σ(a)=1/(1+exp(−a)) 是一个sigmoid函数。损失函数:对一个特征样本
xt ,其对应的标记为yt∈0,1 ,则通过LogLoss(logistic loss)来作为损失函数,即:lt(wt)=−ytlogpt−(1−yt)log(1−pt) 迭代公式:我们的目的是使得损失函数尽可能的小,即可以采用极大似然估计来求解参数。首先求梯度
gt=dltdw=(σ(w∗xt)−yt)xt=(pt−yt)xt ,使用FTRL进行迭代:
其中,σs 为学习率且σ1:t=1nt ,g1:t=∑ts=1gt ,λ1 为正则化参数。该最优化公式可以化简为:
则,如果我们令zt−1=g1:t−1−∑t−1s=1σsws ,则在第t 轮迭代前,令zt=zt−1+gt−(1nt−1nt−1)wt (此处和论文中的公式不一致,我觉得应该是减去最后一项,而不是加,作者在后面伪代码中也改成了减,故此处可能是作者笔误)
下面令梯度为0,则可以得到该优化问题的解析解:
到此就叙述完该算法的理论部分了,我想大部分人对这部分也不太感兴趣吧,下面直接上伪代码和Java实现吧(过程和理论部分其实是一致的,嘿嘿,想深入的还是研究下理论部分吧):
基于FTRL的在线CTR预测算法的Java实现
模型参数类
package DataClass;
public class FTRLParameters {
public double alpha;//学习速率参数
public double beta;//调整参数,值为1时效果较好,无需调整
public double L1_lambda;//L1范式参数
public double L2_lambda;//L2范式参数
public int dataDimensions;//数据特征维度数
public int testDataSize;//测试集分次处理每次处理的个数
public int interval;//每间隔interval进行一次打印
public String modelPath;//模型训练参数的存放路径
public FTRLParameters(double alpha, double beta,
double L1, double L2, int dataDimensions,int testDataSize, int interval,String modelPath) {
this.alpha = alpha;
this.beta = beta;
this.L1_lambda = L1;
this.L2_lambda = L2;
this.dataDimensions = dataDimensions;
this.testDataSize = testDataSize;
this.interval = interval;
this.modelPath = modelPath;
}
}
模型训练类
package model;
import DataPreprocessing.FileOperation;
import evaluate.LogLossEvalutor;
import java.io.*;
import java.util.Map;
import java.util.TreeMap;
public class FTRLLocalTrain {
private FTRLProximal learner;
private FTRLModelLoad mload;
private LogLossEvalutor evalutor;
private int printInterval;
public FTRLLocalTrain(FTRLModelLoad mload, FTRLProximal learner, LogLossEvalutor evalutor, int interval) {
this.mload = mload;
this.learner = learner;
this.evalutor = evalutor;
this.printInterval = interval;
}
/**
* 训练方法
* */
public void train(String modelPath,double[][] X,double[] Y) throws IOException {
int trainedNum = 0;
double totalLoss = 0.0;//损失值
long startTime = System.currentTimeMillis();
BufferedReader mp = new BufferedReader(new InputStreamReader(new FileInputStream(new File(modelPath)), "UTF-8"));
while((line = mp.readLine())!=null){
learner.loadModel(modelPath);
}
for(int j=0;j<X.length;j++){
Map<Integer, Double> x = new TreeMap<Integer, Double>();
for (int i = 0; i < X[0].length; i++) {
x.put(i, X[j][i]);
}
double y = ((int)Y[j] == 1) ? 1. : 0.;
double p = learner.predict(x);
learner.updateModel(x, p, y);
double loss = LogLossEvalutor.calLogLoss(p, y);
evalutor.addLogLoss(loss);
totalLoss += loss;
trainedNum += 1;
if (trainedNum % printInterval == 0) {
long currentTime = System.currentTimeMillis();
double minutes = (double) (currentTime - startTime) / 60000;
System.out.printf("%.3f, %.5f\n", minutes, evalutor.getAverageLogLoss());
}
}
learner.saveModel(modelPath);
System.out.printf("global average loss: %.5f\n", totalLoss / trainedNum);
}
}
模型更新类
package model;
import DataClass.FTRLParameters;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
public class FTRLProximal {
// parameters->alpha, beta, l1, l2, dimensions
private FTRLParameters parameters;
// n->squared sum of past gradients
public double[] n;
// z->weights
public double[] z;
// w->lazy weights
public Map<Integer, Double> w;
public double[] n_;
public double[] z_;
public Map<Integer, Double> w_;
public FTRLProximal(FTRLParameters parameters) {
this.parameters = parameters;
this.n = new double[parameters.dataDimensions];
this.z = new double[parameters.dataDimensions];
this.w = null;
}
/** x->p(y=1|x; w) , get w, nothing is changed*/
public double predict(Map<Integer, Double> x) {
w = new HashMap<Integer, Double>();
double decisionValue = 0.0;
for (Entry<Integer, Double> e : x.entrySet()) {
double sgn = sign(z[e.getKey()]);
double weight = 0.0;
if (sgn * z[e.getKey()] <= parameters.L1_lambda) {
w.put(e.getKey(), weight);
} else {
weight = (sgn * parameters.L1_lambda - z[e.getKey()])
/ ((parameters.beta + Math.sqrt(n[e.getKey()]))
/ parameters.alpha + parameters.L2_lambda);
w.put(e.getKey(), weight);
}
decisionValue += e.getValue() * weight;
}
decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);
return 1. / (1. + Math.exp(-decisionValue));
}
/** input: sample x, probability p, label y(-1(or 0) or 1)
* used: w
* update: n, z*/
public void updateModel(Map<Integer, Double> x, double p, double y) {
for(Entry<Integer, Double> e : x.entrySet()) {
double grad = p * e.getValue();
if(y == 1.0) {
grad = (p - y) * e.getValue();
}
double sigma = (Math.sqrt(n[e.getKey()] + grad * grad) -
Math.sqrt(n[e.getKey()])) / parameters.alpha;
z[e.getKey()] += (grad - sigma * w.get(e.getKey()));
n[e.getKey()] += grad * grad;
}
}
/**
* N、Z、W
* 模型参数保存函数
* */
public void saveModel(String filePath) throws IOException {
String n_=String.valueOf(n[0]);
String z_=String.valueOf(z[0]);
String w_=String.valueOf(w.get(0));
for(int i=1;i<n.length;i++){
n_ = n_+" "+String.valueOf(n[i]);
z_ = z_+" "+String.valueOf(z[i]);
w_ = w_+" "+String.valueOf(w.get(i));
}
try{
File file = new File(filePath);
if(!file.exists()){
file.createNewFile();
}
FileWriter fileWriter = new FileWriter(filePath);
BufferedWriter bufferWriter = new BufferedWriter(fileWriter);
bufferWriter.write(n_+"\r\n");
bufferWriter.write(z_+"\r\n");
bufferWriter.write(w_);
bufferWriter.close();
System.out.print("Done");
}catch (IOException e){
e.printStackTrace();
}
}
public void loadModel(String filePath) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8"));
String line = null;
String[][] Str = new String[3][];
int i = 0;
while((line = br.readLine()) != null) {
Str[i] = line.split(" ");
i++;
}
n = new double[n.length];
z = new double[z.length];
w = new HashMap<Integer, Double>();
for(int j=0;j<n.length;j++){
n[j] = Double.valueOf(Str[0][j]);
z[j] = Double.valueOf(Str[1][j]);
w.put(j,Double.valueOf(Str[2][j]));
}
}
public double predict_(Map<Integer, Double> x) {
double decisionValue = 0.0;
for (Entry<Integer, Double> e : x.entrySet()) {
decisionValue += e.getValue() * w_.get(e.getKey());
}
decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);
return 1. / (1. + Math.exp(-decisionValue));
}
private double sign(double x) {
if (x > 0) {
return 1.0;
} else if (x < 0) {
return -1.0;
} else {
return 0.0;
}
}
}
模型预测类
package model;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
/**
* 模型下载与预测方法
* n、z、w为需下载的模型参数
*/
public class FTRLModelLoad {
public double[] n;
public double[] z;
public Map<Integer, Double> w;
/**
* 模型下载方法
* 输入:模型文件所在路径
* 功能:算法全局参数更新
* */
public Map<Integer, Double> loadModel(String filePath) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8"));
String line = null;
String[][] Str = new String[3][];
int i = 0;
while((line = br.readLine()) != null) {
Str[i] = line.split(" ");
i++;
}
n = new double[Str[0].length];
z = new double[Str[0].length];
w = new HashMap<Integer, Double>();
for(int j=0;j<Str[0].length;j++){
n[j] = Double.valueOf(Str[0][j]);
z[j] = Double.valueOf(Str[1][j]);
w.put(j,Double.valueOf(Str[2][j]));
}
return w;
}
/**
* 预测函数
* */
public double predict_(double[] x_,Map<Integer,Double> w) {
Map<Integer,Double> x = new TreeMap<Integer, Double>();
for(int i=0;i<x_.length;i++){
x.put(i,x_[i]);
}
double decisionValue = 0.0;
for (Map.Entry<Integer, Double> e : x.entrySet()) {
decisionValue += e.getValue() * w.get(e.getKey());
}
decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);
return 1. / (1. + Math.exp(-decisionValue));
}
}
损失函数类
package evaluate;
public class LogLossEvalutor {
private int testDataSize;
private double[] logloss;
private int position;
private double totalLoss;
private boolean enoughData;
public LogLossEvalutor(int testDataSize) {
this.testDataSize = testDataSize;
logloss = new double[testDataSize];
position = 0;
totalLoss = 0.0;
}
public void addLogLoss(double loss) {
totalLoss = totalLoss + loss - logloss[position];
logloss[position] = loss;
position += 1;
if(position >= testDataSize) {
position = 0;
enoughData = true;
}
}
public double getAverageLogLoss() {
if(enoughData) {
return totalLoss / testDataSize;
} else {
return totalLoss / position;
}
}
/** prob: p(y=1|x;w), y: 1 or 0(-1) */
public static double calLogLoss(double prob, double y) {
//预测值范围控制方法
double p = Math.max(Math.min(prob, 1-1e-15), 1e-15);
return y == 1.? -Math.log(p) : -Math.log(1. - p);
}
public static void main(String[] args) {
LogLossEvalutor evalutor = new LogLossEvalutor(4);
double[] losses = {3, 2, 1, 0.7, 0.5, 0.2};
for(int i=0; i<losses.length; i++) {
evalutor.addLogLoss(losses[i]);
System.out.println(evalutor.getAverageLogLoss());
}
}
}
下面为程序的一个测试结果图:
参考文献
(1) Ad Click Prediction: a View from the Trenches.H. Brendan McMahan, Gary Holt, D. Sculley et al
(2)美团技术团队《Online Learning算法理论与实践》
来源:CSDN
作者:yz930618
链接:https://blog.csdn.net/yz930618/article/details/75270869