/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.ConjugateGradient;
import cc.mallet.optimize.OptimizationException;
import com.rapidminer.kobra.topicmodels.MyEtaOptimizable;
import com.rapidminer.kobra.topicmodels.SamplersLDA;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import org.apache.commons.math3.util.FastMath;

public class SamplersNBLDA
extends SamplersLDA {
    double[][] z_bar;
    int[] doc_lengths;
    double[] labels;
    int start_test = 2000;
    double[] labels_train;
    double[] predictions;

    public void init(int[] docIds, int[] wordIds, int numTopics, int numWords, int numDocs, int iter, double beta, double alpha, double[] labels, int[] doc_lengths) {
        this.labels = labels;
        this.labels_train = new double[this.start_test];
        for (int i = 0; i < this.start_test; ++i) {
            this.labels_train[i] = labels[i];
        }
        this.doc_lengths = doc_lengths;
        this.maxIter = iter;
        this.BETA = beta;
        this.ALPHA = alpha;
        int topic = 0;
        this.numTokens = wordIds.length;
        this.numTopics = numTopics;
        this.numDocs = numDocs;
        this.numWords = numWords;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[numWords * numTopics];
        this.doctopiccounts = new int[numDocs * numTopics];
        this.topiccounts = new int[numTopics];
        this.words = wordIds;
        this.docs = docIds;
        for (int i = 0; i < wordIds.length; ++i) {
            int wi = this.words[i];
            int di = this.docs[i];
            this.topics[i] = topic = new Random().nextInt(numTopics);
            int n = wi * numTopics + topic;
            this.wordtopiccounts[n] = this.wordtopiccounts[n] + 1;
            int n2 = di * numTopics + topic;
            this.doctopiccounts[n2] = this.doctopiccounts[n2] + 1;
            int n3 = topic;
            this.topiccounts[n3] = this.topiccounts[n3] + 1;
        }
    }

    @Override
    public void GibbsSampling() {
        int iter;
        int burnIn = (int)((double)this.maxIter * 0.9);
        int topic = 0;
        this.WBETA = (double)this.numWords * this.BETA;
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList ll = new TIntArrayList(this.numTokens);
        for (int i = 0; i < this.numTokens; ++i) {
            ll.add(i);
        }
        ll.shuffle(new Random(2000L));
        double[] eta = new double[this.numTopics];
        for (int j = 0; j < this.numTopics; ++j) {
            eta[j] = 1.0 / (double)this.numTopics;
        }
        double logStandardDeviationPlusHalfLog2Pi = FastMath.log(4.0) + 0.5 * FastMath.log(Math.PI * 2);
        int[] numPos = new int[this.numTopics];
        int[] numNeg = new int[this.numTopics];
        for (iter = 0; iter < this.maxIter; ++iter) {
            int i;
            this.z_bar = new double[this.numDocs][this.numTopics];
            for (int ii = 0; ii < this.numTokens; ++ii) {
                int i2 = ll.get(ii);
                int wi = this.words[i2];
                int di = this.docs[i2];
                if (di > this.start_test) continue;
                int n = topic = this.topics[i2];
                this.topiccounts[n] = this.topiccounts[n] - 1;
                int wioffset = wi * this.numTopics;
                int dioffset = di * this.numTopics;
                int n2 = wioffset + topic;
                this.wordtopiccounts[n2] = this.wordtopiccounts[n2] - 1;
                int n3 = dioffset + topic;
                this.doctopiccounts[n3] = this.doctopiccounts[n3] - 1;
                double totprob = 0.0;
                double docCounts = 0.0;
                int assigedDoc = 0;
                boolean all = false;
                for (int j = 0; j < this.numTopics; ++j) {
                    docCounts += (double)this.doctopiccounts[dioffset + j] * eta[j];
                    assigedDoc += this.doctopiccounts[dioffset + j];
                }
                double mean = docCounts / (double)this.doc_lengths[di];
                double variance = 1.0;
                for (int j = 0; j < this.numTopics; ++j) {
                    mean = (docCounts + eta[j]) / (double)(assigedDoc + 1);
                    variance = 4.0;
                    double x = this.labels[di] - mean;
                    double normal = FastMath.exp(-0.5 * (x /= variance) * x - logStandardDeviationPlusHalfLog2Pi);
                    this.probs[j] = ((double)this.wordtopiccounts[wioffset + j] + this.BETA) / ((double)this.topiccounts[j] + this.WBETA) * ((double)this.doctopiccounts[dioffset + j] + this.ALPHA) * normal;
                    totprob += this.probs[j];
                }
                double r = totprob * Math.random();
                topic = 0;
                for (double max = this.probs[0]; r > max; max += this.probs[++topic]) {
                }
                this.topics[i2] = topic;
                int n4 = wioffset + topic;
                this.wordtopiccounts[n4] = this.wordtopiccounts[n4] + 1;
                int n5 = dioffset + topic;
                this.doctopiccounts[n5] = this.doctopiccounts[n5] + 1;
                int n6 = topic;
                this.topiccounts[n6] = this.topiccounts[n6] + 1;
                if (this.labels[di] > 0.0) {
                    int n7 = topic;
                    numPos[n7] = numPos[n7] + 1;
                } else {
                    int n8 = topic;
                    numNeg[n8] = numNeg[n8] + 1;
                }
                double[] dArray = this.z_bar[di];
                int n9 = topic;
                dArray[n9] = dArray[n9] + 1.0 / (double)this.doc_lengths[di];
                this.tokenToTopic[ii] = topic;
            }
            for (i = 0; i < this.numWords * this.numTopics; ++i) {
                if (this.wordtopiccounts[i] >= 0) continue;
                this.wordtopiccounts[i] = 0;
            }
            for (i = 0; i < this.numDocs * this.numTopics; ++i) {
                if (this.doctopiccounts[i] >= 0) continue;
                this.doctopiccounts[i] = 0;
            }
            if (iter >= burnIn && iter % 2 == 0) {
                this.updateDistributions();
            }
            if (iter != 0 && iter % 10 != 0 || iter >= burnIn) continue;
            MyEtaOptimizable opt = new MyEtaOptimizable(this.labels_train, this.numTopics, this.z_bar);
            ConjugateGradient optimizer = null;
            optimizer = new ConjugateGradient(opt);
            boolean converged = false;
            try {
                converged = optimizer.optimize(1000);
            }
            catch (OptimizationException e) {
                e.printStackTrace();
            }
            eta = new double[this.numTopics];
            opt.getParameters(eta);
        }
        iter = 0;
        if (iter < this.maxIter / 10) {
            // empty if block
        }
        this.predictions = new double[this.labels.length];
        for (int i = 0; i < this.z_bar.length; ++i) {
            double pred = 0.0;
            for (int j = 0; j < this.z_bar[i].length; ++j) {
                pred += eta[j] * this.z_bar[i][j];
            }
            this.predictions[i] = pred > 0.0 ? 1.0 : -1.0;
        }
    }

    public double[] getPredictions() {
        return this.predictions;
    }

    public static void main(String[] args) {
    }
}

