Tensorflow Invalid Argument: Assertation Failed [Label IDs must < n_classes]

前端 未结 1 919
渐次进展
渐次进展 2021-01-12 06:04

I get an error implementing a DNNClassifier in Tensorflow 1.3.0 with Python 2.7. I got the sample code from the Tensorflow tf.estimator Quickstart Tutorial and

相关标签:
1条回答
  • 2021-01-12 06:28

    So the solution as Ishant Mrinal pointed out:

    Tensorflow expects the integers from 0 up to the number of classes as class labels (range(0, num_classes)), not "arbitrary" numbers like in my case. Thanks!:)

    ...The other option I just came across is to add a label_vocabulary to the classifier-definition:

    classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                              hidden_units=[10, 20, 10],
                                              n_classes=numClasses,
                                              model_dir=saveAt,
                                              label_vocabulary=uniqueTrain)
    

    With this option I can define the labels like I had before, converted to strings.

    0 讨论(0)
提交回复
热议问题