/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.Arrays;

public class LDA
implements Serializable {
    int numTopics;
    double alpha;
    double beta;
    double tAlpha;
    double vBeta;
    InstanceList ilist;
    int[][] topics;
    int numTypes;
    int numTokens;
    int[][] docTopicCounts;
    int[][] typeTopicCounts;
    int[] tokensPerTopic;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public LDA(int numberOfTopics) {
        this(numberOfTopics, 50.0, 0.01);
    }

    public LDA(int numberOfTopics, double alphaSum, double beta) {
        this.numTopics = numberOfTopics;
        this.alpha = alphaSum / (double)this.numTopics;
        this.beta = beta;
    }

    public void estimate(InstanceList documents, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) {
        this.ilist = documents.shallowClone();
        this.numTypes = this.ilist.getDataAlphabet().size();
        int numDocs = this.ilist.size();
        this.topics = new int[numDocs][];
        this.docTopicCounts = new int[numDocs][this.numTopics];
        this.typeTopicCounts = new int[this.numTypes][this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.tAlpha = this.alpha * (double)this.numTopics;
        this.vBeta = this.beta * (double)this.numTypes;
        long startTime = System.currentTimeMillis();
        for (int di = 0; di < numDocs; ++di) {
            FeatureSequence fs;
            try {
                fs = (FeatureSequence)((Instance)this.ilist.get(di)).getData();
            }
            catch (ClassCastException e) {
                System.err.println("LDA and other topic models expect FeatureSequence data, not FeatureVector data.  With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
                throw e;
            }
            int seqLen = fs.getLength();
            this.numTokens += seqLen;
            this.topics[di] = new int[seqLen];
            for (int si = 0; si < seqLen; ++si) {
                int topic;
                this.topics[di][si] = topic = r.nextInt(this.numTopics);
                int[] nArray = this.docTopicCounts[di];
                int n = topic;
                nArray[n] = nArray[n] + 1;
                int[] nArray2 = this.typeTopicCounts[fs.getIndexAtPosition(si)];
                int n2 = topic;
                nArray2[n2] = nArray2[n2] + 1;
                int n3 = topic;
                this.tokensPerTopic[n3] = this.tokensPerTopic[n3] + 1;
            }
        }
        this.estimate(0, numDocs, numIterations, showTopicsInterval, outputModelInterval, outputModelFilename, r);
    }

    public void addDocuments(InstanceList additionalDocuments, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) {
        if (this.ilist == null) {
            throw new IllegalStateException("Must already have some documents first.");
        }
        for (Instance inst : additionalDocuments) {
            this.ilist.add(inst);
        }
        assert (this.ilist.getDataAlphabet() == additionalDocuments.getDataAlphabet());
        assert (additionalDocuments.getDataAlphabet().size() >= this.numTypes);
        this.numTypes = additionalDocuments.getDataAlphabet().size();
        int numNewDocs = additionalDocuments.size();
        int numOldDocs = this.topics.length;
        int numDocs = numOldDocs + numNewDocs;
        int[][] newTopics = new int[numDocs][];
        for (int i = 0; i < this.topics.length; ++i) {
            newTopics[i] = this.topics[i];
        }
        this.topics = newTopics;
        int[][] newDocTopicCounts = new int[numDocs][this.numTopics];
        for (int i = 0; i < this.docTopicCounts.length; ++i) {
            newDocTopicCounts[i] = this.docTopicCounts[i];
        }
        this.docTopicCounts = newDocTopicCounts;
        int[][] newTypeTopicCounts = new int[this.numTypes][this.numTopics];
        for (int i = 0; i < this.typeTopicCounts.length; ++i) {
            for (int j = 0; j < this.numTopics; ++j) {
                newTypeTopicCounts[i][j] = this.typeTopicCounts[i][j];
            }
        }
        for (int di = numOldDocs; di < numDocs; ++di) {
            FeatureSequence fs;
            try {
                fs = (FeatureSequence)((Instance)additionalDocuments.get(di - numOldDocs)).getData();
            }
            catch (ClassCastException e) {
                System.err.println("LDA and other topic models expect FeatureSequence data, not FeatureVector data.  With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
                throw e;
            }
            int seqLen = fs.getLength();
            this.numTokens += seqLen;
            this.topics[di] = new int[seqLen];
            for (int si = 0; si < seqLen; ++si) {
                int topic;
                this.topics[di][si] = topic = r.nextInt(this.numTopics);
                int[] nArray = this.docTopicCounts[di];
                int n = topic;
                nArray[n] = nArray[n] + 1;
                int[] nArray2 = this.typeTopicCounts[fs.getIndexAtPosition(si)];
                int n2 = topic;
                nArray2[n2] = nArray2[n2] + 1;
                int n3 = topic;
                this.tokensPerTopic[n3] = this.tokensPerTopic[n3] + 1;
            }
        }
    }

    public void estimate(int docIndexStart, int docIndexLength, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) {
        long startTime = System.currentTimeMillis();
        for (int iterations = 0; iterations < numIterations; ++iterations) {
            if (iterations % 10 == 0) {
                System.out.print(iterations);
            } else {
                System.out.print(".");
            }
            System.out.flush();
            if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0 && iterations > 0) {
                System.out.println();
                this.printTopWords(5, false);
            }
            if (outputModelInterval != 0 && iterations % outputModelInterval == 0 && iterations > 0) {
                this.write(new File(outputModelFilename + '.' + iterations));
            }
            this.sampleTopicsForDocs(docIndexStart, docIndexLength, r);
        }
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    public void sampleTopicsForAllDocs(Randoms r) {
        double[] topicWeights = new double[this.numTopics];
        for (int di = 0; di < this.topics.length; ++di) {
            this.sampleTopicsForOneDoc((FeatureSequence)((Instance)this.ilist.get(di)).getData(), this.topics[di], this.docTopicCounts[di], topicWeights, r);
        }
    }

    public void sampleTopicsForDocs(int start, int length, Randoms r) {
        assert (start + length <= this.docTopicCounts.length);
        double[] topicWeights = new double[this.numTopics];
        for (int di = start; di < start + length; ++di) {
            this.sampleTopicsForOneDoc((FeatureSequence)((Instance)this.ilist.get(di)).getData(), this.topics[di], this.docTopicCounts[di], topicWeights, r);
        }
    }

    private void sampleTopicsForOneDoc(FeatureSequence oneDocTokens, int[] oneDocTopics, int[] oneDocTopicCounts, double[] topicWeights, Randoms r) {
        int docLen = oneDocTokens.getLength();
        for (int si = 0; si < docLen; ++si) {
            int newTopic;
            int oldTopic;
            int type = oneDocTokens.getIndexAtPosition(si);
            int n = oldTopic = oneDocTopics[si];
            oneDocTopicCounts[n] = oneDocTopicCounts[n] - 1;
            int[] nArray = this.typeTopicCounts[type];
            int n2 = oldTopic;
            nArray[n2] = nArray[n2] - 1;
            int n3 = oldTopic;
            this.tokensPerTopic[n3] = this.tokensPerTopic[n3] - 1;
            Arrays.fill(topicWeights, 0.0);
            double topicWeightsSum = 0.0;
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            for (int ti = 0; ti < this.numTopics; ++ti) {
                double tw = ((double)currentTypeTopicCounts[ti] + this.beta) / ((double)this.tokensPerTopic[ti] + this.vBeta) * ((double)oneDocTopicCounts[ti] + this.alpha);
                topicWeightsSum += tw;
                topicWeights[ti] = tw;
            }
            oneDocTopics[si] = newTopic = r.nextDiscrete(topicWeights, topicWeightsSum);
            int n4 = newTopic;
            oneDocTopicCounts[n4] = oneDocTopicCounts[n4] + 1;
            int[] nArray2 = this.typeTopicCounts[type];
            int n5 = newTopic;
            nArray2[n5] = nArray2[n5] + 1;
            int n6 = newTopic;
            this.tokensPerTopic[n6] = this.tokensPerTopic[n6] + 1;
        }
    }

    public int[][] getDocTopicCounts() {
        return this.docTopicCounts;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getTokensPerTopic() {
        return this.tokensPerTopic;
    }

    public void printTopWords(int numWords, boolean useNewLines) {
        class WordProb
        implements Comparable {
            int wi;
            double p;

            public WordProb(int wi, double p) {
                this.wi = wi;
                this.p = p;
            }

            public final int compareTo(Object o2) {
                if (this.p > ((WordProb)o2).p) {
                    return -1;
                }
                if (this.p == ((WordProb)o2).p) {
                    return 0;
                }
                return 1;
            }
        }
        Object[] wp = new WordProb[this.numTypes];
        for (int ti = 0; ti < this.numTopics; ++ti) {
            int i;
            for (int wi = 0; wi < this.numTypes; ++wi) {
                wp[wi] = new WordProb(wi, (double)this.typeTopicCounts[wi][ti] / (double)this.tokensPerTopic[ti]);
            }
            Arrays.sort(wp);
            if (useNewLines) {
                System.out.println("\nTopic " + ti);
                for (i = 0; i < numWords; ++i) {
                    System.out.println(this.ilist.getDataAlphabet().lookupObject(((WordProb)wp[i]).wi).toString() + " " + ((WordProb)wp[i]).p);
                }
                continue;
            }
            System.out.print("Topic " + ti + ": ");
            for (i = 0; i < numWords; ++i) {
                System.out.print(this.ilist.getDataAlphabet().lookupObject(((WordProb)wp[i]).wi).toString() + " ");
            }
            System.out.println();
        }
    }

    public void printDocumentTopics(File f) throws IOException {
        this.printDocumentTopics(new PrintWriter(new FileWriter(f)));
    }

    public void printDocumentTopics(PrintWriter pw) {
        this.printDocumentTopics(pw, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
        pw.println("#doc source topic proportion ...");
        double[] topicDist = new double[this.topics.length];
        for (int di = 0; di < this.topics.length; ++di) {
            pw.print(di);
            pw.print(' ');
            if (((Instance)this.ilist.get(di)).getSource() != null) {
                pw.print(((Instance)this.ilist.get(di)).getSource().toString());
            } else {
                pw.print("null-source");
            }
            pw.print(' ');
            int docLen = this.topics[di].length;
            for (int ti = 0; ti < this.numTopics; ++ti) {
                topicDist[ti] = (float)this.docTopicCounts[di][ti] / (float)docLen;
            }
            if (max < 0) {
                max = this.numTopics;
            }
            for (int tp = 0; tp < max; ++tp) {
                double maxvalue = 0.0;
                int maxindex = -1;
                for (int ti = 0; ti < this.numTopics; ++ti) {
                    if (!(topicDist[ti] > maxvalue)) continue;
                    maxvalue = topicDist[ti];
                    maxindex = ti;
                }
                if (maxindex == -1 || topicDist[maxindex] < threshold) break;
                pw.print(maxindex + " " + topicDist[maxindex] + " ");
                topicDist[maxindex] = 0.0;
            }
            pw.println(' ');
        }
    }

    public void printState(File f) throws IOException {
        PrintWriter writer = new PrintWriter(new FileWriter(f));
        this.printState(writer);
        writer.close();
    }

    public void printState(PrintWriter pw) {
        Alphabet a = this.ilist.getDataAlphabet();
        pw.println("#doc pos typeindex type topic");
        for (int di = 0; di < this.topics.length; ++di) {
            FeatureSequence fs = (FeatureSequence)((Instance)this.ilist.get(di)).getData();
            for (int si = 0; si < this.topics[di].length; ++si) {
                int type = fs.getIndexAtPosition(si);
                pw.print(di);
                pw.print(' ');
                pw.print(si);
                pw.print(' ');
                pw.print(type);
                pw.print(' ');
                pw.print(a.lookupObject(type));
                pw.print(' ');
                pw.print(this.topics[di][si]);
                pw.println();
            }
        }
    }

    public void write(File f) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Exception writing file " + f + ": " + e);
        }
    }

    private void writeObject(ObjectOutputStream out2) throws IOException {
        int ti;
        int di;
        out2.writeInt(0);
        out2.writeObject(this.ilist);
        out2.writeInt(this.numTopics);
        out2.writeDouble(this.alpha);
        out2.writeDouble(this.beta);
        out2.writeDouble(this.tAlpha);
        out2.writeDouble(this.vBeta);
        for (di = 0; di < this.topics.length; ++di) {
            for (int si = 0; si < this.topics[di].length; ++si) {
                out2.writeInt(this.topics[di][si]);
            }
        }
        for (di = 0; di < this.topics.length; ++di) {
            for (ti = 0; ti < this.numTopics; ++ti) {
                out2.writeInt(this.docTopicCounts[di][ti]);
            }
        }
        for (int fi = 0; fi < this.numTypes; ++fi) {
            for (ti = 0; ti < this.numTopics; ++ti) {
                out2.writeInt(this.typeTopicCounts[fi][ti]);
            }
        }
        for (int ti2 = 0; ti2 < this.numTopics; ++ti2) {
            out2.writeInt(this.tokensPerTopic[ti2]);
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int ti;
        int di;
        int version = in.readInt();
        this.ilist = (InstanceList)in.readObject();
        this.numTopics = in.readInt();
        this.alpha = in.readDouble();
        this.beta = in.readDouble();
        this.tAlpha = in.readDouble();
        this.vBeta = in.readDouble();
        int numDocs = this.ilist.size();
        this.topics = new int[numDocs][];
        for (di = 0; di < this.ilist.size(); ++di) {
            int docLen = ((FeatureSequence)((Instance)this.ilist.get(di)).getData()).getLength();
            this.topics[di] = new int[docLen];
            for (int si = 0; si < docLen; ++si) {
                this.topics[di][si] = in.readInt();
            }
        }
        this.docTopicCounts = new int[numDocs][this.numTopics];
        for (di = 0; di < this.ilist.size(); ++di) {
            for (ti = 0; ti < this.numTopics; ++ti) {
                this.docTopicCounts[di][ti] = in.readInt();
            }
        }
        int numTypes = this.ilist.getDataAlphabet().size();
        this.typeTopicCounts = new int[numTypes][this.numTopics];
        for (int fi = 0; fi < numTypes; ++fi) {
            for (int ti2 = 0; ti2 < this.numTopics; ++ti2) {
                this.typeTopicCounts[fi][ti2] = in.readInt();
            }
        }
        this.tokensPerTopic = new int[this.numTopics];
        for (ti = 0; ti < this.numTopics; ++ti) {
            this.tokensPerTopic[ti] = in.readInt();
        }
    }

    public InstanceList getInstanceList() {
        return this.ilist;
    }

    public static void main(String[] args) throws IOException {
        InstanceList ilist = InstanceList.load(new File(args[0]));
        int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000;
        int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20;
        System.out.println("Data loaded.");
        LDA lda = new LDA(10);
        lda.estimate(ilist, numIterations, 50, 0, null, new Randoms());
        lda.printTopWords(numTopWords, true);
        lda.printDocumentTopics(new File(args[0] + ".lda"));
    }
}

