/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import com.rapidminer.kobra.topicmodels.SamplersSLDA;
import gnu.trove.TIntHashSet;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TDoubleIntHashMap;
import java.util.Random;
import org.apache.commons.math3.distribution.BetaDistribution;
import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.Variance;

public class SamplersDTLDA
extends SamplersSLDA {
    double[] times;
    double[] meanTimes;
    double[] varianceTimes;
    int[] uniqueIds;
    BetaDistribution[] pBeta = null;
    double[][] pi;
    TDoubleIntHashMap[] hsValues = null;

    public void init(int[] docIds, int[] wordIds, double[] ts, int[] ids, int numTopics, int numWords, int numDocs, int iter, double beta, double alpha, boolean locSeed, int seed) {
        this.maxIter = iter;
        this.BETA = beta;
        this.ALPHA = alpha;
        int topic = 0;
        this.numTokens = wordIds.length;
        this.numTopics = numTopics;
        this.numDocs = numDocs;
        this.numWords = numWords;
        this.topics = new int[this.numTokens];
        this.wordtopiccounts = new int[numWords * numTopics];
        this.doctopiccounts = new int[numDocs * numTopics];
        this.topiccounts = new int[numTopics];
        this.words = wordIds;
        this.docs = docIds;
        if (locSeed) {
            this.seed = seed;
            this.rn = new Random(seed);
        } else {
            this.rn = new Random();
        }
        for (int i = 0; i < wordIds.length; ++i) {
            int wi = this.words[i];
            int di = this.docs[i];
            this.topics[i] = topic = this.rn.nextInt(numTopics);
            int n = wi * numTopics + topic;
            this.wordtopiccounts[n] = this.wordtopiccounts[n] + 1;
            int n2 = di * numTopics + topic;
            this.doctopiccounts[n2] = this.doctopiccounts[n2] + 1;
            int n3 = topic;
            this.topiccounts[n3] = this.topiccounts[n3] + 1;
        }
        this.times = ts;
        double maxTime = 0.0;
        double minTime = Double.MAX_VALUE;
        double mean = 0.0;
        double variance = 0.0;
        this.pi = new double[numTopics][2];
        this.pBeta = new BetaDistribution[numTopics];
        this.meanTimes = new double[numTopics];
        this.varianceTimes = new double[numTopics];
        for (int i = 0; i < numTopics; ++i) {
            this.pi[i][0] = 1.0;
            this.pi[i][1] = 1.0;
            this.pBeta[i] = new BetaDistribution(this.pi[i][0], this.pi[i][1]);
        }
        this.uniqueIds = ids;
    }

    @Override
    public void GibbsSampling() {
        int burnIn = (int)((double)this.maxIter * 0.9);
        int topic = 0;
        this.WBETA = (double)this.numWords * this.BETA;
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList ll = new TIntArrayList(this.numTokens);
        for (int i = 0; i < this.numTokens; ++i) {
            ll.add(i);
        }
        ll.shuffle(this.rn);
        this.hsValues = new TDoubleIntHashMap[this.numTopics];
        for (int iter = 0; iter < this.maxIter; ++iter) {
            int i;
            System.out.println("Current Gibbs Sampler iteration: " + iter);
            Mean[] currmeans = new Mean[this.numTopics];
            Variance[] currvars = new Variance[this.numTopics];
            TIntHashSet[] set = new TIntHashSet[this.numTopics];
            for (i = 0; i < this.numTopics; ++i) {
                this.hsValues[i] = new TDoubleIntHashMap();
                currmeans[i] = new Mean();
                currvars[i] = new Variance();
            }
            for (int ii = 0; ii < this.numTokens; ++ii) {
                int j;
                int i2 = ll.get(ii);
                int wi = this.words[i2];
                int di = this.docs[i2];
                int n = topic = this.topics[i2];
                this.topiccounts[n] = this.topiccounts[n] - 1;
                int wioffset = wi * this.numTopics;
                int dioffset = di * this.numTopics;
                int n2 = wioffset + topic;
                this.wordtopiccounts[n2] = this.wordtopiccounts[n2] - 1;
                int n3 = dioffset + topic;
                this.doctopiccounts[n3] = this.doctopiccounts[n3] - 1;
                double totprob = 0.0;
                int docCounts = 0;
                for (j = 0; j < this.numTopics; ++j) {
                    docCounts += this.doctopiccounts[dioffset + j];
                }
                for (j = 0; j < this.numTopics; ++j) {
                    this.probs[j] = ((double)this.wordtopiccounts[wioffset + j] + this.BETA) / ((double)this.topiccounts[j] + this.WBETA) * ((double)this.doctopiccounts[dioffset + j] + this.ALPHA) * this.pBeta[j].density(this.times[di]);
                    totprob += this.probs[j];
                }
                double r = totprob * this.rn.nextDouble();
                topic = 0;
                for (double max = this.probs[0]; r > max; max += this.probs[++topic]) {
                }
                this.topics[i2] = topic;
                int n4 = wioffset + topic;
                this.wordtopiccounts[n4] = this.wordtopiccounts[n4] + 1;
                int n5 = dioffset + topic;
                this.doctopiccounts[n5] = this.doctopiccounts[n5] + 1;
                int n6 = topic;
                this.topiccounts[n6] = this.topiccounts[n6] + 1;
                currmeans[topic].increment(this.times[di]);
                currvars[topic].increment(this.times[di]);
                int n7 = 1;
                if (this.hsValues[topic].contains(this.times[di])) {
                    n7 += this.hsValues[topic].get(this.times[di]);
                }
                this.hsValues[topic].put(this.times[di], n7);
                this.tokenToTopic[ii] = topic;
            }
            for (i = 0; i < this.numWords * this.numTopics; ++i) {
                if (this.wordtopiccounts[i] >= 0) continue;
                this.wordtopiccounts[i] = 0;
            }
            for (i = 0; i < this.numDocs * this.numTopics; ++i) {
                if (this.doctopiccounts[i] >= 0) continue;
                this.doctopiccounts[i] = 0;
            }
            for (i = 0; i < this.numTopics; ++i) {
                double point;
                if (currmeans[i].getN() == 1L) {
                    point = currmeans[i].getResult();
                    for (int p = 0; p < 10; ++p) {
                        double nextPoint = this.rn.nextGaussian() * 0.01 + point;
                        while (!(nextPoint > 0.0) && !(nextPoint < 1.0)) {
                            nextPoint = this.rn.nextGaussian() * 0.01 + point;
                        }
                        currmeans[i].increment(nextPoint);
                        currvars[i].increment(nextPoint);
                    }
                } else if (currvars[i].getResult() <= 0.01) {
                    point = currmeans[i].getResult();
                    long n = currmeans[i].getN();
                    currmeans[i].clear();
                    currvars[i].clear();
                    int p = 0;
                    while ((long)p < n) {
                        double nextPoint = this.rn.nextGaussian() * 0.01 + point;
                        while (!(nextPoint > 0.0) && !(nextPoint < 1.0)) {
                            nextPoint = this.rn.nextGaussian() * 0.01 + point;
                        }
                        currmeans[i].increment(nextPoint);
                        currvars[i].increment(nextPoint);
                        ++p;
                    }
                }
                if (currmeans[i].getN() < 1L) {
                    this.pi[i][0] = 1.0;
                    this.pi[i][1] = 1.0;
                    continue;
                }
                this.meanTimes[i] = currmeans[i].getResult();
                this.varianceTimes[i] = currvars[i].getResult();
                double x = this.meanTimes[i];
                double v = this.varianceTimes[i];
                this.pi[i][0] = x * (x * (1.0 - x) / v - 1.0);
                this.pi[i][1] = (1.0 - x) * (x * (1.0 - x) / v - 1.0);
                this.pBeta[i] = new BetaDistribution(this.pi[i][0], this.pi[i][1]);
            }
            if (iter < burnIn || iter % 2 != 0) continue;
            this.updateDistributions();
        }
    }

    public TDoubleArrayList[] getAssignedTimes() {
        TDoubleArrayList[] res = new TDoubleArrayList[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            TDoubleArrayList ls = new TDoubleArrayList();
            int nv = 0;
            for (double k : this.hsValues[i].keys()) {
                nv = this.hsValues[i].get(k);
                for (int kn = 0; kn < nv; ++kn) {
                    ls.add(k);
                }
            }
            res[i] = ls;
        }
        return res;
    }

    public double[][] getPi() {
        return this.pi;
    }

    public static void main(String[] args) {
    }
}

