/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.MarginalProbEstimator;
import cc.mallet.topics.TopicAssignment;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetFactory;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import gnu.trove.TIntDoubleHashMap;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.util.zip.GZIPOutputStream;

public class WeightedTopicModel
implements Serializable {
    private static Logger logger = MalletLogger.getLogger(WeightedTopicModel.class.getName());
    static CommandOption.String inputFile = new CommandOption.String(WeightedTopicModel.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureSequence or FeatureSequenceWithBigrams, not FeatureVector", null);
    static CommandOption.String weightsFile = new CommandOption.String(WeightedTopicModel.class, "weights-filename", "FILENAME", true, null, "The filename for the word-word weights file.", null);
    static CommandOption.String evaluatorFilename = new CommandOption.String(WeightedTopicModel.class, "evaluator-filename", "FILENAME", true, null, "A held-out likelihood evaluator for new documents.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.String stateFile = new CommandOption.String(WeightedTopicModel.class, "state-filename", "FILENAME", true, null, "The filename in which to write the Gibbs sampling state after at the end of the iterations.  By default this is null, indicating that no file will be written.", null);
    static CommandOption.Integer numTopicsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-topics", "INTEGER", true, 10, "The number of topics to fit.", null);
    static CommandOption.Integer numEpochsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-epochs", "INTEGER", true, 1, "The number of cycles of training. Evaluators and state files will be saved after each epoch.", null);
    static CommandOption.Integer numIterationsOption = new CommandOption.Integer(WeightedTopicModel.class, "num-iterations", "INTEGER", true, 1000, "The number of iterations of Gibbs sampling PER EPOCH.", null);
    static CommandOption.Integer randomSeedOption = new CommandOption.Integer(WeightedTopicModel.class, "random-seed", "INTEGER", true, 0, "The random seed for the Gibbs sampler.  Default is 0, which will use the clock.", null);
    static CommandOption.Double alphaOption = new CommandOption.Double(WeightedTopicModel.class, "alpha", "DECIMAL", true, 50.0, "Alpha parameter: smoothing over topic distribution.", null);
    static CommandOption.Double betaOption = new CommandOption.Double(WeightedTopicModel.class, "beta", "DECIMAL", true, 0.01, "Beta parameter: smoothing over topic distribution.", null);
    public static Pattern sourceWordPattern = Pattern.compile("(.*) \\((\\d+)\\)");
    public static Pattern targetWordPattern = Pattern.compile("  (\\d+)\t(\\d+)\t([\\d\\.]+)\t(.*)");
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int numTypes;
    protected double alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    protected int[] oneDocTopicCounts;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected TIntDoubleHashMap[] typeTypeWeights;
    protected double[][] logTypeTopicWeights;
    protected double[][] typeTopicWeights;
    protected double[] totalTopicWeights;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 10;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    protected double[] logCountRatioCache;

    public WeightedTopicModel(int numberOfTopics, double alphaSum, double beta, Randoms random) {
        this.topicAlphabet = AlphabetFactory.labelAlphabetOfSize(numberOfTopics);
        this.numTopics = this.topicAlphabet.size();
        this.alphaSum = alphaSum;
        this.alpha = alphaSum / (double)this.numTopics;
        this.beta = beta;
        this.random = random;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Weighted LDA: " + this.numTopics + " topics");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getTopicTotals() {
        return this.tokensPerTopic;
    }

    public void addInstances(InstanceList training) {
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        this.typeTopicCounts = new int[this.numTypes][this.numTopics];
        this.typeTopicWeights = new double[this.numTypes][this.numTopics];
        this.totalTopicWeights = new double[this.numTopics];
        for (int type = 0; type < this.numTypes; ++type) {
            Arrays.fill(this.typeTopicWeights[type], this.beta);
        }
        Arrays.fill(this.totalTopicWeights, this.betaSum);
        int doc = 0;
        for (Instance instance : training) {
            ++doc;
            FeatureSequence tokenSequence = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokenSequence.size()]);
            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            this.data.add(t);
        }
    }

    public void readTypeTypeWeights(File weightsFile) throws Exception {
        String line;
        this.typeTypeWeights = new TIntDoubleHashMap[this.numTypes];
        logger.info("num types: " + this.numTypes);
        for (int type = 0; type < this.numTypes; ++type) {
            this.typeTypeWeights[type] = new TIntDoubleHashMap();
            this.typeTypeWeights[type].put(type, 1.0);
        }
        int sourceType = 0;
        boolean sourceWordValid = true;
        BufferedReader reader = new BufferedReader(new FileReader(weightsFile));
        while ((line = reader.readLine()) != null) {
            int i;
            String[] fields = line.split("\t");
            double sum = 0.0;
            for (i = 1; i < fields.length; i += 2) {
                sum += Double.parseDouble(fields[i]);
            }
            sourceType = this.alphabet.lookupIndex(fields[0]);
            this.typeTypeWeights[sourceType].put(sourceType, Double.parseDouble(fields[1]) / sum);
            for (i = 2; i < fields.length; i += 2) {
                int targetType = this.alphabet.lookupIndex(fields[i]);
                this.typeTypeWeights[sourceType].put(targetType, Double.parseDouble(fields[i + 1]) / sum);
            }
        }
    }

    public void sample(int iterations, boolean shouldInitialize, int docCycleCount) throws IOException {
        for (int iteration = 1; iteration <= iterations; ++iteration) {
            long iterationStart = System.currentTimeMillis();
            for (int doc = 0; doc < this.data.size(); ++doc) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
                LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence, shouldInitialize && iteration == 1, false);
                for (int i = 1; i < docCycleCount; ++i) {
                    this.sampleTopicsForOneDoc(tokenSequence, topicSequence, false, false);
                }
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            logger.info(iteration + "\t" + elapsedMillis + "ms\t");
            if (this.showTopicsInterval == 0 || iteration % this.showTopicsInterval != 0) continue;
            logger.info("<" + iteration + ">\n" + this.topWords(this.wordsPerTopic));
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean initializing, boolean debugging) {
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        int[] localTopicCounts = new int[this.numTopics];
        if (!initializing) {
            for (int position = 0; position < docLength; ++position) {
                int n = oneDocTopics[position];
                localTopicCounts[n] = localTopicCounts[n] + 1;
            }
        }
        double[] topicTermScores = new double[this.numTopics];
        for (int position = 0; position < docLength; ++position) {
            int type = tokenSequence.getIndexAtPosition(position);
            int oldTopic = oneDocTopics[position];
            TIntDoubleHashMap typeFactors = this.typeTypeWeights[type];
            int[] connectedTypes = typeFactors.keys();
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            double[] currentTypeTopicWeights = this.typeTopicWeights[type];
            if (!initializing) {
                int n = oldTopic;
                localTopicCounts[n] = localTopicCounts[n] - 1;
                int n2 = oldTopic;
                this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
                assert (this.tokensPerTopic[oldTopic] >= 0);
                int n3 = oldTopic;
                currentTypeTopicCounts[n3] = currentTypeTopicCounts[n3] - 1;
                int typeCount = currentTypeTopicCounts[oldTopic];
                for (int otherType : connectedTypes) {
                    double factor = typeFactors.get(otherType);
                    double[] dArray = this.typeTopicWeights[otherType];
                    int n4 = oldTopic;
                    dArray[n4] = dArray[n4] - factor;
                    int n5 = oldTopic;
                    this.totalTopicWeights[n5] = this.totalTopicWeights[n5] - factor;
                }
            }
            double sum = 0.0;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                double score = (this.alpha + (double)localTopicCounts[topic]) * (currentTypeTopicWeights[topic] / this.totalTopicWeights[topic]);
                sum += score;
                topicTermScores[topic] = score;
                if (!debugging || type != 68) continue;
                System.out.println(type + "\t" + topic + "\t" + localTopicCounts[topic] + "\t" + currentTypeTopicCounts[topic] + "\t" + currentTypeTopicWeights[topic] + "\t" + this.tokensPerTopic[topic] + "\t" + sum);
            }
            double sample = this.random.nextUniform() * sum;
            if (debugging) {
                System.out.println("sample " + sample + " / " + sum);
            }
            int newTopic = -1;
            while (sample > 0.0) {
                sample -= topicTermScores[++newTopic];
            }
            if (debugging || newTopic == -1) {
                // empty if block
            }
            oneDocTopics[position] = newTopic;
            int n = newTopic;
            localTopicCounts[n] = localTopicCounts[n] + 1;
            int n6 = newTopic;
            this.tokensPerTopic[n6] = this.tokensPerTopic[n6] + 1;
            int n7 = newTopic;
            currentTypeTopicCounts[n7] = currentTypeTopicCounts[n7] + 1;
            int typeCount = currentTypeTopicCounts[newTopic];
            for (int otherType : connectedTypes) {
                double factor = typeFactors.get(otherType);
                double[] dArray = this.typeTopicWeights[otherType];
                int n8 = newTopic;
                dArray[n8] = dArray[n8] + factor;
                int n9 = newTopic;
                this.totalTopicWeights[n9] = this.totalTopicWeights[n9] + factor;
            }
        }
    }

    public String topWords(int numWords) {
        StringBuilder output = new StringBuilder();
        Object[] sortedWords = new IDSorter[this.numTypes];
        for (int topic = 0; topic < this.numTopics; ++topic) {
            for (int type = 0; type < this.numTypes; ++type) {
                sortedWords[type] = new IDSorter(type, this.typeTopicCounts[type][topic]);
            }
            Arrays.sort(sortedWords);
            output.append(topic + "\t" + this.tokensPerTopic[topic] + "\t" + this.formatter.format(this.totalTopicWeights[topic]));
            for (int i = 0; i < numWords; ++i) {
                output.append(this.alphabet.lookupObject(((IDSorter)sortedWords[i]).getID()) + " ");
            }
            output.append("\n");
        }
        return output.toString();
    }

    public MarginalProbEstimator getEstimator() {
        int topicBits;
        int topicMask;
        if (Integer.bitCount(this.numTopics) == 1) {
            topicMask = this.numTopics - 1;
            topicBits = Integer.bitCount(topicMask);
        } else {
            topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            topicBits = Integer.bitCount(topicMask);
        }
        int[][] sparseTypeTopicCounts = new int[this.numTypes][];
        for (int type = 0; type < this.numTypes; ++type) {
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            int numNonZeros = 0;
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (currentTypeTopicCounts[topic] <= 0) continue;
                ++numNonZeros;
            }
            int[] sparseCounts = new int[numNonZeros];
            for (int topic = 0; topic < this.numTopics; ++topic) {
                if (currentTypeTopicCounts[topic] <= 0) continue;
                int value = (currentTypeTopicCounts[topic] << topicBits) + topic;
                int i = 0;
                while (sparseCounts[i] > value) {
                    ++i;
                }
                while (i < sparseCounts.length && value > sparseCounts[i]) {
                    int temp = sparseCounts[i];
                    sparseCounts[i] = value;
                    value = temp;
                    ++i;
                }
            }
            sparseTypeTopicCounts[type] = sparseCounts;
        }
        double[] alphas = new double[this.numTopics];
        Arrays.fill(alphas, this.alpha);
        return new MarginalProbEstimator(this.numTopics, alphas, this.alphaSum, this.beta, sparseTypeTopicCounts, this.tokensPerTopic);
    }

    public void printState(File f) throws IOException {
        PrintStream out2 = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out2);
        out2.close();
    }

    public void printState(PrintStream stream) {
        stream.println("#doc source pos typeindex type topic");
        for (int doc = 0; doc < this.data.size(); ++doc) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            String source = "NA";
            StringBuilder out2 = new StringBuilder();
            for (int position = 0; position < topicSequence.getLength(); ++position) {
                int type = tokenSequence.getIndexAtPosition(position);
                int topic = topicSequence.getIndexAtPosition(position);
                out2.append(doc);
                out2.append(' ');
                out2.append(source);
                out2.append(' ');
                out2.append(position);
                out2.append(' ');
                out2.append(type);
                out2.append(' ');
                out2.append(this.alphabet.lookupObject(type));
                out2.append(' ');
                out2.append(topic);
                out2.append("\n");
            }
            stream.print(out2.toString());
        }
    }

    public static void main(String[] args) throws Exception {
        CommandOption.setSummary(WeightedTopicModel.class, "Train topics with weights between word types encoded in the prior");
        CommandOption.process(WeightedTopicModel.class, args);
        InstanceList training = InstanceList.load(new File(WeightedTopicModel.inputFile.value));
        Randoms random = null;
        random = WeightedTopicModel.randomSeedOption.value != 0 ? new Randoms(WeightedTopicModel.randomSeedOption.value) : new Randoms();
        WeightedTopicModel lda = new WeightedTopicModel(WeightedTopicModel.numTopicsOption.value, WeightedTopicModel.alphaOption.value, WeightedTopicModel.betaOption.value, random);
        lda.addInstances(training);
        lda.readTypeTypeWeights(new File(WeightedTopicModel.weightsFile.value));
        int docCycleCount = 1;
        for (int epoch = 1; epoch <= WeightedTopicModel.numEpochsOption.value; ++epoch) {
            lda.sample(WeightedTopicModel.numIterationsOption.value, epoch == 1, docCycleCount);
            if (stateFile.wasInvoked()) {
                lda.printState(new File(WeightedTopicModel.stateFile.value + "." + epoch));
            }
            if (!evaluatorFilename.wasInvoked()) continue;
            try {
                ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(WeightedTopicModel.evaluatorFilename.value + "." + epoch));
                oos.writeObject(lda.getEstimator());
                oos.close();
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}

