Folding in (estimating topics for new documents) in LDA using Mallet in Java

僤鯓⒐⒋嵵緔 提交于 2019-12-04 16:20:21

Inference is actually also listed in the example link provided in the question (the last few lines).

For anyone interested in the whole code for saving/loading the trained model and then using it for inferring model distribution for new documents - here are some snippets:

After model.estimate() has completed, you have the actual trained model so you can serialize it using a standard Java ObjectOutputStream (since ParallelTopicModel implements Serializable):

try {
    FileOutputStream outFile = new FileOutputStream("model.ser");
    ObjectOutputStream oos = new ObjectOutputStream(outFile);
    oos.writeObject(model);
    oos.close();
} catch (FileNotFoundException ex) {
    // handle this error
} catch (IOException ex) {
    // handle this error
}

Note though, when you infer, you need also to pass the new sentences (as Instance) through the same pipeline in order to pre-process it (tokenzie etc) thus, you need to also save the pipe-list (since we're using SerialPipe when can create an instance and then serialize it):

// initialize the pipelist (using in model training)
SerialPipes pipes = new SerialPipes(pipeList);

try {
    FileOutputStream outFile = new FileOutputStream("pipes.ser");
    ObjectOutputStream oos = new ObjectOutputStream(outFile);
    oos.writeObject(pipes);
    oos.close();
} catch (FileNotFoundException ex) {
    // handle error
} catch (IOException ex) {
    // handle error
}

In order to load the model/pipeline and use them for inference we need to de-serialize:

private static void InferByModel(String sentence) {
    // define model and pipeline
    ParallelTopicModel model = null;
    SerialPipes pipes = null;

    // load the model
    try {
        FileInputStream outFile = new FileInputStream("model.ser");
        ObjectInputStream oos = new ObjectInputStream(outFile);
        model = (ParallelTopicModel) oos.readObject();
    } catch (IOException ex) {
        System.out.println("Could not read model from file: " + ex);
    } catch (ClassNotFoundException ex) {
        System.out.println("Could not load the model: " + ex);
    }

    // load the pipeline
    try {
        FileInputStream outFile = new FileInputStream("pipes.ser");
        ObjectInputStream oos = new ObjectInputStream(outFile);
        pipes = (SerialPipes) oos.readObject();
    } catch (IOException ex) {
        System.out.println("Could not read pipes from file: " + ex);
    } catch (ClassNotFoundException ex) {
        System.out.println("Could not load the pipes: " + ex);
    }

    // if both are properly loaded
    if (model != null && pipes != null){

        // Create a new instance named "test instance" with empty target 
        // and source fields note we are using the pipes list here
        InstanceList testing = new InstanceList(pipes);   
        testing.addThruPipe(
            new Instance(sentence, null, "test instance", null));

        // here we get an inferencer from our loaded model and use it
        TopicInferencer inferencer = model.getInferencer();
        double[] testProbabilities = inferencer
                   .getSampledDistribution(testing.get(0), 10, 1, 5);
        System.out.println("0\t" + testProbabilities[0]);
    }
}

For some reason I am not getting the exact same inference with the loaded model as with the original one - but this is a matter for another question (if anyone knows though, I'd be happy to hear)

And I've found the answer hidden in a slide-deck from Mallet's lead developer:

TopicInferencer inferencer = model.getInferencer();
double[] topicProbs = inferencer.getSampledDistribution(newInstance, 100, 10, 10);
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!