Cannot get predictions of tensorflow DNNClassifier

前端 未结 5 1503
无人及你
无人及你 2021-01-05 13:24

I\'m using the code from the MNIST tutorial:

feature_columns = [tf.contrib.layers.real_valued_column(\"\", dimension=4)]
classifier = tf.contrib.learn.DNNCla         


        
相关标签:
5条回答
  • 2021-01-05 14:07

    To be as close as possible to the tutorial use:

    print('Predictions: {}' .format(list(ds_predict_tf)))
    
    0 讨论(0)
  • 2021-01-05 14:16

    Sorry, the answer is very easy, you need to use the predictor as generator object:

    g1 = ds_predict_tf
    
    [g1.__next__() for i in range(100)]
    
    0 讨论(0)
  • 2021-01-05 14:25

    The DNNClassifier predict function by default have as_iterable=True. Thus, it returns an generator. For getting values of predictions instead of generator, pass as_iterable=False in classifier.predict method.

    For example,

    classifier.predict(input_fn = _my_predict_data,as_iterable=False)



    For understanding more about classifier methods and arguments. Here is a part of documentation for predict method.

    From DNNClassifier documentation:

    Predict

    Args:

    • x: features.
    • input_fn: Input function. If set, x must be None.
    • batch_size: Override default batch size.
    • outputs: list of str, name of the output to predict. If None, returns classes.
    • as_iterable: If True, return an iterable which keeps yielding predictions for each example until inputs are exhausted. Note: The inputs must terminate if you want the iterable to terminate (e.g. be sure to pass num_epochs=1 if you are using something like read_batch_features).

    Returns:

    • Numpy array of predicted classes with shape [batch_size] (or an iterable of predicted classes if as_iterable is True). Each predicted class is represented by its class index (i.e. integer from 0 to n_classes-1). If outputs is set, returns a dict of predictions.
    0 讨论(0)
  • 2021-01-05 14:25

    Solution:-

    pred = classifier.fit(x=training_set.data, y=training_set.target, steps=2000).predict(test_set.data)
    
    print ("Predictions:")
    
    print(list(pred))
    

    That's it...

    0 讨论(0)
  • 2021-01-05 14:28

    What you received and saved to ds_predict_tf is a generator expression. To print it you can do:

    for i in ds_predict_tf:
        print i
    

    or

    print(list(ds_predict_tf))
    

    You can read more about genexpr here.

    0 讨论(0)
提交回复
热议问题