Mxnet (15): 网络中的网络(NiN)

梦想的初衷 提交于 2020-09-30 06:43:51

1. 网络中的网络(NiN)

LeNet、AlexNet和VGG在设计上的共同之处是就是都是先通过卷积提取特征,然后通过全链接层进行分类。NiN使用了另外一种思路,将简单的卷积和全链接结构协程Block,然后将这些block串联形成网络。

1.1 NiN块

卷积层的输入和输出通常是四维数组(样本,通道,高,宽),而全连接层的输入和输出则通常是二维数组(样本,特征)。如果想在全连接层后再接上卷积层,则需要将全连接层的输出变换为四维。这里通过 1 × 1 1×1 1×1卷积层作为全链接层,每一个 1 × 1 1×1 1×1相当于一个样本,用到相当于特征。

下图通过VGG和NiN的比较展示其结构。

在这里插入图片描述

NiN块包含一个卷积层,然后是两个卷积层 1 × 1 1×1 1×1卷积层充当具有ReLU激活的按像素的全连接层。第一层的卷积窗口形状通常由用户设置。随后的窗口形状固定为 1 × 1 1×1 1×1

from d2l import mxnet as d2l
from mxnet import np, npx, init, gluon, autograd
from mxnet.gluon import nn
import plotly.graph_objs as go
npx.set_np()

ctx = npx.gpu() if npx.num_gpus() > 0 else npx.cpu()

def nin_block(num_channels, kernel_size, strides=1, padding=0):
    blk = nn.Sequential()
    blk.add(
        nn.Conv2D(num_channels, kernel_size, strides, padding, activation='relu'),
        nn.Conv2D(num_channels, kernel_size=1, activation='relu'),
        nn.Conv2D(num_channels, kernel_size=1, activation='relu')
    )
    return blk

1.2 NiN 网络

  • NiN使用具有以下形状的卷积层 11 × 11 11×11 11×11 , 5 × 5 5×5 5×5 3 × 3 3×3 3×3 ,相应的输出通道数与AlexNet中的相同。每个NiN块后面是步幅为2且窗口形状为的最大池化层 3 × 3 3×3 3×3
  • NiN使用NiN块来取代全链接层,其输出通道的数量等于标签类的数量,然后是全局平均池层,从而产生logits向量。
  • NiN设计的优点之一是减少了模型需要的参数数量,但是取而代之的是训练时间上的增加。
NinNet = nn.Sequential()
NinNet.add(
    nin_block(96, kernel_size=11, strides=4),
    nn.MaxPool2D(pool_size=3, strides=2),
    nin_block(256, kernel_size=5, padding=2),
    nn.MaxPool2D(pool_size=3, strides=2),
    nin_block(384, kernel_size=3, padding=1),
    nn.MaxPool2D(pool_size=3, strides=2),
    nn.Dropout(0.5),
    # 因为分类种类为10,这里输出10个channels
    nin_block(10, kernel_size=3, padding=1),
    # 全局平均池化层将窗口形状自动设置成输入的高和宽
    nn.GlobalAvgPool2D(),
    # 将四维转为二维(batch size, 10)
    nn.Flatten()
)

我们创建一个数据示例以查看每个块的输出形状。

X = np.random.uniform(size=(1, 1, 224, 224))
NinNet.initialize()
for layer in NinNet:
    X = layer(X)
    print(layer.name, 'output shape:\t', X.shape)

在这里插入图片描述

1.3 训练

def get_workers(num):
    # windows系统不能使用多线程转换
    return 0 if __import__('sys').platform.startswith('win') else num

def loader(data, batch_size, shuffle=True, workers = 6):
    return gluon.data.DataLoader(data,batch_size, shuffle=shuffle,
                                   num_workers=get_workers(workers))

def load_data(batch_size, resize=None):
    
    dataset = gluon.data.vision
    trans = [dataset.transforms.Resize(resize)] if resize else []
    trans.append(dataset.transforms.ToTensor())
    trans = dataset.transforms.Compose(trans)
    mnist_train = dataset.FashionMNIST(train=True).transform_first(trans)
    mnist_test = dataset.FashionMNIST(train=False).transform_first(trans)
    return loader(mnist_train, batch_size), loader(mnist_test, batch_size, False)    


def accuracy(y_hat, y): 
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.astype(y.dtype) == y
    return float(cmp.sum())

def train_epoch(net, train_iter, loss, updater):
    
    l_sum = acc_rate = total = 0
    
    if isinstance(updater, gluon.Trainer):
        updater = updater.step
        
    for X,y in train_iter:
        X = X.as_in_ctx(ctx)
        y = y.as_in_ctx(ctx)
        with autograd.record():
            pre_y = net(X)
            l = loss(pre_y, y)
        l.backward()
        updater(y.size)
        l_sum += float(l.sum())
        acc_rate += accuracy(pre_y, y)
        total += y.size
    return l_sum/total, acc_rate/total

def evaluate_accuracy(net, data_iter):  

    match_num = total_num = 0
    for X, y in data_iter:
        X = X.as_in_ctx(ctx)
        y = y.as_in_ctx(ctx)
        match_num += accuracy(net(X), y)
        total_num += y.size
    return match_num / total_num

import time
def train(net, train_iter, test_iter, epochs, lr):
    
    net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())
    loss = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(), 'sgd',  {
   
   'learning_rate': lr})
    l_lst, acc_lst, test_acc_lst = [], [], []
    timer = 0
    print("----------------start------------------")
    for epoch in range(epochs):
        start = time.time()
        l, acc = train_epoch(net, train_iter, loss, trainer)
        timer += time.time()-start
        test_acc = evaluate_accuracy(net, test_iter)
        print(f'[epoch {epoch+1}] loss {l:.3f}, train acc {acc:.3f}, ' f'test acc {test_acc:.3f}')
        l_lst.append(l)
        acc_lst.append(acc)
        test_acc_lst.append(test_acc)
    print(f'loss {l:.3f}, train acc {acc:.3f}, test acc {test_acc:.3f}')
    print(f'{timer:.1f} sec, on {str(ctx)}')
    draw_graph([l_lst, acc_lst, test_acc_lst])
    

def draw_graph(result):
    data = []
    colors = ['aquamarine', 'orange', 'hotpink']
    names = ['train loss', 'train acc', 'test acc']
    symbols = ['circle-open', 'cross-open', 'triangle-up-open']
    for i, info in enumerate(result):
        trace = go.Scatter(
            x = list(range(1, num_epochs+1)),
            y = info,
            mode = 'lines+markers',
            name = names[i],
            marker = {
   
   
                'color':colors[i],
                'symbol':symbols[i],
            },
        )
        data.append(trace)
    fig = go.Figure(data = data)
    fig.update_layout(xaxis_title='epochs', width=800, height=480)
    fig.show()

和以前一样,我们使用Fashion-MNIST训练模型。

lr, num_epochs, batch_size = 0.1, 10, 64
train_iter, test_iter = load_data(batch_size, resize=224)
train(NinNet, train_iter, test_iter, num_epochs, lr)

在这里插入图片描述

在这里插入图片描述

1.4 预测

训练完成的模型通过输入一些数据进行预测,试试效果

import plotly.express as px
from plotly.subplots import make_subplots
def get_fashion_mnist_labels(labels): 
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None): 
    colorscales = px.colors.named_colorscales()
    fig = make_subplots(num_rows, num_cols, subplot_titles=titles)
    for i, img in enumerate(imgs):
        fig.add_trace(go.Heatmap(z=img.asnumpy()[::-1], showscale=False, colorscale=colorscales[i+3]), 1, i+1)
        fig.update_xaxes(visible=False,row=1, col=i+1)
        fig.update_yaxes(visible=False, row=1, col=i+1)
    fig.update_layout(height=280)
    fig.show()

def predict(net, test_iter, stop, n=8):
    for i,(X,y) in enumerate(test_iter):
        if (i==stop) :
            break
    X,y = X.as_in_ctx(ctx), y.as_in_ctx(ctx)
    trues = get_fashion_mnist_labels(y)
    preds = get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [f"true: {t} <br> pre: {p}" for t, p in zip(trues, preds)]
    show_images(X[:n].reshape((n, 224, 224)), 1, n, titles=titles[:n])

import random
stop = random.choice(range(10))
predict(NinNet, test_iter, stop)

在这里插入图片描述

2. 参考

https://d2l.ai/chapter_convolutional-modern/densenet.html

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!