基于FTRL的在线CTR预测算法

蓝咒 提交于 2019-12-07 21:56:36

在程序化广告投放中,一个优秀的CTR预测算法会给广告主、Adx以及用户都将带来好处。Google公司2013在《ResearchGate》上发表了一篇“Ad click prediction: a view from the trenches”论文,这篇论文是基于FTRL的在线CTR预测算法,下面将讲解该算法的主要思想以及Java实现。

什么是Online Learning

传统的批量算法的每次迭代是对全体训练数据集进行计算(例如计算全局梯度),优点是精度和收敛还可以,缺点是无法有效处理大数据集(此时全局梯度计算代价太大),且没法应用于数据流做在线学习。而在线学习算法的特点是:每来一个训练样本,就用该样本产生的loss和梯度对模型迭代一次,一个一个数据地进行训练,因此可以处理大数据量训练和在线训练。准确地说,Online Learning并不是一种模型,而是一种模型的训练方法,Online Learning能够根据线上反馈数据,实时快速地进行模型调整,使得模型及时反映线上的变化,提高线上预测的准确率。Online Learning的流程包括:将模型的预测结果展现给用户,然后收集用户的反馈数据,再用来训练模型,形成闭环的系统。如下图所示:

这里写图片描述

这篇论文提出的基于FTRL的在线CTR预测算法,就是一种Online Learning算法。即,针对每一个训练样本,首先通过一种方式进行预测,然后再利用一种损失函数进行误差评估,最后再通过所评估的误差值对参数进行更新迭代。直到所有样本全部遍历完,则结束。那么,如何选择模型预测方法、评估指标以及模型更新公式就是该算法的重点所在。下面将介绍论文中这三部分内容:

  1. 预测方法:在每一轮t中,针对特征样本xtRd,以及迭代后(第一此则是给定初值)的模型参数wt,我们可以预测该样本的标记值:pt=σ(wt,xt),其中σ(a)=1/(1+exp(a))是一个sigmoid函数。

  2. 损失函数:对一个特征样本xt,其对应的标记为yt0,1,则通过LogLoss(logistic loss)来作为损失函数,即: lt(wt)=ytlogpt(1yt)log(1pt)

  3. 迭代公式:我们的目的是使得损失函数尽可能的小,即可以采用极大似然估计来求解参数。首先求梯度 gt=dltdw=(σ(wxt)yt)xt=(ptyt)xt,使用FTRL进行迭代:
    这里写图片描述
    其中,σs为学习率且σ1:t=1ntg1:t=ts=1gtλ1为正则化参数。该最优化公式可以化简为:
    这里写图片描述
    则,如果我们令zt1=g1:t1t1s=1σsws,则在第t轮迭代前,令zt=zt1+gt(1nt1nt1)wt(此处和论文中的公式不一致,我觉得应该是减去最后一项,而不是加,作者在后面伪代码中也改成了减,故此处可能是作者笔误)

下面令梯度为0,则可以得到该优化问题的解析解:
这里写图片描述

到此就叙述完该算法的理论部分了,我想大部分人对这部分也不太感兴趣吧,下面直接上伪代码和Java实现吧(过程和理论部分其实是一致的,嘿嘿,想深入的还是研究下理论部分吧):

这里写图片描述

基于FTRL的在线CTR预测算法的Java实现

模型参数类

package DataClass;

public class FTRLParameters {
    public double alpha;//学习速率参数
    public double beta;//调整参数,值为1时效果较好,无需调整
    public double L1_lambda;//L1范式参数
    public double L2_lambda;//L2范式参数
    public int dataDimensions;//数据特征维度数
    public int testDataSize;//测试集分次处理每次处理的个数
    public int interval;//每间隔interval进行一次打印
    public String modelPath;//模型训练参数的存放路径

    public FTRLParameters(double alpha, double beta,
                          double L1, double L2, int dataDimensions,int testDataSize, int interval,String modelPath) {
        this.alpha = alpha;
        this.beta = beta;
        this.L1_lambda = L1;
        this.L2_lambda = L2;
        this.dataDimensions = dataDimensions;
        this.testDataSize = testDataSize;
        this.interval = interval;
        this.modelPath = modelPath;
    }
}

模型训练类

package model;
import DataPreprocessing.FileOperation;
import evaluate.LogLossEvalutor;
import java.io.*;
import java.util.Map;
import java.util.TreeMap;

public class FTRLLocalTrain {
    private FTRLProximal learner;
    private FTRLModelLoad mload;
    private LogLossEvalutor evalutor;
    private int printInterval;

    public FTRLLocalTrain(FTRLModelLoad mload, FTRLProximal learner, LogLossEvalutor evalutor, int interval) {
        this.mload = mload;
        this.learner = learner;
        this.evalutor = evalutor;
        this.printInterval = interval;
    }

    /**
     * 训练方法
     * */
    public void train(String modelPath,double[][] X,double[] Y) throws IOException {
        int trainedNum = 0;
        double totalLoss = 0.0;//损失值
        long startTime = System.currentTimeMillis();
       BufferedReader mp = new BufferedReader(new InputStreamReader(new FileInputStream(new File(modelPath)), "UTF-8"));
        while((line = mp.readLine())!=null){
            learner.loadModel(modelPath);
        }
        for(int j=0;j<X.length;j++){
            Map<Integer, Double> x = new TreeMap<Integer, Double>();
            for (int i = 0; i < X[0].length; i++) {
                x.put(i, X[j][i]);
            }
            double y = ((int)Y[j] == 1) ? 1. : 0.;
            double p = learner.predict(x);
            learner.updateModel(x, p, y);
            double loss = LogLossEvalutor.calLogLoss(p, y);
            evalutor.addLogLoss(loss);
            totalLoss += loss;
            trainedNum += 1;
            if (trainedNum % printInterval == 0) {
                long currentTime = System.currentTimeMillis();
                double minutes = (double) (currentTime - startTime) / 60000;
                System.out.printf("%.3f, %.5f\n", minutes, evalutor.getAverageLogLoss());
            }
        }
        learner.saveModel(modelPath);
        System.out.printf("global average loss: %.5f\n", totalLoss / trainedNum);
    }
}

模型更新类

package model;
import DataClass.FTRLParameters;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;

public class FTRLProximal {

    // parameters->alpha, beta, l1, l2, dimensions
    private FTRLParameters parameters;
    // n->squared sum of past gradients
    public double[] n;
    // z->weights
    public double[] z;
    // w->lazy weights
    public Map<Integer, Double> w;

    public double[] n_;
    public double[] z_;
    public Map<Integer, Double> w_;

    public FTRLProximal(FTRLParameters parameters) {
        this.parameters = parameters;
        this.n = new double[parameters.dataDimensions];
        this.z = new double[parameters.dataDimensions];
        this.w = null;
    }


    /** x->p(y=1|x; w) , get w, nothing is changed*/
    public double  predict(Map<Integer, Double> x) {
        w = new HashMap<Integer, Double>();
        double decisionValue = 0.0;
        for (Entry<Integer, Double> e : x.entrySet()) {
            double sgn = sign(z[e.getKey()]);
            double weight = 0.0;
            if (sgn * z[e.getKey()] <= parameters.L1_lambda) {
                w.put(e.getKey(), weight);
            } else {
                weight = (sgn * parameters.L1_lambda - z[e.getKey()])
                        / ((parameters.beta + Math.sqrt(n[e.getKey()]))
                                / parameters.alpha + parameters.L2_lambda);
                w.put(e.getKey(), weight);
            }
            decisionValue += e.getValue() * weight;
        }
        decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);
        return 1. / (1. + Math.exp(-decisionValue));
    }

    /** input: sample x, probability p, label y(-1(or 0) or 1) 
     *  used: w 
     *  update: n, z*/
    public void updateModel(Map<Integer, Double> x, double p, double y) {
        for(Entry<Integer, Double> e : x.entrySet()) {
            double grad = p * e.getValue();
            if(y == 1.0) {
                grad = (p - y) * e.getValue();
            }
            double sigma = (Math.sqrt(n[e.getKey()] + grad * grad) - 
                    Math.sqrt(n[e.getKey()])) / parameters.alpha;
            z[e.getKey()] += (grad - sigma * w.get(e.getKey()));
            n[e.getKey()] += grad * grad;
        }
    }

    /**
     * N、Z、W
     * 模型参数保存函数
     * */
    public void saveModel(String filePath) throws IOException {

        String n_=String.valueOf(n[0]);
        String z_=String.valueOf(z[0]);
        String w_=String.valueOf(w.get(0));

        for(int i=1;i<n.length;i++){
            n_ = n_+" "+String.valueOf(n[i]);
            z_ = z_+" "+String.valueOf(z[i]);
            w_ = w_+" "+String.valueOf(w.get(i));
        }

        try{
            File file = new File(filePath);
            if(!file.exists()){
                file.createNewFile();
            }
            FileWriter fileWriter = new FileWriter(filePath);
            BufferedWriter bufferWriter = new BufferedWriter(fileWriter);
            bufferWriter.write(n_+"\r\n");
            bufferWriter.write(z_+"\r\n");
            bufferWriter.write(w_);
            bufferWriter.close();
            System.out.print("Done");
        }catch (IOException e){
            e.printStackTrace();
        }

    }

    public void loadModel(String filePath) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8"));
        String line = null;
        String[][] Str = new String[3][];
        int i = 0;
        while((line = br.readLine()) != null) {
            Str[i] = line.split(" ");
            i++;
        }
        n = new double[n.length];
        z = new double[z.length];
        w = new HashMap<Integer, Double>();
        for(int j=0;j<n.length;j++){
            n[j] = Double.valueOf(Str[0][j]);
            z[j] = Double.valueOf(Str[1][j]);
            w.put(j,Double.valueOf(Str[2][j]));
        }
    }


    public double predict_(Map<Integer, Double> x) {

        double decisionValue = 0.0;
        for (Entry<Integer, Double> e : x.entrySet()) {

            decisionValue += e.getValue() * w_.get(e.getKey());
        }
        decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);
        return 1. / (1. + Math.exp(-decisionValue));
    }



    private double sign(double x) {
        if (x > 0) {
            return 1.0;
        } else if (x < 0) {
            return -1.0;
        } else {
            return 0.0;
        }
    }
}

模型预测类

package model;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;

/**
 * 模型下载与预测方法
 * n、z、w为需下载的模型参数
 */
public class FTRLModelLoad {

    public double[] n;
    public double[] z;
    public Map<Integer, Double> w;

    /**
     * 模型下载方法
     * 输入:模型文件所在路径
     * 功能:算法全局参数更新
     * */
    public Map<Integer, Double> loadModel(String filePath) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(new File(filePath)), "UTF-8"));
        String line = null;
        String[][] Str = new String[3][];
        int i = 0;
        while((line = br.readLine()) != null) {
            Str[i] = line.split(" ");
            i++;
        }
        n = new double[Str[0].length];
        z = new double[Str[0].length];
        w = new HashMap<Integer, Double>();
        for(int j=0;j<Str[0].length;j++){
            n[j] = Double.valueOf(Str[0][j]);
            z[j] = Double.valueOf(Str[1][j]);
            w.put(j,Double.valueOf(Str[2][j]));
        }
        return w;
    }

    /**
     * 预测函数
     * */
    public double predict_(double[] x_,Map<Integer,Double> w) {
        Map<Integer,Double> x = new TreeMap<Integer, Double>();
        for(int i=0;i<x_.length;i++){
            x.put(i,x_[i]);
        }
        double decisionValue = 0.0;
        for (Map.Entry<Integer, Double> e : x.entrySet()) {
            decisionValue += e.getValue() * w.get(e.getKey());
        }
        decisionValue = Math.max(Math.min(decisionValue, 35.), -35.);
        return 1. / (1. + Math.exp(-decisionValue));
    }
}

损失函数类

package evaluate;

public class LogLossEvalutor {

    private int testDataSize;
    private double[] logloss;
    private int position;
    private double totalLoss;
    private boolean enoughData;

    public LogLossEvalutor(int testDataSize) {
        this.testDataSize = testDataSize;
        logloss = new double[testDataSize];
        position = 0;
        totalLoss = 0.0;
    }

    public void addLogLoss(double loss) {
        totalLoss = totalLoss + loss - logloss[position];
        logloss[position] = loss;
        position += 1;
        if(position >= testDataSize) {
            position = 0;
            enoughData = true;
        }
    }

    public double getAverageLogLoss() {
        if(enoughData) {
            return totalLoss / testDataSize;
        } else {
            return totalLoss / position;
        }
    }

    /** prob: p(y=1|x;w), y: 1 or 0(-1) */
    public static double calLogLoss(double prob, double y) {
        //预测值范围控制方法
        double p = Math.max(Math.min(prob,  1-1e-15), 1e-15);
        return y == 1.? -Math.log(p) : -Math.log(1. - p);
    }

    public static void main(String[] args) {
        LogLossEvalutor evalutor = new LogLossEvalutor(4);
        double[] losses = {3, 2, 1, 0.7, 0.5, 0.2};
        for(int i=0; i<losses.length; i++) {
            evalutor.addLogLoss(losses[i]);
            System.out.println(evalutor.getAverageLogLoss());
        }
    }
}

下面为程序的一个测试结果图:

这里写图片描述

参考文献
(1) Ad Click Prediction: a View from the Trenches.H. Brendan McMahan, Gary Holt, D. Sculley et al
(2)美团技术团队《Online Learning算法理论与实践》

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!