CNTK python API: How to get predictions from the trained model?

元气小坏坏 提交于 2019-12-08 05:15:49

问题


I have a trained model which I am loading using CNTK.load_model() function. I was looking at the MNIST Tutorial on the CNTK git repo as reference for model evaluation code. I have created a data reader (which is a MinibatchSource object) and trying to run model.eval(mb) where mb = minibatch_source.next_minibatch(...) (Similar to this answer)

But, I'm getting the following error message

Traceback (most recent call last):
    File "LID_test.py", line 162, in <module>
        test_and_evaluate()
    File "LID_test.py", line 159, in test_and_evaluate
        predictions = model.eval(mb)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/ops/functions.py", line 228, in eval
        _, output_map = self.forward(arguments, self.outputs, device=device, as_numpy=as_numpy)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/utils/swig_helper.py", line 62, in wrapper
        result = f(*args, **kwds)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/ops/functions.py", line 354, in forward
        None, device)
    File "/home/t-asbahe/anaconda3/envs/cntk-py35/lib/python3.5/site-packages/cntk/utils/__init__.py", line 393, in sanitize_var_map
        if len(arguments) < len(op_arguments):
TypeError: object of type 'Variable' has no len()

I have no input_variable named 'Variable' in my model and I don't see any reason to get this error.

P.S.: My inputs are sparse inputs (one-hots)


回答1:


You have a few options:

  • Pass a set of data as numpy array (instance in CNTK 202 tutorial) where onehot data is passed in as a numpy array.

    pred = model.eval({model.arguments[0]:[onehot]})

  • Read the minibatch data and pass it to the eval function

    eval_input_map = { input : reader_eval.streams.features }
    eval_data = reader_eval.next_minibatch(eval_minibatch_size, input_map = eval_input_map) mydata = eval_data[input].value predicted= model.eval(mydata)



来源:https://stackoverflow.com/questions/42646842/cntk-python-api-how-to-get-predictions-from-the-trained-model

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!