WEKA classification likelihood of the classes

后端 未结 4 996
轻奢々
轻奢々 2021-01-12 19:58

I would like to know if there is a way in WEKA to output a number of \'best-guesses\' for a classification.

My scenario is: I classify the data with cross-validatio

4条回答
  •  小鲜肉
    小鲜肉 (楼主)
    2021-01-12 20:36

    Weka's API has a method called Classifier.distributionForInstance() tha can be used to get the classification prediction distribution. You can then sort the distribution by decreasing probability to get your top-N predictions.

    Below is a function that prints out: (1) the test instance's ground truth label; (2) the predicted label from classifyInstance(); and (3) the prediction distribution from distributionForInstance(). I have used this with J48, but it should work with other classifiers.

    The inputs parameters are the serialized model file (which you can create during the model training phase and applying the -d option) and the test file in ARFF format.

    public void test(String modelFileSerialized, String testFileARFF) 
        throws Exception
    {
        // Deserialize the classifier.
        Classifier classifier = 
            (Classifier) weka.core.SerializationHelper.read(
                modelFileSerialized);
    
        // Load the test instances.
        Instances testInstances = DataSource.read(testFileARFF);
    
        // Mark the last attribute in each instance as the true class.
        testInstances.setClassIndex(testInstances.numAttributes()-1);
    
        int numTestInstances = testInstances.numInstances();
        System.out.printf("There are %d test instances\n", numTestInstances);
    
        // Loop over each test instance.
        for (int i = 0; i < numTestInstances; i++)
        {
            // Get the true class label from the instance's own classIndex.
            String trueClassLabel = 
                testInstances.instance(i).toString(testInstances.classIndex());
    
            // Make the prediction here.
            double predictionIndex = 
                classifier.classifyInstance(testInstances.instance(i)); 
    
            // Get the predicted class label from the predictionIndex.
            String predictedClassLabel =
                testInstances.classAttribute().value((int) predictionIndex);
    
            // Get the prediction probability distribution.
            double[] predictionDistribution = 
                classifier.distributionForInstance(testInstances.instance(i)); 
    
            // Print out the true label, predicted label, and the distribution.
            System.out.printf("%5d: true=%-10s, predicted=%-10s, distribution=", 
                              i, trueClassLabel, predictedClassLabel); 
    
            // Loop over all the prediction labels in the distribution.
            for (int predictionDistributionIndex = 0; 
                 predictionDistributionIndex < predictionDistribution.length; 
                 predictionDistributionIndex++)
            {
                // Get this distribution index's class label.
                String predictionDistributionIndexAsClassLabel = 
                    testInstances.classAttribute().value(
                        predictionDistributionIndex);
    
                // Get the probability.
                double predictionProbability = 
                    predictionDistribution[predictionDistributionIndex];
    
                System.out.printf("[%10s : %6.3f]", 
                                  predictionDistributionIndexAsClassLabel, 
                                  predictionProbability );
            }
    
            o.printf("\n");
        }
    }
    

提交回复
热议问题