How to load and retrain tflean model

主宰稳场 提交于 2019-12-07 04:59:25

This is to create a graph and save it

graph1 = tf.Graph()
with graph1.as_default():
    network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
    network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
    branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
    branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
    branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
    network = merge([branch1, branch2, branch3], mode='concat', axis=1)
    network = tf.expand_dims(network, 2)
    network = global_max_pool(network)
    network = dropout(network, 0.5)
    network = fully_connected(network, 2, activation='softmax')
    network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
    model = tflearn.DNN(network, tensorboard_verbose=0)
    clf, acc, roc_auc,fpr,tpr =classify_DNN(data,clas,model)
    clf.save(model_path)

To reload and retrain or use it for prediction

MODEL = None
with tf.Graph().as_default():
## Building deep neural network
    network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
    network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
    branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
    branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
    branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
    network = merge([branch1, branch2, branch3], mode='concat', axis=1)
    network = tf.expand_dims(network, 2)
    network = global_max_pool(network)
    network = dropout(network, 0.5)
    network = fully_connected(network, 2, activation='softmax')
    network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
    new_model = tflearn.DNN(network, tensorboard_verbose=3)
    new_model.load(model_path)
    MODEL = new_model

Use the MODEL for prediction or retraining. The 1st line and the with loop was important. For anyone who might need help

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!