问题
I have trained a network using MXnet, but am not sure how I can save and load the parameters for later use. First I define and train the network:
dataIn = mx.sym.var('data')
fc1 = mx.symbol.FullyConnected(data=dataIn, num_hidden=100)
act1 = mx.sym.Activation(data=fc1, act_type="relu")
fc2 = mx.symbol.FullyConnected(data=act1, num_hidden=50)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
fc3 = mx.symbol.FullyConnected(data=act2, num_hidden=25)
act3 = mx.sym.Activation(data=fc3, act_type="relu")
fc4 = mx.symbol.FullyConnected(data=act3, num_hidden=10)
act4 = mx.sym.Activation(data=fc4, act_type="relu")
fc5 = mx.symbol.FullyConnected(data=act4, num_hidden=2)
lenet = mx.sym.SoftmaxOutput(data=fc5, name='softmax',normalization = 'batch')
# create iterator around training and validation data
train_iter = mx.io.NDArrayIter(data=data[:ntrain], label = phen[:ntrain],batch_size=batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data=data[ntrain:], label=phen[ntrain:], batch_size=batch_size)
# create a trainable module on GPU 0
lenet_model = mx.mod.Module(symbol=lenet, context=mx.gpu())
# train with the same
lenet_model.fit(train_iter,
eval_data=val_iter,
optimizer='adam',
optimizer_params={'learning_rate':0.00001},
eval_metric='f1',
batch_end_callback = mx.callback.Speedometer(batch_size, 10),
num_epoch=1000)
This model performs well on the test set, so I want to keep it. Next, I save the network layout and the parameterization:
lenet.save('./testNet_symbol.mxnet')
lenet_model.save_params('./testNet_module.mxnet')
All the documentation I can find on loading the network seem to have implemented the save function within the training routine, to save the network parameters at the end of each epoch. I haven't set these checkpoints during the training process Other methods use the mx.model.FeedForward class, which doesn't seem appropriate. Still other methods load the network from a .json file, which I don't have as a result of my save functions. How can I save/load a network after it's already finished training?
回答1:
You just have to do this instead to save:
lenet_model.save_checkpoint('lenet', num_epoch, save_optimizer_states=True)
This would create 3 files if the states flag is set to True else 2 files:
.params (weights), .json (symbol), .states
And this to load:
lenet_model = mx.mod.Module.load(prefix,epoch)
lenet_model.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
来源:https://stackoverflow.com/questions/47190614/how-to-load-a-trained-mxnet-model