I am trying to train the MNIST dataset using deepenet package's dbn.dnn.train function. The task is a classification one. I am using the following command
dbn.deepnet <- dbn.dnn.train(train.image.data,train.image.labels,hidden=c(5,5))
The problems I am facing are:
1) The labels should be factor type vector. But when i input the labels as factor the function gives an error that "y should be a matrix or vector". So, I am using labels as numeric. How to proceed for a classification task
2) What it the function to make the predictions for dbn.dnn.train. I am using nn.predict but the documentation mentions that the input should be neural network trained by function nn.train (dbn.dnn.train is not mentioned). The output is 0.9986 for all records
nn.predict(dbn.deepnet,train.image.data)
Don't know if you are still working on it, or if you've found the solution but : 1/ try this : train.image.labels <- data.matrix(train.image.labels)
2/ i use nn.predict, even if the neural network is trained by dbn.dnn.train.
As you know the input values for neural network better to be between 0 and 1. In "deepnet" package, unlike the nn.train
function, for dbn.dnn.train
you need to normalize the input yourself. Here is the complete code to load, train, and test.
#loading MNIST
setwd("path/to/MNIST/")
mnist <- load.mnist(".")
# the function to normalize the input values
normalize <- function(x) {
return (x/255)
}
# standardization
train_x_n <- apply(mnist$train$x, c(1,2),FUN = normalize)
test_x_n <- apply(mnist$test$x, c(1,2),FUN = normalize)
#training and prediction
dnn <- dbn.dnn.train(train_x_n, mnist$train$yy, hidden = c(100, 70, 80), numepochs = 3, cd = 3)
err.dnn <- nn.test(dnn, test_x_n, mnist$test$yy)
dnn_predict <- nn.predict(dnn, test_x_n)
# test the outputs
print(err.dnn)
print(dnn_predict[1,])
print(mnist$test$y[1])
Outout:
> err.dnn
[1] 0.0829
> dnn_predict[1,]
[1] 7.549055e-04 1.111647e-03 1.946491e-03 7.417489e-03 3.221340e-04 7.306264e-04 4.088365e-05 9.944441e-01 8.953903e-05
[10] 9.085863e-03
> mnist$test$y[1]
[1] 7
来源:https://stackoverflow.com/questions/28623533/r-package-deepnet-training-and-testing-the-mnist-dataset