问题
I am using MXnet for training a CNN (in R) and I can train the model without any error with the following code:
model <- mx.model.FeedForward.create(symbol=network,
X=train.iter,
ctx=mx.gpu(0),
num.round=20,
array.batch.size=batch.size,
learning.rate=0.1,
momentum=0.1,
eval.metric=mx.metric.accuracy,
wd=0.001,
batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
But as this process is time-consuming, I run it on a server during the night and I want to save the model for the purpose of using it after finishing the training.
I used:
save(list = ls(), file="mymodel.RData")
and
mx.model.save("mymodel", 10)
But none of them can save the model! for example when I load the "mymodel.RData"
, I can not predict the labels for the test set!
Another example is when I load the "mymodel.RData"
and try to plot it with the following code:
graph.viz(model$symbol$as.json())
I get the following error:
Error in model$symbol$as.json() : external pointer is not valid
Can anybody give me a solution for saving and then loading this model for future use?
Thanks
回答1:
You can save the model by
model <- mx.model.FeedForward.create(symbol=network,
X=train.iter,
ctx=mx.gpu(0),
num.round=20,
array.batch.size=batch.size,
learning.rate=0.1,
momentum=0.1,
eval.metric=mx.metric.accuracy,
wd=0.001,
epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
回答2:
The best practice for saving a snapshot of your training progress is to use save_snapshot (http://mxnet.io/api/python/module.html#mxnet.module.Module.save_checkpoint) as part of the callback after every epoch training. In R the equivalent command is probably mx.callback.save.checkpoint, but I'm not using R and not sure about it usage.
Using these snapshots can also allow you to take advantage of the low cost option of using AWS Spot market (https://aws.amazon.com/ec2/spot/pricing/ ), which for example now offers and instance with 16 K80 GPUs for $3.8/hour compare to the on-demand price of $14.4. Such 80%-90% discount is common in the spot market and can optimize the speed and cost of your training, as long as you use these snapshots correctly.
回答3:
A mxnet model is an R list, but its first component is not an R object but a C++ pointer and can't be saved and reloaded as an R object. Therefore, the model needs to be serialized to behave as an actual R object. The serialized object is also a list, but its first object is a text containing model information.
To save a model:
modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")
To retrieve it and use it again:
load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)
来源:https://stackoverflow.com/questions/43517960/how-to-save-a-model-when-using-mxnet