/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SimpleExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.DataRow;
import com.rapidminer.example.table.DataRowFactory;
import com.rapidminer.example.table.ExampleTable;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.RandomGenerator;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class LDAEvaluationOperator
extends Operator {
    static String PARAMETER_NUMITERATIONS = "iterations";
    static String PARAMETER_NUMTOPICS = "number_of_topics";
    static String PARAMETER_NUMTESTS = "tests";
    static String PARAMETER_ALPHA = "alpha";
    static String PARAMETER_BETA = "beta";
    static String PARAMETER_TEXT_ATTRIBUTE = "text_attribute";
    int iters = 2000;
    protected int numTopics;
    protected double alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    protected double smoothingOnlyMass = 0.0;
    protected double[] cachedCoefficients;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected Random random;
    private final InputPort input = (InputPort)this.getInputPorts().createPort("example set input");
    private final InputPort inputWords = (InputPort)this.getInputPorts().createPort("example set words assignments");
    private final InputPort inputTopics = (InputPort)this.getInputPorts().createPort("example set topic assignments");
    private final OutputPort output = (OutputPort)this.getOutputPorts().createPort("output neg log likelihoods");
    Random rn = null;

    public LDAEvaluationOperator(OperatorDescription description) {
        super(description);
    }

    public void MarginalProbEstimator(int numTopics, double alpha, double alphaSum, double beta, int[][] typeTopicCounts, int[] tokensPerTopic) {
        this.numTopics = numTopics;
        this.typeTopicCounts = typeTopicCounts;
        this.tokensPerTopic = tokensPerTopic;
        this.alphaSum = alphaSum;
        this.alpha = alpha;
        this.beta = beta;
        this.betaSum = beta * (double)typeTopicCounts.length;
        this.random = new Random();
        this.cachedCoefficients = new double[numTopics];
        this.smoothingOnlyMass = 0.0;
        for (int topic = 0; topic < numTopics; ++topic) {
            this.smoothingOnlyMass += alpha * beta / ((double)tokensPerTopic[topic] + this.betaSum);
            this.cachedCoefficients[topic] = alpha / ((double)tokensPerTopic[topic] + this.betaSum);
        }
    }

    public int[] getTokensPerTopic() {
        return this.tokensPerTopic;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public double evaluateLeftToRight(TIntArrayList[] testing, int numParticles, boolean usingResampling, PrintStream docProbabilityStream) {
        this.random = this.rn;
        double logNumParticles = Math.log(numParticles);
        double totalLogLikelihood = 0.0;
        for (TIntArrayList instance : testing) {
            instance.shuffle(this.rn);
            int[] tokenSequence = instance.toArray();
            double docLogLikelihood = 0.0;
            double[][] particleProbabilities = new double[numParticles][];
            for (int particle = 0; particle < numParticles; ++particle) {
                particleProbabilities[particle] = this.leftToRight(tokenSequence, usingResampling);
            }
            for (int position = 0; position < particleProbabilities[0].length; ++position) {
                double sum = 0.0;
                for (int particle = 0; particle < numParticles; ++particle) {
                    sum += particleProbabilities[particle][position];
                }
                if (!(sum > 0.0)) continue;
                docLogLikelihood += Math.log(sum) - logNumParticles;
            }
            if (docProbabilityStream != null) {
                docProbabilityStream.println(docLogLikelihood);
            }
            totalLogLikelihood += docLogLikelihood;
        }
        return totalLogLikelihood;
    }

    protected double[] leftToRight(int[] tokenSequence, boolean usingResampling) {
        int denseIndex;
        int[] oneDocTopics = new int[tokenSequence.length];
        double[] wordProbabilities = new double[tokenSequence.length];
        int docLength = tokenSequence.length;
        int tokensSoFar = 0;
        int[] localTopicCounts = new int[this.numTopics];
        int[] localTopicIndex = new int[this.numTopics];
        int nonZeroTopics = denseIndex = 0;
        double topicBetaMass = 0.0;
        double topicTermMass = 0.0;
        double[] topicTermScores = new double[this.numTopics];
        double logLikelihood = 0.0;
        for (int limit = 0; limit < docLength; ++limit) {
            double sample;
            int i;
            int newTopic;
            double score;
            int[] currentTypeTopicCounts;
            int type;
            if (usingResampling) {
                for (int position = 0; position < limit; ++position) {
                    double sample2;
                    type = tokenSequence[position];
                    int oldTopic = oneDocTopics[position];
                    if (type >= this.typeTopicCounts.length || this.typeTopicCounts[type] == null) continue;
                    currentTypeTopicCounts = this.typeTopicCounts[type];
                    topicBetaMass -= this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    int n = oldTopic;
                    localTopicCounts[n] = localTopicCounts[n] - 1;
                    if (localTopicCounts[oldTopic] == 0) {
                        denseIndex = 0;
                        while (localTopicIndex[denseIndex] != oldTopic) {
                            ++denseIndex;
                        }
                        while (denseIndex < nonZeroTopics) {
                            if (denseIndex < localTopicIndex.length - 1) {
                                localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
                            }
                            ++denseIndex;
                        }
                        --nonZeroTopics;
                    }
                    topicBetaMass += this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    this.cachedCoefficients[oldTopic] = (this.alpha + (double)localTopicCounts[oldTopic]) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    boolean alreadyDecremented = false;
                    topicTermMass = 0.0;
                    for (int index = 0; index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0; ++index) {
                        int currentTopic = index;
                        int currentValue = currentTypeTopicCounts[index];
                        score = this.cachedCoefficients[currentTopic] * (double)currentValue;
                        topicTermMass += score;
                        topicTermScores[index] = score;
                    }
                    double origSample = sample2 = this.random.nextDouble() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
                    newTopic = -1;
                    if (sample2 < topicTermMass) {
                        i = -1;
                        while (sample2 > 0.0) {
                            sample2 -= topicTermScores[++i];
                        }
                        newTopic = i;
                    } else if ((sample2 -= topicTermMass) < topicBetaMass) {
                        sample2 /= this.beta;
                        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                            int topic = localTopicIndex[denseIndex];
                            if (!((sample2 -= (double)localTopicCounts[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum)) <= 0.0)) continue;
                            newTopic = topic;
                            break;
                        }
                    } else {
                        sample2 -= topicBetaMass;
                        sample2 /= this.beta;
                        newTopic = 0;
                        sample2 -= this.alpha / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                        while (sample2 > 0.0) {
                            sample2 -= this.alpha / ((double)this.tokensPerTopic[++newTopic] + this.betaSum);
                        }
                    }
                    if (newTopic == -1) {
                        System.err.println("sampling error: " + origSample + " " + sample2 + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                        newTopic = this.numTopics - 1;
                    }
                    oneDocTopics[position] = newTopic;
                    topicBetaMass -= this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    int n2 = newTopic;
                    localTopicCounts[n2] = localTopicCounts[n2] + 1;
                    if (localTopicCounts[newTopic] == 1) {
                        for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                            localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                        }
                        localTopicIndex[denseIndex] = newTopic;
                        ++nonZeroTopics;
                    }
                    this.cachedCoefficients[newTopic] = (this.alpha + (double)localTopicCounts[newTopic]) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    topicBetaMass += this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                }
            }
            if ((type = tokenSequence[limit]) >= this.typeTopicCounts.length || this.typeTopicCounts[type] == null) continue;
            currentTypeTopicCounts = this.typeTopicCounts[type];
            topicTermMass = 0.0;
            for (int index = 0; index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0; ++index) {
                int currentTopic = index;
                int currentValue = currentTypeTopicCounts[index];
                score = this.cachedCoefficients[currentTopic] * (double)currentValue;
                topicTermMass += score;
                topicTermScores[index] = score;
            }
            double origSample = sample = this.random.nextDouble() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
            int n = limit;
            wordProbabilities[n] = wordProbabilities[n] + (this.smoothingOnlyMass + topicBetaMass + topicTermMass) / (this.alphaSum + (double)tokensSoFar);
            ++tokensSoFar;
            newTopic = -1;
            if (sample < topicTermMass) {
                i = -1;
                while (sample > 0.0) {
                    sample -= topicTermScores[++i];
                }
                newTopic = i;
            } else if ((sample -= topicTermMass) < topicBetaMass) {
                sample /= this.beta;
                for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                    int topic = localTopicIndex[denseIndex];
                    if (!((sample -= (double)localTopicCounts[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum)) <= 0.0)) continue;
                    newTopic = topic;
                    break;
                }
            } else {
                sample -= topicBetaMass;
                sample /= this.beta;
                newTopic = 0;
                sample -= this.alpha / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                while (sample > 0.0) {
                    sample -= this.alpha / ((double)this.tokensPerTopic[++newTopic] + this.betaSum);
                }
            }
            if (newTopic == -1) {
                System.err.println("sampling error: " + origSample + " " + sample + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                newTopic = this.numTopics - 1;
            }
            oneDocTopics[limit] = newTopic;
            topicBetaMass -= this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            int n3 = newTopic;
            localTopicCounts[n3] = localTopicCounts[n3] + 1;
            if (localTopicCounts[newTopic] == 1) {
                for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                    localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                }
                localTopicIndex[denseIndex] = newTopic;
                ++nonZeroTopics;
            }
            this.cachedCoefficients[newTopic] = (this.alpha + (double)localTopicCounts[newTopic]) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
        }
        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
            int topic = localTopicIndex[denseIndex];
            this.cachedCoefficients[topic] = this.alpha / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        return wordProbabilities;
    }

    public void doWork() throws OperatorException {
        this.iters = this.getParameterAsInt(PARAMETER_NUMITERATIONS);
        this.numTopics = this.getParameterAsInt(PARAMETER_NUMTOPICS);
        this.alpha = this.getParameterAsDouble(PARAMETER_ALPHA);
        this.beta = this.getParameterAsDouble(PARAMETER_BETA);
        int numTests = this.getParameterAsInt(PARAMETER_NUMTESTS);
        boolean locSeed = this.getParameterAsBoolean("use_local_random_seed");
        int seed = this.getParameterAsInt("local_random_seed");
        this.rn = locSeed ? new Random(seed) : new Random();
        ArrayList<Attribute> attributeList = new ArrayList<Attribute>();
        attributeList.add(AttributeFactory.createAttribute((String)"negloglikelihood", (int)2));
        MemoryExampleTable table = new MemoryExampleTable(attributeList);
        DataRowFactory factory = new DataRowFactory(0, '.');
        int numWords = 0;
        ExampleSet exampleSet = (ExampleSet)this.input.getData(ExampleSet.class);
        Example ex = exampleSet.getExample(0);
        Attributes attr = ex.getAttributes();
        numWords = attr.size();
        String[] words = new String[numWords];
        TIntArrayList[] documentTokens = new TIntArrayList[exampleSet.size()];
        Attribute textAttribute = null;
        String colName = this.getParameterAsString(PARAMETER_TEXT_ATTRIBUTE);
        for (Attribute att : attr) {
            if (!Ontology.ATTRIBUTE_VALUE_TYPE.isA(att.getValueType(), 5)) continue;
            if (colName.equals("")) {
                textAttribute = att;
                continue;
            }
            if (!colName.equals(att.getName())) continue;
            textAttribute = att;
        }
        if (textAttribute != null) {
            TObjectIntHashMap<String> attToId = new TObjectIntHashMap<String>();
            ex = exampleSet.getExample(0);
            attr = ex.getAttributes();
            int id = 0;
            for (Attribute att : attr) {
                if (att == textAttribute) continue;
                attToId.put(att.getName().trim().toLowerCase(), id);
                ++id;
            }
            for (int i = 0; i < exampleSet.size(); ++i) {
                ex = exampleSet.getExample(0);
                attr = ex.getAttributes();
                String text = ex.getValueAsString(textAttribute);
                String[] tokens = text.split(" ");
                documentTokens[i] = new TIntArrayList();
                for (String token : tokens) {
                    if (!attToId.contains(token.trim().toLowerCase())) continue;
                    documentTokens[i].add(attToId.get(token.trim().toLowerCase()));
                }
                documentTokens[i].shuffle(this.rn);
            }
        } else {
            for (int i = 0; i < exampleSet.size(); ++i) {
                documentTokens[i] = new TIntArrayList();
                ex = exampleSet.getExample(i);
                attr = ex.getAttributes();
                int j = 0;
                for (Attribute att : attr) {
                    words[j] = att.getName();
                    double frequ = 0.0;
                    frequ = ex.getValue(att);
                    if (frequ != 0.0) {
                        for (int k = 0; k < (int)frequ; ++k) {
                            documentTokens[i].add(j);
                        }
                    }
                    ++j;
                }
            }
        }
        ExampleSet examplesProbs = (ExampleSet)this.inputWords.getData(ExampleSet.class);
        Example ex2 = examplesProbs.getExample(0);
        Attributes atts = null;
        numWords = examplesProbs.size();
        this.numTopics = ex2.getAttributes().size() - 2;
        int[][] wordtopicassigns = new int[numWords][this.numTopics];
        for (int i = 0; i < examplesProbs.size(); ++i) {
            ex2 = examplesProbs.getExample(i);
            atts = ex2.getAttributes();
            int j = 0;
            for (Attribute att : atts) {
                if (!att.getName().contains("Topic_")) continue;
                wordtopicassigns[i][j] = (int)ex2.getValue(att);
                ++j;
            }
        }
        int[] topicassigns = new int[this.numTopics];
        examplesProbs = (ExampleSet)this.inputTopics.getData(ExampleSet.class);
        ex2 = examplesProbs.getExample(0);
        atts = null;
        for (int i = 0; i < examplesProbs.size(); ++i) {
            ex2 = examplesProbs.getExample(i);
            atts = ex2.getAttributes();
            int j = 0;
            for (Attribute att : atts) {
                if (!att.getName().contains("Topic_")) continue;
                topicassigns[j] = (int)ex2.getValue(att);
                ++j;
            }
        }
        this.MarginalProbEstimator(this.numTopics, this.alpha, (double)this.numTopics * this.alpha, this.beta, wordtopicassigns, topicassigns);
        for (int te = 0; te < numTests; ++te) {
            double perplexity = 0.0;
            int allCounts = 0;
            PrintStream docProbabilityStream = null;
            int numParticle = this.iters;
            perplexity = this.evaluateLeftToRight(documentTokens, numParticle, true, docProbabilityStream);
            System.out.println(allCounts);
            System.out.println(perplexity);
            DataRow row = factory.create(table.getNumberOfAttributes());
            table.addDataRow(row);
            row.set((Attribute)attributeList.get(0), perplexity);
        }
        SimpleExampleSet set = new SimpleExampleSet((ExampleTable)table);
        this.output.deliver((IOObject)set);
    }

    public int[] getDiscrete(int num, double[] probs) {
        int i;
        double sum = 0.0;
        for (i = 0; i < probs.length; ++i) {
            sum += probs[i];
        }
        i = 0;
        while (i < probs.length) {
            int n = i++;
            probs[n] = probs[n] / sum;
        }
        double pr = 0.0;
        int[] res = new int[num];
        for (int i2 = 0; i2 < num; ++i2) {
            int j = 0;
            double p = this.rn.nextDouble();
            for (pr = probs[0]; pr < p; pr += probs[++j]) {
            }
            res[i2] = j;
        }
        return res;
    }

    public List<ParameterType> getParameterTypes() {
        List types = super.getParameterTypes();
        types.add(new ParameterTypeInt(PARAMETER_NUMITERATIONS, "Number of Iterations for Samplings.", 1, Integer.MAX_VALUE, 2000));
        types.add(new ParameterTypeInt(PARAMETER_NUMTESTS, "Number of Iterations for Samplings.", 1, Integer.MAX_VALUE, 20));
        types.add(new ParameterTypeInt(PARAMETER_NUMTOPICS, "Number of Topics.", 1, Integer.MAX_VALUE, 5));
        types.add(new ParameterTypeDouble(PARAMETER_ALPHA, "Alpha", 0.0, Double.MAX_VALUE, 0.25));
        types.add(new ParameterTypeDouble(PARAMETER_BETA, "Beta", 0.0, Double.MAX_VALUE, 0.1));
        types.add(new ParameterTypeString(PARAMETER_TEXT_ATTRIBUTE, "Attribute name of text columns of interest.", ""));
        types.addAll(RandomGenerator.getRandomGeneratorParameters((Operator)this));
        return types;
    }

    public static void main(String[] args) {
    }
}

