/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.kobra.topicmodels;

import cc.mallet.optimize.OptimizationException;
import com.rapidminer.kobra.opt.MyOrthantWiseLimitedMemoryBFGS;
import com.rapidminer.kobra.topicmodels.MyHashGompertzOptimizable;
import com.rapidminer.kobra.topicmodels.SamplersGTLDA;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TDoubleIntHashMap;
import gnu.trove.set.hash.TDoubleHashSet;

public class SamplersGBTLDA
extends SamplersGTLDA {
    int[] type = null;
    TDoubleIntHashMap[] hsValues2 = null;
    public boolean emp = false;

    @Override
    public void GibbsSampling() {
        int di;
        TIntArrayList t = new TIntArrayList();
        for (int i = 0; i < this.numTopics; ++i) {
            if (i % 2 == 0) {
                t.add(0);
                continue;
            }
            t.add(0);
        }
        this.type = t.toArray();
        int burnIn = (int)((double)this.maxIter * 0.9);
        int topic = 0;
        this.WBETA = (double)this.numWords * this.BETA;
        this.probs = new double[this.numTopics];
        this.tokenToTopic = new int[this.numTokens];
        TIntArrayList ll = new TIntArrayList(this.numTokens);
        for (int i = 0; i < this.numTokens; ++i) {
            ll.add(i);
        }
        ll.shuffle(this.rn);
        this.hsValues = new TDoubleIntHashMap[this.numTopics];
        this.hsValues2 = new TDoubleIntHashMap[this.numTopics];
        this.testStatistics = new double[this.maxIter][this.numTopics];
        TDoubleIntHashMap timeMap = new TDoubleIntHashMap();
        int sum = this.numTokens;
        for (int ii = 0; ii < this.numTokens; ++ii) {
            int i = ll.get(ii);
            di = this.docs[i];
            int n = 1;
            double time = this.times[di];
            if (timeMap.contains(time)) {
                n += timeMap.get(time);
            }
            timeMap.put(time, n);
        }
        for (int iter = 0; iter < this.maxIter; ++iter) {
            int i;
            int i2;
            System.out.println("Current Gibbs Sampler iteration: " + iter);
            for (i2 = 0; i2 < this.numTopics; ++i2) {
                this.hsValues[i2] = new TDoubleIntHashMap();
                this.hsValues2[i2] = new TDoubleIntHashMap();
            }
            for (int ii = 0; ii < this.numTokens; ++ii) {
                i = ll.get(ii);
                int wi = this.words[i];
                di = this.docs[i];
                int n = topic = this.topics[i];
                this.topiccounts[n] = this.topiccounts[n] - 1;
                int wioffset = wi * this.numTopics;
                int dioffset = di * this.numTopics;
                int n2 = wioffset + topic;
                this.wordtopiccounts[n2] = this.wordtopiccounts[n2] - 1;
                int n3 = dioffset + topic;
                this.doctopiccounts[n3] = this.doctopiccounts[n3] - 1;
                double totprob = 0.0;
                for (int j = 0; j < this.numTopics; ++j) {
                    this.probs[j] = this.type[j] == 0 ? ((double)this.wordtopiccounts[wioffset + j] + this.BETA) / ((double)this.topiccounts[j] + this.WBETA) * ((double)this.doctopiccounts[dioffset + j] + this.ALPHA) * 1.0 / (this.maxTime - this.minTime) : (this.type[j] == 1 ? ((double)this.wordtopiccounts[wioffset + j] + this.BETA) / ((double)this.topiccounts[j] + this.WBETA) * ((double)this.doctopiccounts[dioffset + j] + this.ALPHA) * (double)timeMap.get(this.times[di]) / (double)sum : ((double)this.wordtopiccounts[wioffset + j] + this.BETA) / ((double)this.topiccounts[j] + this.WBETA) * ((double)this.doctopiccounts[dioffset + j] + this.ALPHA) * this.dist(this.times[di], this.pGombertz[j][0], this.pGombertz[j][1]));
                    totprob += this.probs[j];
                }
                double r = totprob * this.rn.nextDouble();
                topic = 0;
                for (double max = this.probs[0]; r > max; max += this.probs[++topic]) {
                }
                this.topics[i] = topic;
                int n4 = wioffset + topic;
                this.wordtopiccounts[n4] = this.wordtopiccounts[n4] + 1;
                int n5 = dioffset + topic;
                this.doctopiccounts[n5] = this.doctopiccounts[n5] + 1;
                int n6 = topic;
                this.topiccounts[n6] = this.topiccounts[n6] + 1;
                int n7 = 1;
                if (this.hsValues[topic].contains(this.times[di])) {
                    n7 += this.hsValues[topic].get(this.times[di]);
                }
                this.hsValues[topic].put(this.times[di], n7);
                this.tokenToTopic[ii] = topic;
            }
            for (i2 = 0; i2 < this.numWords * this.numTopics; ++i2) {
                if (this.wordtopiccounts[i2] >= 0) continue;
                this.wordtopiccounts[i2] = 0;
            }
            for (i2 = 0; i2 < this.numDocs * this.numTopics; ++i2) {
                if (this.doctopiccounts[i2] >= 0) continue;
                this.doctopiccounts[i2] = 0;
            }
            TDoubleHashSet set = new TDoubleHashSet();
            for (i = 0; i < this.numTopics; ++i) {
                set.addAll(this.hsValues[i].keys());
            }
            for (i = 0; i < this.numTopics; ++i) {
                MyHashGompertzOptimizable opt = new MyHashGompertzOptimizable();
                opt.vals = this.hsValues[i];
                opt.alpha = this.rn.nextDouble();
                opt.beta = this.rn.nextDouble();
                opt.a = this.rn.nextDouble();
                opt.b = this.rn.nextDouble();
                MyOrthantWiseLimitedMemoryBFGS optimizer = new MyOrthantWiseLimitedMemoryBFGS(opt);
                boolean converged = false;
                double[] paras = new double[2];
                try {
                    converged = optimizer.optimize(100);
                }
                catch (OptimizationException e) {
                    e.printStackTrace();
                }
                opt.getParameters(paras);
                this.pGombertz[i][0] = Math.exp(paras[0]);
                this.pGombertz[i][1] = Math.exp(paras[1]);
                double optGomp = opt.getValue2();
                int nv = 0;
                int n = 0;
                double p = 1.0 / (double)set.size();
                double lL = 0.0;
                for (double k : this.hsValues[i].keys()) {
                    nv = this.hsValues[i].get(k);
                    lL += (double)nv * Math.log((double)timeMap.get(k) / (double)sum);
                    n += nv;
                }
                if (!this.emp) {
                    lL = (double)(-n) * Math.log(this.maxTime - this.minTime);
                }
                System.out.println("Gom: " + optGomp + " vs. Uni: " + lL);
                double AICg = -2.0 * optGomp + 4.0;
                double AICu = -2.0 * lL;
                double AICmin = 0.0;
                double AICmax = 0.0;
                if (AICg < AICu) {
                    AICmin = AICg;
                    AICmax = AICu;
                    this.testStatistics[iter][i] = 1.0 - Math.exp((AICmin - AICmax) / 2.0);
                } else {
                    AICmin = AICu;
                    AICmax = AICg;
                    this.testStatistics[iter][i] = Math.exp((AICmin - AICmax) / 2.0);
                }
                double ratio = Math.exp((AICmin - AICmax) / 2.0);
                System.out.println(i + ": Gom: " + optGomp + " vs. Uni: " + lL + " AIC ratio " + ratio);
                System.out.println(i + " AIC 1 " + AICmin + " AIC 2 " + AICmax);
                if (lL > optGomp) {
                    this.pGombertz[i][0] = 0.0;
                    this.pGombertz[i][1] = 0.0;
                    if (this.emp) {
                        this.type[i] = 1;
                        continue;
                    }
                    this.type[i] = 0;
                    continue;
                }
                this.type[i] = 2;
            }
            if (iter < burnIn || iter % 2 != 0) continue;
            this.updateDistributions();
        }
    }

    @Override
    public TDoubleArrayList[] getAssignedTimes() {
        TDoubleArrayList[] res = new TDoubleArrayList[this.numTopics];
        for (int i = 0; i < this.numTopics; ++i) {
            TDoubleArrayList ls = new TDoubleArrayList();
            int nv = 0;
            for (double k : this.hsValues[i].keys()) {
                nv = this.hsValues[i].get(k);
                for (int kn = 0; kn < nv; ++kn) {
                    ls.add(k);
                }
            }
            res[i] = ls;
        }
        return res;
    }

    private static long binomial(int n, int k) {
        if (k > n - k) {
            k = n - k;
        }
        long b = 1L;
        int i = 1;
        int m = n;
        while (i <= k) {
            b = b * (long)m / (long)i;
            ++i;
            --m;
        }
        return b;
    }

    public static void main(String[] args) {
    }
}

