/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TObjectIntHashMap;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.zip.GZIPOutputStream;

public class LDAHyper
implements Serializable {
    protected ArrayList<Topication> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int numTypes;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected double smoothingOnlyMass = 0.0;
    protected double[] cachedCoefficients;
    int topicTermCount = 0;
    int betaTopicCount = 0;
    int smoothingOnlyCount = 0;
    protected InstanceList testing = null;
    protected int[] oneDocTopicCounts;
    protected TIntIntHashMap[] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    public int iterationsSoFar = 0;
    public int numIterations = 1000;
    public int burninPeriod = 20;
    public int saveSampleInterval = 5;
    public int optimizeInterval = 20;
    public int showTopicsInterval = 10;
    public int wordsPerTopic = 7;
    protected int outputModelInterval = 0;
    protected String outputModelFilename;
    protected int saveStateInterval = 0;
    protected String stateFilename = null;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public LDAHyper(int numberOfTopics) {
        this(numberOfTopics, numberOfTopics, 0.01);
    }

    public LDAHyper(int numberOfTopics, double alphaSum, double beta) {
        this(numberOfTopics, alphaSum, beta, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        for (int i = 0; i < numTopics; ++i) {
            ret.lookupIndex("topic" + i);
        }
        return ret;
    }

    public LDAHyper(int numberOfTopics, double alphaSum, double beta, Randoms random) {
        this(LDAHyper.newLabelAlphabet(numberOfTopics), alphaSum, beta, random);
    }

    public LDAHyper(LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        this.alphaSum = alphaSum;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numTopics);
        this.beta = beta;
        this.random = random;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("LDA: " + this.numTopics + " topics");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<Topication> getData() {
        return this.data;
    }

    public int getCountFeatureTopic(int featureIndex, int topicIndex) {
        return this.typeTopicCounts[featureIndex].get(topicIndex);
    }

    public int getCountTokensPerTopic(int topicIndex) {
        return this.tokensPerTopic[topicIndex];
    }

    public void setTestingInstances(InstanceList testing) {
        this.testing = testing;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void setModelOutput(int interval, String filename) {
        this.outputModelInterval = interval;
        this.outputModelFilename = filename;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    protected int instanceLength(Instance instance) {
        return ((FeatureSequence)instance.getData()).size();
    }

    private void initializeForTypes(Alphabet alphabet) {
        if (this.alphabet == null) {
            this.alphabet = alphabet;
            this.numTypes = alphabet.size();
            this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
            for (int fi = 0; fi < this.numTypes; ++fi) {
                this.typeTopicCounts[fi] = new TIntIntHashMap();
            }
            this.betaSum = this.beta * (double)this.numTypes;
        } else {
            if (alphabet != this.alphabet) {
                throw new IllegalArgumentException("Cannot change Alphabet.");
            }
            if (alphabet.size() != this.numTypes) {
                int i;
                this.numTypes = alphabet.size();
                TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[this.numTypes];
                for (i = 0; i < this.typeTopicCounts.length; ++i) {
                    newTypeTopicCounts[i] = this.typeTopicCounts[i];
                }
                for (i = this.typeTopicCounts.length; i < this.numTypes; ++i) {
                    newTypeTopicCounts[i] = new TIntIntHashMap();
                }
                this.betaSum = this.beta * (double)this.numTypes;
            }
        }
    }

    private void initializeTypeTopicCounts() {
        int i;
        TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[this.numTypes];
        for (i = 0; i < this.typeTopicCounts.length; ++i) {
            newTypeTopicCounts[i] = this.typeTopicCounts[i];
        }
        for (i = this.typeTopicCounts.length; i < this.numTypes; ++i) {
            newTypeTopicCounts[i] = new TIntIntHashMap();
        }
        this.typeTopicCounts = newTypeTopicCounts;
    }

    public void addInstances(InstanceList training) {
        this.initializeForTypes(training.getDataAlphabet());
        ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>();
        for (Instance instance : training) {
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[this.instanceLength(instance)]);
            Randoms r = new Randoms();
            int[] topics = topicSequence.getFeatures();
            for (int i = 0; i < topics.length; ++i) {
                topics[i] = r.nextInt(this.numTopics);
            }
            topicSequences.add(topicSequence);
        }
        this.addInstances(training, topicSequences);
    }

    public void addInstances(InstanceList training, List<LabelSequence> topics) {
        this.initializeForTypes(training.getDataAlphabet());
        assert (training.size() == topics.size());
        for (int i = 0; i < training.size(); ++i) {
            Topication t = new Topication((Instance)training.get(i), this, topics.get(i));
            this.data.add(t);
            FeatureSequence tokenSequence = (FeatureSequence)t.instance.getData();
            LabelSequence topicSequence = t.topicSequence;
            for (int pi = 0; pi < topicSequence.getLength(); ++pi) {
                int topic = topicSequence.getIndexAtPosition(pi);
                this.typeTopicCounts[tokenSequence.getIndexAtPosition(pi)].adjustOrPutValue(topic, 1, 1);
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
            }
        }
        this.initializeHistogramsAndCachedValues();
    }

    protected void initializeHistogramsAndCachedValues() {
        int topic;
        int maxTokens = 0;
        int totalTokens = 0;
        for (int doc = 0; doc < this.data.size(); ++doc) {
            FeatureSequence fs = (FeatureSequence)this.data.get((int)doc).instance.getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            totalTokens += seqLen;
        }
        this.smoothingOnlyMass = 0.0;
        for (topic = 0; topic < this.numTopics; ++topic) {
            this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        this.cachedCoefficients = new double[this.numTopics];
        for (topic = 0; topic < this.numTopics; ++topic) {
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        System.err.println("max tokens: " + maxTokens);
        System.err.println("total tokens: " + totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    public void estimate() throws IOException {
        this.estimate(this.numIterations);
    }

    public void estimate(int iterationsThisRound) throws IOException {
        long startTime = System.currentTimeMillis();
        int maxIteration = this.iterationsSoFar + iterationsThisRound;
        while (this.iterationsSoFar <= maxIteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && this.iterationsSoFar != 0 && this.iterationsSoFar % this.showTopicsInterval == 0) {
                System.out.println();
                this.printTopWords(System.out, this.wordsPerTopic, false);
                if (this.testing != null) {
                    double el = this.empiricalLikelihood(1000, this.testing);
                    double ll = this.modelLogLikelihood();
                    double mi = this.topicLabelMutualInformation();
                    System.out.println(ll + "\t" + el + "\t" + mi);
                }
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                this.printState(new File(this.stateFilename + '.' + this.iterationsSoFar));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
                this.smoothingOnlyMass = 0.0;
                for (int topic = 0; topic < this.numTopics; ++topic) {
                    this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
                    this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
                }
                this.clearHistograms();
            }
            this.smoothingOnlyCount = 0;
            this.betaTopicCount = 0;
            this.topicTermCount = 0;
            int numDocs = this.data.size();
            for (int di = 0; di < numDocs; ++di) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)di).instance.getData();
                LabelSequence topicSequence = this.data.get((int)di).topicSequence;
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence, this.iterationsSoFar >= this.burninPeriod && this.iterationsSoFar % this.saveSampleInterval == 0, true);
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            if (elapsedMillis < 1000L) {
                System.out.print(elapsedMillis + "ms ");
            } else {
                System.out.print(elapsedMillis / 1000L + "s ");
            }
            if (this.iterationsSoFar % 10 == 0) {
                System.out.println("<" + this.iterationsSoFar + "> ");
                if (this.printLogLikelihood) {
                    System.out.println(this.modelLogLikelihood());
                }
            }
            System.out.flush();
            ++this.iterationsSoFar;
        }
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    private void clearHistograms() {
        Arrays.fill(this.docLengthCounts, 0);
        for (int topic = 0; topic < this.topicDocCounts.length; ++topic) {
            Arrays.fill(this.topicDocCounts[topic], 0);
        }
    }

    private void oldSampleTopicsForOneDoc(FeatureSequence featureSequence, FeatureSequence topicSequence, boolean saveStateForAlphaEstimation, boolean readjustTopicsAndStats) {
        int token;
        long startTime = System.currentTimeMillis();
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLen = featureSequence.getLength();
        Arrays.fill(this.oneDocTopicCounts, 0);
        if (readjustTopicsAndStats) {
            for (token = 0; token < docLen; ++token) {
                int n = oneDocTopics[token];
                this.oneDocTopicCounts[n] = this.oneDocTopicCounts[n] + 1;
            }
        }
        for (token = 0; token < docLen; ++token) {
            int type = featureSequence.getIndexAtPosition(token);
            int oldTopic = oneDocTopics[token];
            TIntIntHashMap currentTypeTopicCounts = this.typeTopicCounts[type];
            assert (currentTypeTopicCounts.size() != 0);
            if (readjustTopicsAndStats) {
                int n = oldTopic;
                this.oneDocTopicCounts[n] = this.oneDocTopicCounts[n] - 1;
                int adjustedValue = currentTypeTopicCounts.adjustOrPutValue(oldTopic, -1, -1);
                if (adjustedValue == 0) {
                    currentTypeTopicCounts.remove(oldTopic);
                } else if (adjustedValue == -1) {
                    throw new IllegalStateException("Token count in topic went negative.");
                }
                int n2 = oldTopic;
                this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
            }
            int[] topicIndices = currentTypeTopicCounts.keys();
            int[] topicCounts = currentTypeTopicCounts.getValues();
            double[] topicDistribution = new double[topicIndices.length];
            double topicDistributionSum = 0.0;
            for (int i = 0; i < topicCounts.length; ++i) {
                int topic = topicIndices[i];
                double weight = ((double)topicCounts[i] + this.beta) / ((double)this.tokensPerTopic[topic] + this.betaSum) * ((double)this.oneDocTopicCounts[topic] + this.alpha[topic]);
                topicDistributionSum += weight;
                topicDistribution[topic] = weight;
            }
            int newTopic = topicIndices[this.random.nextDiscrete(topicDistribution, topicDistributionSum)];
            if (!readjustTopicsAndStats) continue;
            oneDocTopics[token] = newTopic;
            int n = newTopic;
            this.oneDocTopicCounts[n] = this.oneDocTopicCounts[n] + 1;
            this.typeTopicCounts[type].adjustOrPutValue(newTopic, 1, 1);
            int n3 = newTopic;
            this.tokensPerTopic[n3] = this.tokensPerTopic[n3] + 1;
        }
        if (saveStateForAlphaEstimation) {
            int n = docLen;
            this.docLengthCounts[n] = this.docLengthCounts[n] + 1;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                int[] nArray = this.topicDocCounts[topic];
                int n4 = this.oneDocTopicCounts[topic];
                nArray[n4] = nArray[n4] + 1;
            }
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean shouldSaveState, boolean readjustTopicsAndStats) {
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        TIntIntHashMap localTopicCounts = new TIntIntHashMap();
        for (int position = 0; position < docLength; ++position) {
            localTopicCounts.adjustOrPutValue(oneDocTopics[position], 1, 1);
        }
        double topicBetaMass = 0.0;
        for (int topic : localTopicCounts.keys()) {
            int n = localTopicCounts.get(topic);
            topicBetaMass += this.beta * (double)n / ((double)this.tokensPerTopic[topic] + this.betaSum);
            this.cachedCoefficients[topic] = (this.alpha[topic] + (double)n) / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        double topicTermMass = 0.0;
        double[] topicTermScores = new double[this.numTopics];
        for (int position = 0; position < docLength; ++position) {
            double sample;
            int i;
            int type = tokenSequence.getIndexAtPosition(position);
            int oldTopic = oneDocTopics[position];
            TIntIntHashMap currentTypeTopicCounts = this.typeTopicCounts[type];
            assert (currentTypeTopicCounts.get(oldTopic) >= 0);
            if (currentTypeTopicCounts.get(oldTopic) == 1) {
                currentTypeTopicCounts.remove(oldTopic);
            } else {
                currentTypeTopicCounts.adjustValue(oldTopic, -1);
            }
            this.smoothingOnlyMass -= this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicBetaMass -= this.beta * (double)localTopicCounts.get(oldTopic) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            if (localTopicCounts.get(oldTopic) == 1) {
                localTopicCounts.remove(oldTopic);
            } else {
                localTopicCounts.adjustValue(oldTopic, -1);
            }
            int n = oldTopic;
            this.tokensPerTopic[n] = this.tokensPerTopic[n] - 1;
            this.smoothingOnlyMass += this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts.get(oldTopic) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            this.cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts.get(oldTopic)) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicTermMass = 0.0;
            int[] topicTermIndices = currentTypeTopicCounts.keys();
            int[] topicTermValues = currentTypeTopicCounts.getValues();
            for (i = 0; i < topicTermIndices.length; ++i) {
                int topic = topicTermIndices[i];
                double score = this.cachedCoefficients[topic] * (double)topicTermValues[i];
                topicTermMass += score;
                topicTermScores[i] = score;
            }
            double origSample = sample = this.random.nextUniform() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
            int newTopic = -1;
            if (sample < topicTermMass) {
                i = -1;
                while (sample > 0.0) {
                    sample -= topicTermScores[++i];
                }
                newTopic = topicTermIndices[i];
            } else if ((sample -= topicTermMass) < topicBetaMass) {
                sample /= this.beta;
                topicTermIndices = localTopicCounts.keys();
                topicTermValues = localTopicCounts.getValues();
                for (i = 0; i < topicTermIndices.length && !((sample -= (double)topicTermValues[i] / ((double)this.tokensPerTopic[newTopic = topicTermIndices[i]] + this.betaSum)) <= 0.0); ++i) {
                }
            } else {
                sample -= topicBetaMass;
                sample /= this.beta;
                for (int topic = 0; topic < this.numTopics; ++topic) {
                    if (!((sample -= this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum)) <= 0.0)) continue;
                    newTopic = topic;
                    break;
                }
            }
            if (newTopic == -1) {
                System.err.println("LDAHyper sampling error: " + origSample + " " + sample + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                newTopic = this.numTopics - 1;
            }
            oneDocTopics[position] = newTopic;
            currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
            this.smoothingOnlyMass -= this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass -= this.beta * (double)localTopicCounts.get(newTopic) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            localTopicCounts.adjustOrPutValue(newTopic, 1, 1);
            int n2 = newTopic;
            this.tokensPerTopic[n2] = this.tokensPerTopic[n2] + 1;
            this.cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts.get(newTopic)) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            this.smoothingOnlyMass += this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts.get(newTopic) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            assert (currentTypeTopicCounts.get(newTopic) >= 0);
        }
        for (int topic : localTopicCounts.keys()) {
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        if (shouldSaveState) {
            int n = docLength;
            this.docLengthCounts[n] = this.docLengthCounts[n] + 1;
            for (int topic : localTopicCounts.keys()) {
                int[] nArray = this.topicDocCounts[topic];
                int n3 = localTopicCounts.get(topic);
                nArray[n3] = nArray[n3] + 1;
            }
        }
    }

    public IDSorter[] getSortedTopicWords(int topic) {
        Object[] sortedTypes = new IDSorter[this.numTypes];
        for (int type = 0; type < this.numTypes; ++type) {
            sortedTypes[type] = new IDSorter(type, this.typeTopicCounts[type].get(topic));
        }
        Arrays.sort(sortedTypes);
        return sortedTypes;
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out2 = new PrintStream(file);
        this.printTopWords(out2, numWords, useNewLines);
        out2.close();
    }

    public void printTopWords(PrintStream out2, int numWords, boolean usingNewLines) {
        for (int topic = 0; topic < this.numTopics; ++topic) {
            IDSorter info;
            int word;
            Iterator iterator;
            TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
            for (int type = 0; type < this.numTypes; ++type) {
                if (!this.typeTopicCounts[type].containsKey(topic)) continue;
                sortedWords.add(new IDSorter(type, this.typeTopicCounts[type].get(topic)));
            }
            if (usingNewLines) {
                out2.println("Topic " + topic);
                iterator = sortedWords.iterator();
                for (word = 1; iterator.hasNext() && word < numWords; ++word) {
                    info = (IDSorter)iterator.next();
                    out2.println(this.alphabet.lookupObject(info.getID()) + "\t" + (int)info.getWeight());
                }
                continue;
            }
            out2.print(topic + "\t" + this.formatter.format(this.alpha[topic]) + "\t" + this.tokensPerTopic[topic] + "\t");
            iterator = sortedWords.iterator();
            for (word = 1; iterator.hasNext() && word < numWords; ++word) {
                info = (IDSorter)iterator.next();
                out2.print(this.alphabet.lookupObject(info.getID()) + " ");
            }
            out2.println();
        }
    }

    public void topicXMLReport(PrintWriter out2, int numWords) {
        out2.println("<?xml version='1.0' ?>");
        out2.println("<topicModel>");
        for (int topic = 0; topic < this.numTopics; ++topic) {
            out2.println("  <topic id='" + topic + "' alpha='" + this.alpha[topic] + "' totalTokens='" + this.tokensPerTopic[topic] + "'>");
            TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
            for (int type = 0; type < this.numTypes; ++type) {
                if (!this.typeTopicCounts[type].containsKey(topic)) continue;
                sortedWords.add(new IDSorter(type, this.typeTopicCounts[type].get(topic)));
            }
            Iterator iterator = sortedWords.iterator();
            for (int word = 1; iterator.hasNext() && word < numWords; ++word) {
                IDSorter info = (IDSorter)iterator.next();
                out2.println("    <word rank='" + word + "'>" + this.alphabet.lookupObject(info.getID()) + "</word>");
            }
            out2.println("  </topic>");
        }
        out2.println("</topicModel>");
    }

    public void topicXMLReportPhrases(PrintStream out2, int numWords) {
        int numTopics = this.getNumTopics();
        TObjectIntHashMap[] phrases = new TObjectIntHashMap[numTopics];
        Alphabet alphabet = this.getAlphabet();
        for (int ti = 0; ti < numTopics; ++ti) {
            phrases[ti] = new TObjectIntHashMap();
        }
        for (int di = 0; di < this.getData().size(); ++di) {
            Topication t = this.getData().get(di);
            Instance instance = t.instance;
            FeatureSequence fvs = (FeatureSequence)instance.getData();
            boolean withBigrams = false;
            if (fvs instanceof FeatureSequenceWithBigrams) {
                withBigrams = true;
            }
            int prevtopic = -1;
            int prevfeature = -1;
            int topic = -1;
            StringBuffer sb = null;
            int feature = -1;
            int doclen = fvs.size();
            for (int pi = 0; pi < doclen; ++pi) {
                feature = fvs.getIndexAtPosition(pi);
                topic = this.getData().get((int)di).topicSequence.getIndexAtPosition(pi);
                if (!(topic != prevtopic || withBigrams && ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) == -1)) {
                    if (sb == null) {
                        sb = new StringBuffer(alphabet.lookupObject(prevfeature).toString() + " " + alphabet.lookupObject(feature));
                        continue;
                    }
                    sb.append(" ");
                    sb.append(alphabet.lookupObject(feature));
                    continue;
                }
                if (sb != null) {
                    String sbs = sb.toString();
                    if (phrases[prevtopic].get(sbs) == 0) {
                        phrases[prevtopic].put(sbs, 0);
                    }
                    phrases[prevtopic].increment(sbs);
                    prevfeature = -1;
                    prevtopic = -1;
                    sb = null;
                    continue;
                }
                prevtopic = topic;
                prevfeature = feature;
            }
        }
        out2.println("<?xml version='1.0' ?>");
        out2.println("<topics>");
        double[] probs = new double[alphabet.size()];
        for (int ti = 0; ti < numTopics; ++ti) {
            out2.print("  <topic id=\"" + ti + "\" alpha=\"" + this.alpha[ti] + "\" totalTokens=\"" + this.tokensPerTopic[ti] + "\" ");
            ByteArrayOutputStream bout = new ByteArrayOutputStream();
            PrintStream pout = new PrintStream(bout);
            AugmentableFeatureVector titles = new AugmentableFeatureVector(new Alphabet());
            for (int type = 0; type < alphabet.size(); ++type) {
                probs[type] = (double)this.getCountFeatureTopic(type, ti) / (double)this.getCountTokensPerTopic(ti);
            }
            RankedFeatureVector rfv = new RankedFeatureVector(alphabet, probs);
            for (int ri = 0; ri < numWords; ++ri) {
                int fi = rfv.getIndexAtRank(ri);
                pout.println("      <term weight=\"" + probs[fi] + "\" count=\"" + this.getCountFeatureTopic(fi, ti) + "\">" + alphabet.lookupObject(fi) + "</term>");
                if (ri >= 20) continue;
                titles.add(alphabet.lookupObject(fi), (double)this.getCountFeatureTopic(fi, ti));
            }
            Object[] keys = phrases[ti].keys();
            int[] values = phrases[ti].getValues();
            double[] counts = new double[keys.length];
            for (int i = 0; i < counts.length; ++i) {
                counts[i] = values[i];
            }
            double countssum = MatrixOps.sum(counts);
            Alphabet alph = new Alphabet(keys);
            rfv = new RankedFeatureVector(alph, counts);
            int max = rfv.numLocations() < numWords ? rfv.numLocations() : numWords;
            for (int ri = 0; ri < max; ++ri) {
                int fi = rfv.getIndexAtRank(ri);
                pout.println("      <phrase weight=\"" + counts[fi] / countssum + "\" count=\"" + values[fi] + "\">" + alph.lookupObject(fi) + "</phrase>");
                if (ri >= 20 || values[fi] <= 20) continue;
                titles.add(alph.lookupObject(fi), (double)(100 * values[fi]));
            }
            StringBuffer titlesStringBuffer = new StringBuffer();
            rfv = new RankedFeatureVector(titles.getAlphabet(), titles);
            int numTitles = 10;
            for (int ri = 0; ri < numTitles && ri < rfv.numLocations(); ++ri) {
                if (titlesStringBuffer.indexOf(rfv.getObjectAtRank(ri).toString()) == -1) {
                    titlesStringBuffer.append(rfv.getObjectAtRank(ri));
                    if (ri >= numTitles - 1) continue;
                    titlesStringBuffer.append(", ");
                    continue;
                }
                ++numTitles;
            }
            out2.println("titles=\"" + titlesStringBuffer.toString() + "\">");
            out2.print(pout.toString());
            out2.println("  </topic>");
        }
        out2.println("</topics>");
    }

    public void printDocumentTopics(File f) throws IOException {
        this.printDocumentTopics(new PrintWriter(new FileWriter(f)));
    }

    public void printDocumentTopics(PrintWriter pw) {
        this.printDocumentTopics(pw, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
        pw.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            sortedTopics[topic] = new IDSorter(topic, topic);
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        for (int di = 0; di < this.data.size(); ++di) {
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            pw.print(di);
            pw.print(' ');
            if (this.data.get((int)di).instance.getSource() != null) {
                pw.print(this.data.get((int)di).instance.getSource());
            } else {
                pw.print("null-source");
            }
            pw.print(' ');
            int docLen = currentDocTopics.length;
            for (int token = 0; token < docLen; ++token) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                ((IDSorter)sortedTopics[topic]).set(topic, (float)topicCounts[topic] / (float)docLen);
            }
            Arrays.sort(sortedTopics);
            for (int i = 0; i < max && !(((IDSorter)sortedTopics[i]).getWeight() < threshold); ++i) {
                pw.print(((IDSorter)sortedTopics[i]).getID() + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
            }
            pw.print(" \n");
            Arrays.fill(topicCounts, 0);
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out2 = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out2);
        out2.close();
    }

    public void printState(PrintStream out2) {
        out2.println("#doc source pos typeindex type topic");
        for (int di = 0; di < this.data.size(); ++di) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)di).instance.getData();
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            String source = "NA";
            if (this.data.get((int)di).instance.getSource() != null) {
                source = this.data.get((int)di).instance.getSource().toString();
            }
            for (int pi = 0; pi < topicSequence.getLength(); ++pi) {
                int type = tokenSequence.getIndexAtPosition(pi);
                int topic = topicSequence.getIndexAtPosition(pi);
                out2.print(di);
                out2.print(' ');
                out2.print(source);
                out2.print(' ');
                out2.print(pi);
                out2.print(' ');
                out2.print(type);
                out2.print(' ');
                out2.print(this.alphabet.lookupObject(type));
                out2.print(' ');
                out2.print(topic);
                out2.println();
            }
        }
    }

    public void write(File f) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("LDAHyper.write: Exception writing LDAHyper to file " + f + ": " + e);
        }
    }

    public static LDAHyper read(File f) {
        LDAHyper lda = null;
        try {
            ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
            lda = (LDAHyper)ois.readObject();
            lda.initializeTypeTopicCounts();
            ois.close();
        }
        catch (IOException e) {
            System.err.println("Exception reading file " + f + ": " + e);
        }
        catch (ClassNotFoundException e) {
            System.err.println("Exception reading file " + f + ": " + e);
        }
        return lda;
    }

    private void writeObject(ObjectOutputStream out2) throws IOException {
        out2.writeInt(0);
        out2.writeObject(this.data);
        out2.writeObject(this.alphabet);
        out2.writeObject(this.topicAlphabet);
        out2.writeInt(this.numTopics);
        out2.writeObject(this.alpha);
        out2.writeDouble(this.beta);
        out2.writeDouble(this.betaSum);
        out2.writeDouble(this.smoothingOnlyMass);
        out2.writeObject(this.cachedCoefficients);
        out2.writeInt(this.iterationsSoFar);
        out2.writeInt(this.numIterations);
        out2.writeInt(this.burninPeriod);
        out2.writeInt(this.saveSampleInterval);
        out2.writeInt(this.optimizeInterval);
        out2.writeInt(this.showTopicsInterval);
        out2.writeInt(this.wordsPerTopic);
        out2.writeInt(this.outputModelInterval);
        out2.writeObject(this.outputModelFilename);
        out2.writeInt(this.saveStateInterval);
        out2.writeObject(this.stateFilename);
        out2.writeObject(this.random);
        out2.writeObject(this.formatter);
        out2.writeBoolean(this.printLogLikelihood);
        out2.writeObject(this.docLengthCounts);
        out2.writeObject(this.topicDocCounts);
        for (int fi = 0; fi < this.numTypes; ++fi) {
            out2.writeObject(this.typeTopicCounts[fi]);
        }
        for (int ti = 0; ti < this.numTopics; ++ti) {
            out2.writeInt(this.tokensPerTopic[ti]);
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.data = (ArrayList)in.readObject();
        this.alphabet = (Alphabet)in.readObject();
        this.topicAlphabet = (LabelAlphabet)in.readObject();
        this.numTopics = in.readInt();
        this.alpha = (double[])in.readObject();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.smoothingOnlyMass = in.readDouble();
        this.cachedCoefficients = (double[])in.readObject();
        this.iterationsSoFar = in.readInt();
        this.numIterations = in.readInt();
        this.burninPeriod = in.readInt();
        this.saveSampleInterval = in.readInt();
        this.optimizeInterval = in.readInt();
        this.showTopicsInterval = in.readInt();
        this.wordsPerTopic = in.readInt();
        this.outputModelInterval = in.readInt();
        this.outputModelFilename = (String)in.readObject();
        this.saveStateInterval = in.readInt();
        this.stateFilename = (String)in.readObject();
        this.random = (Randoms)in.readObject();
        this.formatter = (NumberFormat)in.readObject();
        this.printLogLikelihood = in.readBoolean();
        this.docLengthCounts = (int[])in.readObject();
        this.topicDocCounts = (int[][])in.readObject();
        int numDocs = this.data.size();
        this.numTypes = this.alphabet.size();
        this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
        for (int fi = 0; fi < this.numTypes; ++fi) {
            this.typeTopicCounts[fi] = (TIntIntHashMap)in.readObject();
        }
        this.tokensPerTopic = new int[this.numTopics];
        for (int ti = 0; ti < this.numTopics; ++ti) {
            this.tokensPerTopic[ti] = in.readInt();
        }
    }

    public double topicLabelMutualInformation() {
        double p;
        int topic;
        int label;
        if (this.data.get((int)0).instance.getTargetAlphabet() == null) {
            return 0.0;
        }
        int targetAlphabetSize = this.data.get((int)0).instance.getTargetAlphabet().size();
        int[][] topicLabelCounts = new int[this.numTopics][targetAlphabetSize];
        int[] topicCounts = new int[this.numTopics];
        int[] labelCounts = new int[targetAlphabetSize];
        int total = 0;
        for (int doc = 0; doc < this.data.size(); ++doc) {
            label = this.data.get((int)doc).instance.getLabeling().getBestIndex();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            for (int token = 0; token < docTopics.length; ++token) {
                topic = docTopics[token];
                int[] nArray = topicLabelCounts[topic];
                int n = label;
                nArray[n] = nArray[n] + 1;
                int n2 = topic;
                topicCounts[n2] = topicCounts[n2] + 1;
                int n3 = label;
                labelCounts[n3] = labelCounts[n3] + 1;
                ++total;
            }
        }
        double topicEntropy = 0.0;
        double labelEntropy = 0.0;
        double jointEntropy = 0.0;
        double log2 = Math.log(2.0);
        for (topic = 0; topic < topicCounts.length; ++topic) {
            if (topicCounts[topic] == 0) continue;
            p = (double)topicCounts[topic] / (double)total;
            topicEntropy -= p * Math.log(p) / log2;
        }
        for (label = 0; label < labelCounts.length; ++label) {
            if (labelCounts[label] == 0) continue;
            p = (double)labelCounts[label] / (double)total;
            labelEntropy -= p * Math.log(p) / log2;
        }
        for (topic = 0; topic < topicCounts.length; ++topic) {
            for (label = 0; label < labelCounts.length; ++label) {
                if (topicLabelCounts[topic][label] == 0) continue;
                p = (double)topicLabelCounts[topic][label] / (double)total;
                jointEntropy -= p * Math.log(p) / log2;
            }
        }
        return topicEntropy + labelEntropy - jointEntropy;
    }

    public double empiricalLikelihood(int numSamples, InstanceList testing) {
        int doc;
        int sample;
        double[][] likelihoods = new double[testing.size()][numSamples];
        double[] multinomial = new double[this.numTypes];
        Dirichlet topicPrior = new Dirichlet(this.alpha);
        for (sample = 0; sample < numSamples; ++sample) {
            int type;
            double[] topicDistribution = topicPrior.nextDistribution();
            Arrays.fill(multinomial, 0.0);
            for (int topic = 0; topic < this.numTopics; ++topic) {
                for (type = 0; type < this.numTypes; ++type) {
                    int n = type;
                    multinomial[n] = multinomial[n] + topicDistribution[topic] * (this.beta + (double)this.typeTopicCounts[type].get(topic)) / (this.betaSum + (double)this.tokensPerTopic[topic]);
                }
            }
            for (type = 0; type < this.numTypes; ++type) {
                assert (multinomial[type] > 0.0);
                multinomial[type] = Math.log(multinomial[type]);
            }
            for (doc = 0; doc < testing.size(); ++doc) {
                FeatureSequence fs = (FeatureSequence)((Instance)testing.get(doc)).getData();
                int seqLen = fs.getLength();
                for (int token = 0; token < seqLen; ++token) {
                    type = fs.getIndexAtPosition(token);
                    if (type >= this.numTypes) continue;
                    double[] dArray = likelihoods[doc];
                    int n = sample;
                    dArray[n] = dArray[n] + multinomial[type];
                }
            }
        }
        double averageLogLikelihood = 0.0;
        double logNumSamples = Math.log(numSamples);
        for (doc = 0; doc < testing.size(); ++doc) {
            double max = Double.NEGATIVE_INFINITY;
            for (sample = 0; sample < numSamples; ++sample) {
                if (!(likelihoods[doc][sample] > max)) continue;
                max = likelihoods[doc][sample];
            }
            double sum = 0.0;
            for (sample = 0; sample < numSamples; ++sample) {
                sum += Math.exp(likelihoods[doc][sample] - max);
            }
            averageLogLikelihood += Math.log(sum) + max - logNumSamples;
        }
        return averageLogLikelihood;
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
        }
        for (int doc = 0; doc < this.data.size(); ++doc) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            for (int token = 0; token < docTopics.length; ++token) {
                int n = docTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
            }
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (topicCounts[topic] <= 0) continue;
                logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic] + (double)topicCounts[topic]) - topicLogGammas[topic];
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)docTopics.length);
            Arrays.fill(topicCounts, 0);
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        int nonZeroTypeTopics = 0;
        for (int type = 0; type < this.numTypes; ++type) {
            int[] usedTopics;
            for (int topic : usedTopics = this.typeTopicCounts[type].keys()) {
                int count = this.typeTopicCounts[type].get(topic);
                if (count <= 0) continue;
                ++nonZeroTypeTopics;
                logLikelihood += Dirichlet.logGammaStirling(this.beta + (double)count);
            }
        }
        for (int topic = 0; topic < this.numTopics; ++topic) {
            logLikelihood -= Dirichlet.logGammaStirling(this.beta * (double)this.numTopics + (double)this.tokensPerTopic[topic]);
        }
        return logLikelihood += Dirichlet.logGammaStirling(this.beta * (double)this.numTopics) - Dirichlet.logGammaStirling(this.beta) * (double)nonZeroTypeTopics;
    }

    public static void main(String[] args) throws IOException {
        InstanceList training = InstanceList.load(new File(args[0]));
        int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
        InstanceList testing = args.length > 2 ? InstanceList.load(new File(args[2])) : null;
        LDAHyper lda = new LDAHyper(numTopics, 50.0, 0.01);
        lda.printLogLikelihood = true;
        lda.setTopicDisplay(50, 7);
        lda.addInstances(training);
        lda.estimate();
    }

    public class Topication
    implements Serializable {
        public Instance instance;
        public LDAHyper model;
        public LabelSequence topicSequence;
        public Labeling topicDistribution;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 0;

        public Topication(Instance instance, LDAHyper model, LabelSequence topicSequence) {
            this.instance = instance;
            this.model = model;
            this.topicSequence = topicSequence;
        }

        private void writeObject(ObjectOutputStream out2) throws IOException {
            out2.writeInt(0);
            out2.writeObject(this.instance);
            out2.writeObject(this.model);
            out2.writeObject(this.topicSequence);
            out2.writeObject(this.topicDistribution);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            this.instance = (Instance)in.readObject();
            this.model = (LDAHyper)in.readObject();
            this.topicSequence = (LabelSequence)in.readObject();
            this.topicDistribution = (Labeling)in.readObject();
        }
    }
}

