学习资料参考李沐gluon讲义。
mxnet,gluon构建网络的区别体现在4方面。下面以简单的drop out为例一一对比说明。
1 构建dropout
ndarray:
def dropout(X, drop_prob):
assert 0 <= drop_prob <= 1
keep_prob = 1 - drop_prob
# 这种情况下把全部元素都丢弃。
if keep_prob == 0:
return X.zeros_like()
mask = nd.random.uniform(0, 1, X.shape) < keep_prob
return mask * X / keep_prob
gluon:
from mxnet import nn
drop_prob = 0.2
nn.Dropout(drop_prob)
2 构建网络
ndarray:
input dim要定义(i.e. num_inputs), net的dim在param里定义,net列出包含activation functions的linear计算表达式即可
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
def net(X):
X = X.reshape((-1, num_inputs))
H1 = (nd.dot(X, W1) + b1).relu()
if autograd.is_training(): # 只在训练模型时使用丢弃法。
H1 = dropout(H1, drop_prob1) # 在第一层全连接后添加丢弃层。
H2 = (nd.dot(H1, W2) + b2).relu()
if autograd.is_training():
H2 = dropout(H2, drop_prob2) # 在第二层全连接后添加丢弃层。
return nd.dot(H2, W3) + b3
gluon:
from mxnet.gluon import nn
net = nn.Sequential()
net.add(nn.Dense(256, activation="relu"),
nn.Dropout(drop_prob1), # 在第一个全连接层后添加丢弃层。
nn.Dense(256, activation="relu"),
nn.Dropout(drop_prob2), # 在第二个全连接层后添加丢弃层。
nn.Dense(10))
3 构建参数
ndarray:
手动每一层w,b 初始化,并attach_grad
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
W1 = nd.random.normal(scale=0.01, shape=(num_inputs, num_hiddens1))
b1 = nd.zeros(num_hiddens1)
W2 = nd.random.normal(scale=0.01, shape=(num_hiddens1, num_hiddens2))
b2 = nd.zeros(num_hiddens2)
W3 = nd.random.normal(scale=0.01, shape=(num_hiddens2, num_outputs))
b3 = nd.zeros(num_outputs)
params = [W1, b1, W2, b2, W3, b3]
for param in params:
param.attach_grad()
gluon:
net.initialize(init.Normal(sigma=0.01))
4 更新参数
ndarray:
需要手动更新,如sgd(params, lr, batch_size)
def sgd(params, lr, batch_size):
"""Mini-batch stochastic gradient descent."""
for param in params:
param[:] = param - lr * param.grad / batch_size
gluon:
from mxnet import gluon
lr = 0.1
batch_size = 256
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
'''then in the data feeding loop'''
trainer.step(batch_size)
From: mxnet和gluon学习笔记二——mxnet,gluon构建网络的区别
来源:CSDN
作者:tony2278
链接:https://blog.csdn.net/tony2278/article/details/104768263