动手学习深度学习 3-2 Softmax-regression

天大地大妈咪最大 提交于 2020-02-05 18:58:56

Softmax Regression

1. softmax 回归

softmax 回归主要是解决分类问题,输出是离散值,区别于线性回归,输出单元从一个变成了多个。

注:
需要在notebook文件地址下打开jupyter notebook,如果在子文件夹 04 chapter-deep-learning-basics 下打开,则无法访问 notebook/img 中的图片

1.1 softmax回归模型

softmax回归将输入特征与权重做线性叠加,输出值个数等于标签里的类别数。假设有4种特征和3种输出动物类别,所以权重包含12个标量(带下标的\(w\))、偏差包含3个标量(带下标的\(b\)),且对每个输入计算\(o_1, o_2, o_3\)这3个输出:

\[ \begin{aligned} o_1 &= x_1 w_{11} + x_2 w_{21} + x_3 w_{31} + x_4 w_{41} + b_1,\\ o_2 &= x_1 w_{12} + x_2 w_{22} + x_3 w_{32} + x_4 w_{42} + b_2,\\ o_3 &= x_1 w_{13} + x_2 w_{23} + x_3 w_{33} + x_4 w_{43} + b_3. \end{aligned} \]

softmax回归是一个单层神经网络,其输出层是一个全连接层。

1.2 softmax运算

softmax运算符(softmax operator)通过下式将输出值变换成值为正且和为1的概率分布:

\[\hat{y}_1, \hat{y}_2, \hat{y}_3 = \text{softmax}(o_1, o_2, o_3),\]

其中

\[ \hat{y}_1 = \frac{ \exp(o_1)}{\sum_{i=1}^3 \exp(o_i)},\quad \hat{y}_2 = \frac{ \exp(o_2)}{\sum_{i=1}^3 \exp(o_i)},\quad \hat{y}_3 = \frac{ \exp(o_3)}{\sum_{i=1}^3 \exp(o_i)}. \]

容易看出\(\hat{y}_1 + \hat{y}_2 + \hat{y}_3 = 1\)\(0 \leq \hat{y}_1, \hat{y}_2, \hat{y}_3 \leq 1\),因此\(\hat{y}_1, \hat{y}_2, \hat{y}_3\)是一个合法的概率分布。这时候,如果\(\hat{y}_2=0.8\),不管\(\hat{y}_1\)\(\hat{y}_3\)的值是多少,我们都知道图像类别为猫的概率是80%。

1.3 交叉熵损失函数

真实标签也可以用类别分布表达:对于样本\(i\),我们构造向量\(\boldsymbol{y}^{(i)}\in \mathbb{R}^{q}\) ,使其第\(y^{(i)}\)(样本\(i\)类别的离散数值)个元素为1,其余为0。这样我们的训练目标可以设为使预测概率分布\(\boldsymbol{\hat y}^{(i)}\)尽可能接近真实的标签概率分布\(\boldsymbol{y}^{(i)}\)

对于分类问题,常使用交叉熵衡量真实概率分布和预测概率分布的差异:

\[H\left(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}\right ) = -\sum_{j=1}^q y_j^{(i)} \log \hat y_j^{(i)},\]

其中带下标的\(y_j^{(i)}\)是向量\(\boldsymbol y^{(i)}\)中非0即1的元素,需要注意将它与样本\(i\)类别的离散数值,即不带下标的\(y^{(i)}\)区分。在上式中,我们知道向量\(\boldsymbol y^{(i)}\)中只有第\(y^{(i)}\)个元素\(y^{(i)}_{y^{(i)}}\)为1,其余全为0,于是\(H(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}) = -\log \hat y_{y^{(i)}}^{(i)}\)。也就是说,交叉熵只关心对正确类别的预测概率,因为只要其值足够大,就可以确保分类结果正确。当然,遇到一个样本有多个标签时,例如图像里含有不止一个物体时,我们并不能做这一步简化。但即便对于这种情况,交叉熵同样只关心对图像中出现的物体类别的预测概率。

假设训练数据集的样本数为\(n\),交叉熵损失函数定义为
\[ \ell(\boldsymbol{\Theta}) = \frac{1}{n} \sum_{i=1}^n H\left(\boldsymbol y^{(i)}, \boldsymbol {\hat y}^{(i)}\right ),\]

其中\(\boldsymbol{\Theta}\)代表模型参数。同样地,如果每个样本只有一个标签,那么交叉熵损失可以简写成\(\ell(\boldsymbol{\Theta}) = -(1/n) \sum_{i=1}^n \log \hat y_{y^{(i)}}^{(i)}\)。从另一个角度来看,我们知道最小化\(\ell(\boldsymbol{\Theta})\)等价于最大化\(\exp(-n\ell(\boldsymbol{\Theta}))=\prod_{i=1}^n \hat y_{y^{(i)}}^{(i)}\),即最小化交叉熵损失函数等价于最大化训练数据集所有标签类别的联合预测概率。

1.3 模型预测及评价

在训练好softmax回归模型后,给定任一样本特征,就可以预测每个输出类别的概率。通常,我们把预测概率最大的类别作为输出类别。

2. softmax-regression-scratch

基本步骤:
获取数据-初始化模型参数-定义softmax运算-定义模型-定义损失函数-计算分类准确度-训练模型-预测

获取数据

%matplotlib inline
import d2lzh as d2l
from mxnet import autograd, nd

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型参数

num_inputs = 784  # 每个样本输入是高和宽均为28像素的图像
num_outputs = 10  # 一共有10个类别

W = nd.random.normal(scale=0.01, shape=(num_inputs, num_outputs))
b = nd.zeros(num_outputs)
W.attach_grad()
b.attach_grad()

softmax运算

def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(axis=1, keepdims=True)
    return X_exp / partition  # 这里应用了广播机制

定义模型

def net(X):
    return softmax(nd.dot(X.reshape((-1, num_inputs)), W) + b)

定义损失函数

def cross_entropy(y_hat, y):
    return -nd.pick(y_hat, y).log()

计算分类准确率

def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        y = y.astype('float32')
        acc_sum += (net(X).argmax(axis=1) == y).sum().asscalar()
        n += y.size
    return acc_sum / n

训练模型

num_epochs, lr = 5, 0.1

def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, trainer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            with autograd.record():
                y_hat = net(X)
                l = loss(y_hat, y).sum()
            l.backward()
            if trainer is None:
                d2l.sgd(params, lr, batch_size)
            else:
                trainer.step(batch_size)  # “softmax回归的简洁实现”一节将用到
            y = y.astype('float32')
            train_l_sum += l.asscalar()
            train_acc_sum += (y_hat.argmax(axis=1) == y).sum().asscalar()
            n += y.size
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size,
          [W, b], lr)

训练输出

epoch 1, loss 0.7893, train acc 0.745, test acc 0.805
epoch 2, loss 0.5742, train acc 0.810, test acc 0.823
epoch 3, loss 0.5288, train acc 0.823, test acc 0.831
epoch 4, loss 0.5046, train acc 0.830, test acc 0.833
epoch 5, loss 0.4894, train acc 0.835, test acc 0.841

注:训练误差总是高于测试误差是因为,注意训练误差train_acc的计算是在每个epoch的每个batch中计算然后求平均。 而测试误差是在每个epoch结束后再计算的,显然,测试误差计算的时候 模型都效果都是好于train_acc每次计算,因为经历完了一个epoch所有样本的更新。

预测

true_labels = d2l.get_fashion_mnist_labels(y.asnumpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1).asnumpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]

d2l.show_fashion_mnist(X[0:9], titles[0:9])

预测输出
prediction

3. softmax-regression-gluon

下面利用Gluon来更简洁地实现softmax回归模型。

%matplotlib inline
import d2lzh as d2l
from mxnet import gluon, init
from mxnet.gluon import loss as gloss, nn

# 获取和读取数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 定义和初始化模型
net = nn.Sequential()
net.add(nn.Dense(10))
net.initialize(init.Normal(sigma=0.01))

# 损失函数
loss = gloss.SoftmaxCrossEntropyLoss()

# 定义优化算法
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})

# 训练模型
num_epochs = 5
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None,
              None, trainer)

输出

epoch 1, loss 0.7897, train acc 0.745, test acc 0.807
epoch 2, loss 0.5731, train acc 0.812, test acc 0.822
epoch 3, loss 0.5295, train acc 0.824, test acc 0.835
epoch 4, loss 0.5053, train acc 0.830, test acc 0.832
epoch 5, loss 0.4887, train acc 0.834, test acc 0.842
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!